首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >JAX梯度积累

JAX梯度积累
EN

Stack Overflow用户
提问于 2021-06-17 09:17:50
回答 1查看 619关注 0票数 2

我制作了一个简单的脚本来尝试使用JAX进行梯度积累。这样做的目的是让大批处理大小(例如64)被分割成适合GPU内存的小块(例如4)。对于每个chunck,存储在pytree中的结果梯度被添加到当前批处理梯度中。只有在计算大批的所有块时才会进行更新。在这个特殊的例子中,我们只是尝试将随机512维向量拟合到具有线性层的随机布尔函数中。下面是脚本:

代码语言:javascript
复制
import jax
import jax.numpy as jnp
from jax import jit, random
from jax.experimental import optimizers
from functools import partial
from jax.nn.initializers import normal, zeros
from typing import Callable
from dataclasses import dataclass

@dataclass
class Jax_model:
    init_fun: Callable
    apply_fun: Callable


def Dense(input_size: int, output_size: int, init_kernel=normal(), init_bias=zeros):

    def init_fun(key):
        key, sub_key1, sub_key2 = jax.random.split(key, 3)
        params = {
            'I': init_kernel(sub_key1, (input_size, output_size) ),
            'I_b': init_bias(sub_key2, (1,output_size) ),
        }
        return params

    def apply_fun(params, inputs):
        I, I_b, = params['I'], params['I_b']
        logits = inputs @ I + I_b
        return logits

    return Jax_model(init_fun, apply_fun)


def divide_pytree(pytree, div):
    for pt in jax.tree_util.tree_leaves(pytree):
        pt = pt / div
    return pytree


def add_pytrees(pytree1, pytree2):
    for pt1, pt2 in zip( jax.tree_util.tree_leaves(pytree1), jax.tree_util.tree_leaves(pytree2) ):
        pt1 = pt1 + pt2
    return pytree1


rng_key = random.PRNGKey(42)
batch_size = 64
accumulation_size = 4
model_dim = 512
n_iter = 50

model = Dense(model_dim, 1)
rng_key, sub_key = random.split(rng_key)
init_params = model.init_fun(sub_key)
opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(init_params)

@jit
def update(i, current_opt_state, current_batch):
    N = current_batch[0].shape[0]
    K = accumulation_size
    num_gradients = N//K
    accumulation_batch = (current_batch[ib][0:K] for ib in range(len(current_batch)))
    value, grads = jax.value_and_grad(loss_func)(get_params(current_opt_state), accumulation_batch)
    value = value / num_gradients
    grads = divide_pytree(grads, num_gradients)
    for k in range(K,N,K):
        accumulation_batch = (current_batch[ib][k:k+K] for ib in range(len(current_batch)))
        new_value, new_grads = jax.value_and_grad(loss_func)(get_params(current_opt_state), accumulation_batch)
        value = value + (new_value / num_gradients)
        grads = add_pytrees(grads, divide_pytree(new_grads, num_gradients))
    return opt_update(i, grads, current_opt_state), value

def loss_func(current_params, current_batch):
    inputs, labels = current_batch
    predictions = model.apply_fun(current_params, inputs)
    loss = jnp.square(labels-predictions).sum()
    return loss

for i in range(n_iter):
    rng_key, sub_key1, sub_key2 = random.split(rng_key, 3)
    inputs = jax.random.uniform(sub_key1, (batch_size, model_dim))
    labels = jax.random.uniform(sub_key2, (batch_size, 1)) > 0.5
    batch = inputs, labels
    opt_state, batch_loss = update(i, opt_state, batch)
    print(i, batch_loss)

我对divide_pytreeadd_pytrees有疑问。它是否实际上修改了当前的批处理梯度,还是我遗漏了什么?此外,您看到这个代码的速度问题吗?特别是,我应该使用jax.lax.fori_loop来代替传统的python循环吗?

相关链接:

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-06-17 17:03:35

关于pytree计算:正如编写的那样,您的函数将返回未经修改的输入。更好的方法是使用jax.tree_util.tree_map;例如:

代码语言:javascript
复制
from jax.tree_util import tree_map

def divide_pytree(pytree, div):
  return tree_map(lambda pt: pt / div, pytree)

def add_pytrees(pytree1, pytree2):
  return tree_map(lambda pt1, pt2: pt1 + pt2, pytree1, pytree2)

关于性能:在JIT编译时,for循环中的任何内容都将被扁平化,每个循环的每个迭代都有一个所有XLA指令的重复副本。如果您有5个迭代,这实际上不是一个问题。如果您有5000,这将大大降低编译时间(因为XLA需要分析和优化循环中的5000条指令的显式副本)。

fori_loop可以提供帮助,但不会导致最佳代码,特别是在CPU和GPU上运行时。

最好是在可能的情况下使用广播或vmapped操作来表示循环的逻辑,而不需要显式循环。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68016425

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档