是否可以基于xarray.DataArray创建一个pytree (用于JAX)
关键问题似乎是,data属性的xarray.DataArray被构造为输入数组的numpy.ndarray视图(如DeviceArray)。
import jax.numpy as jnp
import xarray as xr
da = xr.DataArray(
data=jnp.array([2.0,3.0]),
dims=("var"),
coords={"var": ["A","B"]},
)
>>> type(da.data)
numpy.ndarray将DataArray扁平化为pytree是相对直接的(除了data之外,所有属性都是辅助的),但是我不知道如何检索DeviceArray,甚至不知道如何分配(这里不能使用at[.].set(.) )。
构建容器类(带有DeviceArray成员)的替代方法要求手动实现xarray的所有相关功能。对于一个单一的特性,例如标记索引,这是可能的,但也是多余的。
发布于 2022-04-21 22:58:10
您可以按照PyTree中的示例将xarray类型注册为自定义文档节点,从而实现您想做的事情。
例如,它可能如下所示:
import jax.numpy as jnp
from jax import tree_util
import xarray as xr
tree_util.register_pytree_node(
xr.DataArray,
lambda x: ((x.data,), {"dims": x.dims, "coords": x.coords}),
lambda kwds, args: xr.DataArray(*args, **kwds)
)
da = xr.DataArray(
data=jnp.array([2.0,3.0]),
dims=("var"),
coords={"var": ["A","B"]},
)
print(da)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
# * var (var) <U1 'A' 'B'
data, tree = tree_util.tree_flatten(da)
print(data)
# [array([2., 3.], dtype=float32)]
da_reconstructed = tree_util.tree_unflatten(tree, data)
print(da_reconstructed)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
# * var (var) <U1 'A' 'B'我不认为在任何最简单的情况下,JAX转换都不会像预期的那样工作:例如,JAX转换仅限于功能性的、不受影响的代码,JAX数组是不可变的。xarray的操作通常违反这两个约束。
另一个问题:您必须小心地定义扁平函数和非平坦函数,以确保所有数据和元数据都被正确序列化,如果希望在任何JAX函数中使用这些数据和元数据,请注意,您可能会遇到xarray对输入的验证问题;有关更多信息,请参见自定义PyTrees和初始化。
https://stackoverflow.com/questions/71949459
复制相似问题