我制作了一个简单的脚本来尝试使用JAX进行梯度积累。这样做的目的是让大批处理大小(例如64)被分割成适合GPU内存的小块(例如4)。对于每个chunck,存储在pytree中的结果梯度被添加到当前批处理梯度中。只有在计算大批的所有块时才会进行更新。在这个特殊的例子中,我们只是尝试将随机512维向量拟合到具有线性层的随机布尔函数中。下面是脚本:
import jax
import jax.numpy as jnp
from jax import jit, random
from jax.experimental import optimizers
from functools import partial
from jax.nn.initializers import normal, zeros
from typing import Callable
from dataclasses import dataclass
@dataclass
class Jax_model:
init_fun: Callable
apply_fun: Callable
def Dense(input_size: int, output_size: int, init_kernel=normal(), init_bias=zeros):
def init_fun(key):
key, sub_key1, sub_key2 = jax.random.split(key, 3)
params = {
'I': init_kernel(sub_key1, (input_size, output_size) ),
'I_b': init_bias(sub_key2, (1,output_size) ),
}
return params
def apply_fun(params, inputs):
I, I_b, = params['I'], params['I_b']
logits = inputs @ I + I_b
return logits
return Jax_model(init_fun, apply_fun)
def divide_pytree(pytree, div):
for pt in jax.tree_util.tree_leaves(pytree):
pt = pt / div
return pytree
def add_pytrees(pytree1, pytree2):
for pt1, pt2 in zip( jax.tree_util.tree_leaves(pytree1), jax.tree_util.tree_leaves(pytree2) ):
pt1 = pt1 + pt2
return pytree1
rng_key = random.PRNGKey(42)
batch_size = 64
accumulation_size = 4
model_dim = 512
n_iter = 50
model = Dense(model_dim, 1)
rng_key, sub_key = random.split(rng_key)
init_params = model.init_fun(sub_key)
opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(init_params)
@jit
def update(i, current_opt_state, current_batch):
N = current_batch[0].shape[0]
K = accumulation_size
num_gradients = N//K
accumulation_batch = (current_batch[ib][0:K] for ib in range(len(current_batch)))
value, grads = jax.value_and_grad(loss_func)(get_params(current_opt_state), accumulation_batch)
value = value / num_gradients
grads = divide_pytree(grads, num_gradients)
for k in range(K,N,K):
accumulation_batch = (current_batch[ib][k:k+K] for ib in range(len(current_batch)))
new_value, new_grads = jax.value_and_grad(loss_func)(get_params(current_opt_state), accumulation_batch)
value = value + (new_value / num_gradients)
grads = add_pytrees(grads, divide_pytree(new_grads, num_gradients))
return opt_update(i, grads, current_opt_state), value
def loss_func(current_params, current_batch):
inputs, labels = current_batch
predictions = model.apply_fun(current_params, inputs)
loss = jnp.square(labels-predictions).sum()
return loss
for i in range(n_iter):
rng_key, sub_key1, sub_key2 = random.split(rng_key, 3)
inputs = jax.random.uniform(sub_key1, (batch_size, model_dim))
labels = jax.random.uniform(sub_key2, (batch_size, 1)) > 0.5
batch = inputs, labels
opt_state, batch_loss = update(i, opt_state, batch)
print(i, batch_loss)我对divide_pytree和add_pytrees有疑问。它是否实际上修改了当前的批处理梯度,还是我遗漏了什么?此外,您看到这个代码的速度问题吗?特别是,我应该使用jax.lax.fori_loop来代替传统的python循环吗?
相关链接:
发布于 2021-06-17 17:03:35
关于pytree计算:正如编写的那样,您的函数将返回未经修改的输入。更好的方法是使用jax.tree_util.tree_map;例如:
from jax.tree_util import tree_map
def divide_pytree(pytree, div):
return tree_map(lambda pt: pt / div, pytree)
def add_pytrees(pytree1, pytree2):
return tree_map(lambda pt1, pt2: pt1 + pt2, pytree1, pytree2)关于性能:在JIT编译时,for循环中的任何内容都将被扁平化,每个循环的每个迭代都有一个所有XLA指令的重复副本。如果您有5个迭代,这实际上不是一个问题。如果您有5000,这将大大降低编译时间(因为XLA需要分析和优化循环中的5000条指令的显式副本)。
fori_loop可以提供帮助,但不会导致最佳代码,特别是在CPU和GPU上运行时。
最好是在可能的情况下使用广播或vmapped操作来表示循环的逻辑,而不需要显式循环。
https://stackoverflow.com/questions/68016425
复制相似问题