我对贾克斯有以下疑问。我将使用官方税务文件中的一个例子来说明这一点:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
@jax.jit
def step(params, opt_state, batch, labels):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value = step(params, opt_state, batch, labels)
if i % 100 == 0:
print(f'step {i}, loss: {loss_value}')
return params
# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)在本例中,函数step使用变量optimizer,尽管它没有在函数参数中传递(因为函数被抛出,optax.GradientTransformation不是受支持的类型)。但是,相同的函数使用其他变量作为参数传递(即params, opt_state, batch, labels)。我知道jax函数需要是纯的,这样才能被抛出,但是输入(只读)变量呢?如果我通过函数参数访问一个变量,或者直接访问它,因为它在step函数范围内,这有什么区别吗?如果这个变量不是常量,而是在不同的step调用之间被修改了呢?如果直接访问,它们是否被视为静态参数?还是干脆就放弃了,这样就不会考虑修改这些参数了?
具体来说,让我们看一下下面的示例:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
extra_learning_rate = 0.1
@jax.jit
def step(params, opt_state, batch, labels):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
updates *= extra_learning_rate # not really valid code, but you get the idea
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
extra_learning_rate = 0.1
params, opt_state, loss_value = step(params, opt_state, batch, labels)
extra_learning_rate = 0.01 # does this affect the next `step` call?
params, opt_state, loss_value = step(params, opt_state, batch, labels)
return paramsvs
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
extra_learning_rate = 0.1
@jax.jit
def step(params, opt_state, batch, labels, extra_lr):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
updates *= extra_lr # not really valid code, but you get the idea
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
extra_learning_rate = 0.1
params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)
extra_learning_rate = 0.01 # does this now affect the next `step` call?
params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)
return params从我有限的实验中,他们的表现不同,因为第二次step调用没有在全球范围内使用新的学习速率,也没有“重新退出”发生,但我想知道是否有任何标准的实践/规则我需要知道。我正在编写一个库,其中性能是基本的,我不想错过一些jit优化,因为我做错了事情。
发布于 2022-09-06 17:59:54
在JIT跟踪期间,JAX将全局值视为跟踪函数的隐式参数。您可以在表示函数的贾克斯普拉中看到这一点。
下面是两个返回等价结果的简单函数,一个带有隐式参数,另一个带有显式:
import jax
import jax.numpy as jnp
def f_explicit(a, b):
return a + b
def f_implicit(b):
return a_global + b
a_global = jnp.arange(5.0)
b = jnp.ones(5)
print(jax.make_jaxpr(f_explicit)(a_global, b))
# { lambda ; a:f32[5] b:f32[5]. let c:f32[5] = add a b in (c,) }
print(jax.make_jaxpr(f_implicit)(b))
# { lambda a:f32[5]; b:f32[5]. let c:f32[5] = add a b in (c,) }注意,两个jaxprs之间唯一的区别是在f_implicit中,a变量出现在分号之前:这是jaxpr表示通过闭包而不是通过显式参数传递参数的方式。但是这两个函数产生的计算是相同的。
尽管如此,需要注意的一个不同之处是,当通过闭包传递的参数是一个可理解的常量时,它将被视为跟踪函数中的静态参数(类似于当显式参数通过static_argnums或static_argnames在jax.jit中标记为静态时):
a_global = 1.0
print(jax.make_jaxpr(f_implicit)(b))
# { lambda ; a:f32[5]. let b:f32[5] = add 1.0 a in (b,) }注意,在jaxpr表示中,常量值是作为add操作的参数直接插入的。为JIT编译的函数获得相同结果的显式方法如下所示:
from functools import partial
@partial(jax.jit, static_argnames=['a'])
def f_explicit(a, b):
return a + bhttps://stackoverflow.com/questions/73621269
复制相似问题