我最近在Jax中实现了一个两层的GRU网络,但对它的性能感到失望(它无法使用)。
所以,我尝试了一下Pytorch的速度。
最小工作示例
这是我的最小工作示例,输出是在Google Colab上使用GPU-GPU创建的。notebook in colab
import flax.linen as jnn
import jax
import torch
import torch.nn as tnn
import numpy as np
import jax.numpy as jnp
def keyGen(seed):
key1 = jax.random.PRNGKey(seed)
while True:
key1, key2 = jax.random.split(key1)
yield key2
key = keyGen(1)
hidden_size=200
seq_length = 1000
in_features = 6
out_features = 4
batch_size = 8
class RNN_jax(jnn.Module):
@jnn.compact
def __call__(self, x, carry_gru1, carry_gru2):
carry_gru1, x = jnn.GRUCell()(carry_gru1, x)
carry_gru2, x = jnn.GRUCell()(carry_gru2, x)
x = jnn.Dense(4)(x)
x = x/jnp.linalg.norm(x)
return x, carry_gru1, carry_gru2
class RNN_torch(tnn.Module):
def __init__(self, batch_size, hidden_size, in_features, out_features):
super().__init__()
self.gru = tnn.GRU(
input_size=in_features,
hidden_size=hidden_size,
num_layers=2
)
self.dense = tnn.Linear(hidden_size, out_features)
self.init_carry = torch.zeros((2, batch_size, hidden_size))
def forward(self, X):
X, final_carry = self.gru(X, self.init_carry)
X = self.dense(X)
return X/X.norm(dim=-1).unsqueeze(-1).repeat((1, 1, 4))
rnn_jax = RNN_jax()
rnn_torch = RNN_torch(batch_size, hidden_size, in_features, out_features)
Xj = jax.random.normal(next(key), (seq_length, batch_size, in_features))
Yj = jax.random.normal(next(key), (seq_length, batch_size, out_features))
Xt = torch.from_numpy(np.array(Xj))
Yt = torch.from_numpy(np.array(Yj))
initial_carry_gru1 = jnp.zeros((batch_size, hidden_size))
initial_carry_gru2 = jnp.zeros((batch_size, hidden_size))
params = rnn_jax.init(next(key), Xj[0], initial_carry_gru1, initial_carry_gru2)
def forward(params, X):
carry_gru1, carry_gru2 = initial_carry_gru1, initial_carry_gru2
Yhat = []
for x in X: # x.shape = (batch_size, in_features)
yhat, carry_gru1, carry_gru2 = rnn_jax.apply(params, x, carry_gru1, carry_gru2)
Yhat.append(yhat) # y.shape = (batch_size, out_features)
#return jnp.concatenate(Y, axis=0)
jitted_forward = jax.jit(forward)结果
# uncompiled jax version
%time forward(params, Xj)CPU times: user 7min 17s, sys: 8.18 s, total: 7min 25s Wall time: 7min 17s
# time for compiling
%time jitted_forward(params, Xj)CPU times: user 8min 9s, sys: 4.46 s, total: 8min 13s Wall time: 8min 12s
# compiled jax version
%timeit jitted_forward(params, Xj)The slowest run took 204.20 times longer than the fastest. This could mean that an intermediate result is being cached. 10000 loops, best of 5: 115 µs per loop
# torch version
%timeit lambda: rnn_torch(Xt)10000000 loops, best of 5: 65.7 ns per loop
问题
为什么我的Jax实现这么慢?我做错了什么?
另外,为什么编译花了这么长时间?序列并没有那么长..
谢谢您:)
发布于 2021-10-29 13:40:07
JAX代码编译缓慢的原因是在JIT编译期间,JAX展开循环。因此,就XLA编译而言,您的函数实际上非常大:您调用rnn_jax.apply() 1000次,而编译时间往往是语句数量的大约二次。
相比之下,您的pytorch函数不使用Python循环,因此在幕后它依赖于运行速度更快的矢量化操作。
在Python语言中使用for循环处理数据时,很可能代码会很慢:无论您使用的是JAX、torch、numpy、pandas等,都是如此。我建议在JAX中找到一种依赖于矢量化操作而不是依赖于缓慢的Python循环的方法来解决这个问题。
https://stackoverflow.com/questions/69767707
复制相似问题