首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >与pyTorch相比,Jax/Flax (非常)慢的RNN-forward-pass?

与pyTorch相比,Jax/Flax (非常)慢的RNN-forward-pass?
EN

Stack Overflow用户
提问于 2021-10-29 10:56:32
回答 1查看 281关注 0票数 2

我最近在Jax中实现了一个两层的GRU网络,但对它的性能感到失望(它无法使用)。

所以,我尝试了一下Pytorch的速度。

最小工作示例

这是我的最小工作示例,输出是在Google Colab上使用GPU-GPU创建的。notebook in colab

代码语言:javascript
复制
import flax.linen as jnn 
import jax
import torch
import torch.nn as tnn
import numpy as np 
import jax.numpy as jnp

def keyGen(seed):
    key1 = jax.random.PRNGKey(seed)
    while True:
        key1, key2 = jax.random.split(key1)
        yield key2
key = keyGen(1)

hidden_size=200
seq_length = 1000
in_features = 6
out_features = 4
batch_size = 8

class RNN_jax(jnn.Module):

    @jnn.compact
    def __call__(self, x, carry_gru1, carry_gru2):
        carry_gru1, x = jnn.GRUCell()(carry_gru1, x)
        carry_gru2, x = jnn.GRUCell()(carry_gru2, x)
        x = jnn.Dense(4)(x)
        x = x/jnp.linalg.norm(x)
        return x, carry_gru1, carry_gru2

class RNN_torch(tnn.Module):
    def __init__(self, batch_size, hidden_size, in_features, out_features):
        super().__init__()

        self.gru = tnn.GRU(
            input_size=in_features, 
            hidden_size=hidden_size,
            num_layers=2
            )
        
        self.dense = tnn.Linear(hidden_size, out_features)

        self.init_carry = torch.zeros((2, batch_size, hidden_size))

    def forward(self, X):
        X, final_carry = self.gru(X, self.init_carry)
        X = self.dense(X)
        return X/X.norm(dim=-1).unsqueeze(-1).repeat((1, 1, 4))

rnn_jax = RNN_jax()
rnn_torch = RNN_torch(batch_size, hidden_size, in_features, out_features)

Xj = jax.random.normal(next(key), (seq_length, batch_size, in_features))
Yj = jax.random.normal(next(key), (seq_length, batch_size, out_features))
Xt = torch.from_numpy(np.array(Xj))
Yt = torch.from_numpy(np.array(Yj))

initial_carry_gru1 = jnp.zeros((batch_size, hidden_size))
initial_carry_gru2 = jnp.zeros((batch_size, hidden_size))

params = rnn_jax.init(next(key), Xj[0], initial_carry_gru1, initial_carry_gru2)

def forward(params, X):
    
    carry_gru1, carry_gru2 = initial_carry_gru1, initial_carry_gru2

    Yhat = []
    for x in X: # x.shape = (batch_size, in_features)
        yhat, carry_gru1, carry_gru2 = rnn_jax.apply(params, x, carry_gru1, carry_gru2)
        Yhat.append(yhat) # y.shape = (batch_size, out_features)

    #return jnp.concatenate(Y, axis=0)

jitted_forward = jax.jit(forward)

结果

代码语言:javascript
复制
# uncompiled jax version
%time forward(params, Xj)

CPU times: user 7min 17s, sys: 8.18 s, total: 7min 25s Wall time: 7min 17s

代码语言:javascript
复制
# time for compiling
%time jitted_forward(params, Xj)

CPU times: user 8min 9s, sys: 4.46 s, total: 8min 13s Wall time: 8min 12s

代码语言:javascript
复制
# compiled jax version
%timeit jitted_forward(params, Xj)

The slowest run took 204.20 times longer than the fastest. This could mean that an intermediate result is being cached. 10000 loops, best of 5: 115 µs per loop

代码语言:javascript
复制
# torch version
%timeit lambda: rnn_torch(Xt)

10000000 loops, best of 5: 65.7 ns per loop

问题

为什么我的Jax实现这么慢?我做错了什么?

另外,为什么编译花了这么长时间?序列并没有那么长..

谢谢您:)

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-10-29 13:40:07

JAX代码编译缓慢的原因是在JIT编译期间,JAX展开循环。因此,就XLA编译而言,您的函数实际上非常大:您调用rnn_jax.apply() 1000次,而编译时间往往是语句数量的大约二次。

相比之下,您的pytorch函数不使用Python循环,因此在幕后它依赖于运行速度更快的矢量化操作。

在Python语言中使用for循环处理数据时,很可能代码会很慢:无论您使用的是JAX、torch、numpy、pandas等,都是如此。我建议在JAX中找到一种依赖于矢量化操作而不是依赖于缓慢的Python循环的方法来解决这个问题。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69767707

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档