首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >JAX -JAX函数:参数与“全局”变量

JAX -JAX函数:参数与“全局”变量
EN

Stack Overflow用户
提问于 2022-09-06 11:27:50
回答 1查看 169关注 0票数 1

我对贾克斯有以下疑问。我将使用官方税务文件中的一个例子来说明这一点:

代码语言:javascript
复制
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  @jax.jit
  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      print(f'step {i}, loss: {loss_value}')

  return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)

在本例中,函数step使用变量optimizer,尽管它没有在函数参数中传递(因为函数被抛出,optax.GradientTransformation不是受支持的类型)。但是,相同的函数使用其他变量作为参数传递(即params, opt_state, batch, labels)。我知道jax函数需要是纯的,这样才能被抛出,但是输入(只读)变量呢?如果我通过函数参数访问一个变量,或者直接访问它,因为它在step函数范围内,这有什么区别吗?如果这个变量不是常量,而是在不同的step调用之间被修改了呢?如果直接访问,它们是否被视为静态参数?还是干脆就放弃了,这样就不会考虑修改这些参数了?

具体来说,让我们看一下下面的示例:

代码语言:javascript
复制
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)
  extra_learning_rate = 0.1

  @jax.jit
  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    updates *= extra_learning_rate # not really valid code, but you get the idea
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    extra_learning_rate = 0.1
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    extra_learning_rate = 0.01 # does this affect the next `step` call?
    params, opt_state, loss_value = step(params, opt_state, batch, labels)

  return params

vs

代码语言:javascript
复制
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)
  extra_learning_rate = 0.1

  @jax.jit
  def step(params, opt_state, batch, labels, extra_lr):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    updates *= extra_lr # not really valid code, but you get the idea
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    extra_learning_rate = 0.1
    params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)
    extra_learning_rate = 0.01 # does this now affect the next `step` call?
    params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)

  return params

从我有限的实验中,他们的表现不同,因为第二次step调用没有在全球范围内使用新的学习速率,也没有“重新退出”发生,但我想知道是否有任何标准的实践/规则我需要知道。我正在编写一个库,其中性能是基本的,我不想错过一些jit优化,因为我做错了事情。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-09-06 17:59:54

在JIT跟踪期间,JAX将全局值视为跟踪函数的隐式参数。您可以在表示函数的贾克斯普拉中看到这一点。

下面是两个返回等价结果的简单函数,一个带有隐式参数,另一个带有显式:

代码语言:javascript
复制
import jax
import jax.numpy as jnp

def f_explicit(a, b):
  return a + b

def f_implicit(b):
  return a_global + b

a_global = jnp.arange(5.0)
b = jnp.ones(5)

print(jax.make_jaxpr(f_explicit)(a_global, b))
# { lambda ; a:f32[5] b:f32[5]. let c:f32[5] = add a b in (c,) }

print(jax.make_jaxpr(f_implicit)(b))
# { lambda a:f32[5]; b:f32[5]. let c:f32[5] = add a b in (c,) }

注意,两个jaxprs之间唯一的区别是在f_implicit中,a变量出现在分号之前:这是jaxpr表示通过闭包而不是通过显式参数传递参数的方式。但是这两个函数产生的计算是相同的。

尽管如此,需要注意的一个不同之处是,当通过闭包传递的参数是一个可理解的常量时,它将被视为跟踪函数中的静态参数(类似于当显式参数通过static_argnumsstatic_argnamesjax.jit中标记为静态时):

代码语言:javascript
复制
a_global = 1.0
print(jax.make_jaxpr(f_implicit)(b))
# { lambda ; a:f32[5]. let b:f32[5] = add 1.0 a in (b,) }

注意,在jaxpr表示中,常量值是作为add操作的参数直接插入的。为JIT编译的函数获得相同结果的显式方法如下所示:

代码语言:javascript
复制
from functools import partial

@partial(jax.jit, static_argnames=['a'])
def f_explicit(a, b):
  return a + b
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73621269

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档