首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用于xarray的JAX pytree

用于xarray的JAX pytree
EN

Stack Overflow用户
提问于 2022-04-21 05:43:51
回答 1查看 361关注 0票数 2

是否可以基于xarray.DataArray创建一个pytree (用于JAX)

关键问题似乎是,data属性的xarray.DataArray被构造为输入数组的numpy.ndarray视图(如DeviceArray)。

代码语言:javascript
复制
import jax.numpy as jnp
import xarray as xr

da = xr.DataArray(
    data=jnp.array([2.0,3.0]),
    dims=("var"),
    coords={"var": ["A","B"]},
)

>>> type(da.data)

    numpy.ndarray

DataArray扁平化为pytree是相对直接的(除了data之外,所有属性都是辅助的),但是我不知道如何检索DeviceArray,甚至不知道如何分配(这里不能使用at[.].set(.) )。

构建容器类(带有DeviceArray成员)的替代方法要求手动实现xarray的所有相关功能。对于一个单一的特性,例如标记索引,这是可能的,但也是多余的。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-04-21 22:58:10

您可以按照PyTree中的示例将xarray类型注册为自定义文档节点,从而实现您想做的事情。

例如,它可能如下所示:

代码语言:javascript
复制
import jax.numpy as jnp
from jax import tree_util
import xarray as xr

tree_util.register_pytree_node(
    xr.DataArray,
    lambda x: ((x.data,), {"dims": x.dims, "coords": x.coords}),
    lambda kwds, args: xr.DataArray(*args, **kwds)
)

da = xr.DataArray(
    data=jnp.array([2.0,3.0]),
    dims=("var"),
    coords={"var": ["A","B"]},
)
print(da)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
#   * var      (var) <U1 'A' 'B'

data, tree = tree_util.tree_flatten(da)
print(data)
# [array([2., 3.], dtype=float32)]

da_reconstructed = tree_util.tree_unflatten(tree, data)
print(da_reconstructed)
# <xarray.DataArray (var: 2)>
# array([2., 3.], dtype=float32)
# Coordinates:
#   * var      (var) <U1 'A' 'B'

我不认为在任何最简单的情况下,JAX转换都不会像预期的那样工作:例如,JAX转换仅限于功能性的、不受影响的代码,JAX数组是不可变的。xarray的操作通常违反这两个约束。

另一个问题:您必须小心地定义扁平函数和非平坦函数,以确保所有数据和元数据都被正确序列化,如果希望在任何JAX函数中使用这些数据和元数据,请注意,您可能会遇到xarray对输入的验证问题;有关更多信息,请参见自定义PyTrees和初始化

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71949459

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档