我对贾克斯很陌生。
我正在实现一个变分自动编码器(VAE)使用Jax和亚麻。在培训期间,我取样了一个潜在的代码(从编码器推断的分布,我使用flax.linen.nn模块的组合来实现)。关键的是,除了通过解码器传递这些代码(这是VAE的标准代码)之外,我还将代码传递给外部函数( MuJoCo物理引擎),该函数试图将其分配给NumPy数组。这并不令人意外地导致以下错误:
TracerArrayConversionError: JAX对象上调用的numpy.ndarray转换方法数组().
基本上,我需要向MuJoCo传递一个具体的numpy数组。如何使我的变量成为一个NumPy数组,以便尽可能使用抽象跟踪器以计算效率高的方式实现我的模型?
下面是我所面临问题的一个最小的工作示例--我相信需要安装mujoco (https://mujoco.org/)才能运行这个问题:
import jax
import jax.numpy as np
import numpy as onp
import gym
from jax import jit
# create an instance of an open AI gym environment
env = gym.make('Humanoid-v3')
env.reset()
def this_fails(env, x):
# this gives a TracerArrayConversionError
env.sim.data.qpos[:] = x
return env, x
x = np.arange(len(env.sim.data.qpos))
jit_this_fails = jax.jit(this_fails, static_argnums = 0)
env, x = jit_this_fails(env, x)发布于 2022-08-23 21:00:54
注:这是对“任择议定书”最初所写问题的回答。这个问题已经被编辑了多次,不再问它最初问了什么。
在过去,这类东西不被支持,但是您可以使用JAXVersion0.3.17中的新jax.pure_callback特性来实现这一点,在我编写这个版本时还没有发布这个特性。
例如,假设您希望从JAX jit编译函数中调用基于numpy的函数;为了简单起见,我们将使用np.sin。您可以先尝试这样的方法:
import jax
import jax.numpy as jnp
import numpy as np
@jax.jit
def this_fails(x):
# Call a numpy function...
return np.sin(x)
x = jnp.arange(5.0)
this_fails(x)jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function this_fails at tmp.py:7 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError结果是一个TracerConversionError,因为您试图将一个跟踪的JAX值传递到一个函数中,这个函数需要一个numpy数组(请注意:有关JAX和相关主题的介绍,请参见如何在JAX中思考 )。
在JAXVersion0.3.17或更高版本中,您可以使用jax.pure_callback解决这个问题
@jax.jit
def numpy_callback(x):
# Need to forward-declare the shape & dtype of the expected output.
result_shape = jax.core.ShapedArray(x.shape, x.dtype)
return jax.pure_callback(np.sin, result_shape, x)
x = jnp.arange(5.0)
print(numpy_callback(x))[ 0. 0.841471 0.9092974 0.14112 -0.7568025]请记住几个注意事项:
vmap函数,它将导致多个回调的for循环(如果回调函数本地处理批处理,则可以指定vectorized=True )。grad和jacobian这样的自定义转换不适用于这个函数,因为JAX无法对所做的计算进行推理。如果您想要将它与自定义转换一起使用,您可以像在自定义衍生规则中那样定义自定义梯度,尽管这需要访问一个计算回调函数梯度的函数。所有这些都还没有在JAX网站上记录下来,但我们希望尽快为pure_callback编写文档!
https://stackoverflow.com/questions/73464697
复制相似问题