首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >我正在尝试将JAX对象分配给需要具体值的NumPy数组--请按需要工作。

我正在尝试将JAX对象分配给需要具体值的NumPy数组--请按需要工作。
EN

Stack Overflow用户
提问于 2022-08-23 20:24:39
回答 1查看 501关注 0票数 -1

我对贾克斯很陌生。

我正在实现一个变分自动编码器(VAE)使用Jax和亚麻。在培训期间,我取样了一个潜在的代码(从编码器推断的分布,我使用flax.linen.nn模块的组合来实现)。关键的是,除了通过解码器传递这些代码(这是VAE的标准代码)之外,我还将代码传递给外部函数( MuJoCo物理引擎),该函数试图将其分配给NumPy数组。这并不令人意外地导致以下错误:

TracerArrayConversionError: JAX对象上调用的numpy.ndarray转换方法数组().

基本上,我需要向MuJoCo传递一个具体的numpy数组。如何使我的变量成为一个NumPy数组,以便尽可能使用抽象跟踪器以计算效率高的方式实现我的模型?

下面是我所面临问题的一个最小的工作示例--我相信需要安装mujoco (https://mujoco.org/)才能运行这个问题:

代码语言:javascript
复制
import jax
import jax.numpy as np
import numpy as onp
import gym
from jax import jit

# create an instance of an open AI gym environment
env = gym.make('Humanoid-v3')
env.reset()

def this_fails(env, x):
    
    # this gives a TracerArrayConversionError
    env.sim.data.qpos[:] = x

    return env, x

x = np.arange(len(env.sim.data.qpos))
jit_this_fails = jax.jit(this_fails, static_argnums = 0)
env, x = jit_this_fails(env, x)
EN

回答 1

Stack Overflow用户

发布于 2022-08-23 21:00:54

注:这是对“任择议定书”最初所写问题的回答。这个问题已经被编辑了多次,不再问它最初问了什么。

在过去,这类东西不被支持,但是您可以使用JAXVersion0.3.17中的新jax.pure_callback特性来实现这一点,在我编写这个版本时还没有发布这个特性。

例如,假设您希望从JAX jit编译函数中调用基于numpy的函数;为了简单起见,我们将使用np.sin。您可以先尝试这样的方法:

代码语言:javascript
复制
import jax
import jax.numpy as jnp
import numpy as np

@jax.jit
def this_fails(x):
  # Call a numpy function...
  return np.sin(x)

x = jnp.arange(5.0)
this_fails(x)
代码语言:javascript
复制
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function this_fails at tmp.py:7 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

结果是一个TracerConversionError,因为您试图将一个跟踪的JAX值传递到一个函数中,这个函数需要一个numpy数组(请注意:有关JAX和相关主题的介绍,请参见如何在JAX中思考 )。

在JAXVersion0.3.17或更高版本中,您可以使用jax.pure_callback解决这个问题

代码语言:javascript
复制
@jax.jit
def numpy_callback(x):
  # Need to forward-declare the shape & dtype of the expected output.
  result_shape = jax.core.ShapedArray(x.shape, x.dtype)
  return jax.pure_callback(np.sin, result_shape, x)

x = jnp.arange(5.0)
print(numpy_callback(x))
代码语言:javascript
复制
[ 0.         0.841471   0.9092974  0.14112   -0.7568025]

请记住几个注意事项:

  • 结果执行将依赖于对主机的回调,因此在GPU/TPU这样的加速器上,特别是在分布式/多主机设置中,它将非常慢。但是,在本地CPU执行的情况下,它避免了缓冲区副本,并且可以很好地执行。
  • 如果vmap函数,它将导致多个回调的for循环(如果回调函数本地处理批处理,则可以指定vectorized=True )。
  • gradjacobian这样的自定义转换不适用于这个函数,因为JAX无法对所做的计算进行推理。如果您想要将它与自定义转换一起使用,您可以像在自定义衍生规则中那样定义自定义梯度,尽管这需要访问一个计算回调函数梯度的函数。

所有这些都还没有在JAX网站上记录下来,但我们希望尽快为pure_callback编写文档!

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73464697

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档