首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么我的JAX + Haiku代码不能在GPU上运行?

为什么我的JAX + Haiku代码不能在GPU上运行?
EN

Stack Overflow用户
提问于 2021-11-11 03:37:13
回答 2查看 266关注 0票数 1

下午好,

我刚开始学习JAX和Haiku,但是我不能在GPU上运行我的代码。我在激活了GPU的Google Colab和Kaggle笔记本上运行了我的代码,但这比禁用GPU需要更多的时间。

此外,当我查看GPU指标时,我发现我只使用了1%的计算能力,但使用了90%的GPU内存。

下面是我的代码(MNIST的MLP):

代码语言:javascript
复制
def mlp(images):
  model = hk.Sequential([hk.Linear(128),
                      jax.nn.relu,
                      hk.Linear(64),
                      jax.nn.relu,
                      hk.Linear(10),
                      jax.nn.log_softmax])
  return model(images)

def loss(params, model, images, labels):
  logits = model.apply(params = params, images = images)
  labels = jax.nn.one_hot(labels, num_classes = 10)
  cross_entropy_loss = -jnp.sum(labels*logits)/len(labels)
  return cross_entropy_loss
代码语言:javascript
复制
# Initializing the MLP model
mlp = hk.without_apply_rng(hk.transform(mlp))
params = mlp.init(rng = jax.random.PRNGKey(0),
                  images = next(iter(train_loader))[0])

# Initializing the optimizer
opt = optax.adam(1e-4)
opt_state = opt.init(params = params)
代码语言:javascript
复制
@jax.jit
def update(params, opt_state, images, labels):
  grads = jax.grad(loss)(params,mlp,images,labels)
  updates, opt_state = opt.update(grads, opt_state)
  return optax.apply_updates(params, updates), opt_state

def train(params, opt_state, epochs):
  for epoch in range(epochs):
    for batch_idx, (images, labels) in enumerate(train_loader):
      if batch_idx == 0:
        print(f"Epoch {epoch} : loss = {loss(params,mlp,images,labels)}")
      params, opt_state = update(params, opt_state, images,labels)

%time train(params, opt_state, epochs = 10)

如果你知道我做错了什么,你会帮我很多。谢谢。

EN

回答 2

Stack Overflow用户

发布于 2021-11-11 12:43:56

这个问题很难回答,因为不清楚epochstrain_loader包含什么。但一般的回答是:

  • 默认情况下,JAX将总是在启动时预先分配90%的GPU内存(参见GPU Memory Allocation),因此这并不表示您的计算有多少内存。for循环等控制流将在您的CPU上执行,将内部计算逐一分派到GPU。除非程序遇到阻塞调用,例如打印计算结果,否则这种调度将尽可能是异步的(请参见Asynchronous Dispatch)。

考虑到这些事实,我怀疑你的代码运行缓慢并且不会使图形处理器饱和的原因是因为每个update操作都是非常小的计算,因此每个循环中的调度开销占主导地位。通常,这种分派开销是由设备传输引起的(即,如果epochstrain_loader的内容还没有存在于GPU中)。由于异步分派,如果您避免阻塞调用,例如在循环中间打印损失函数,则此分派开销的累积影响可能不是什么问题。一种更好的解决方案可能是将循环推入XLA (如果循环的数量很少,则通过即时编译整个训练过程;如果循环的数量很大,则使用lax control flow ),但这取决于epochstrain_loader的大小以及数据是存在于设备上还是需要传输。

票数 0
EN

Stack Overflow用户

发布于 2021-12-01 12:13:00

几天前我遇到了这个问题,我的进程在cpu上运行,但使用了太多的GPU内存。原因可能是你安装了只有cpu的false jax和jaxlib版本,你可以通过安装如下的gpu版本来解决这个问题:

代码语言:javascript
复制
   pip install --upgrade jax==0.2.3 jaxlib==0.1.69+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

注意!您最好检查您的服务器/计算机的cuda驱动程序版本,此外,您可以浏览https://github.com/google/jax

你会知道更多的细节

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69922894

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档