下午好,
我刚开始学习JAX和Haiku,但是我不能在GPU上运行我的代码。我在激活了GPU的Google Colab和Kaggle笔记本上运行了我的代码,但这比禁用GPU需要更多的时间。
此外,当我查看GPU指标时,我发现我只使用了1%的计算能力,但使用了90%的GPU内存。
下面是我的代码(MNIST的MLP):
def mlp(images):
model = hk.Sequential([hk.Linear(128),
jax.nn.relu,
hk.Linear(64),
jax.nn.relu,
hk.Linear(10),
jax.nn.log_softmax])
return model(images)
def loss(params, model, images, labels):
logits = model.apply(params = params, images = images)
labels = jax.nn.one_hot(labels, num_classes = 10)
cross_entropy_loss = -jnp.sum(labels*logits)/len(labels)
return cross_entropy_loss# Initializing the MLP model
mlp = hk.without_apply_rng(hk.transform(mlp))
params = mlp.init(rng = jax.random.PRNGKey(0),
images = next(iter(train_loader))[0])
# Initializing the optimizer
opt = optax.adam(1e-4)
opt_state = opt.init(params = params)@jax.jit
def update(params, opt_state, images, labels):
grads = jax.grad(loss)(params,mlp,images,labels)
updates, opt_state = opt.update(grads, opt_state)
return optax.apply_updates(params, updates), opt_state
def train(params, opt_state, epochs):
for epoch in range(epochs):
for batch_idx, (images, labels) in enumerate(train_loader):
if batch_idx == 0:
print(f"Epoch {epoch} : loss = {loss(params,mlp,images,labels)}")
params, opt_state = update(params, opt_state, images,labels)
%time train(params, opt_state, epochs = 10)如果你知道我做错了什么,你会帮我很多。谢谢。
发布于 2021-11-11 12:43:56
这个问题很难回答,因为不清楚epochs或train_loader包含什么。但一般的回答是:
for循环等控制流将在您的CPU上执行,将内部计算逐一分派到GPU。除非程序遇到阻塞调用,例如打印计算结果,否则这种调度将尽可能是异步的(请参见Asynchronous Dispatch)。考虑到这些事实,我怀疑你的代码运行缓慢并且不会使图形处理器饱和的原因是因为每个update操作都是非常小的计算,因此每个循环中的调度开销占主导地位。通常,这种分派开销是由设备传输引起的(即,如果epochs或train_loader的内容还没有存在于GPU中)。由于异步分派,如果您避免阻塞调用,例如在循环中间打印损失函数,则此分派开销的累积影响可能不是什么问题。一种更好的解决方案可能是将循环推入XLA (如果循环的数量很少,则通过即时编译整个训练过程;如果循环的数量很大,则使用lax control flow ),但这取决于epochs和train_loader的大小以及数据是存在于设备上还是需要传输。
发布于 2021-12-01 12:13:00
几天前我遇到了这个问题,我的进程在cpu上运行,但使用了太多的GPU内存。原因可能是你安装了只有cpu的false jax和jaxlib版本,你可以通过安装如下的gpu版本来解决这个问题:
pip install --upgrade jax==0.2.3 jaxlib==0.1.69+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html注意!您最好检查您的服务器/计算机的cuda驱动程序版本,此外,您可以浏览https://github.com/google/jax
你会知道更多的细节
https://stackoverflow.com/questions/69922894
复制相似问题