首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >jax中的泡菜更改类型

jax中的泡菜更改类型
EN

Stack Overflow用户
提问于 2022-05-10 17:58:25
回答 1查看 127关注 0票数 1

我有一个包含jax数组的亚麻结构数据集。

当我选择转储这个对象并再次加载它时,该数组不再是jax numpy数组,而是转换为numpy数组,下面是再现它的代码:

代码语言:javascript
复制
import flax
import jax.numpy as jnp
import pickle

@flax.struct.dataclass
class A:
    data: jnp.ndarray

a = A(data=jnp.zeros((2,2)))
print(a, type(a.data))



with open('file.pickle', 'wb') as handle:
    pickle.dump(a, handle)
    
with open('file.pickle', 'rb') as handle:
    loaded_a = pickle.load(handle)

print(loaded_a, type(loaded_a.data))

我不想这种行为,我想让它保持原来的类型,有可能吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-05-10 19:40:03

更新:此bug已在中修复。从JAX的下一个版本(v.0.3.14)开始,pickledeepcopy不应该再将JAX数组转换为设备数组。

这是JAX中已知的行为,请参阅https://github.com/google/jax/issues/2632

库开发人员认为这是一种不幸的行为,但尚未对修复进行优先排序。如果你有兴趣的话,你可以考虑一下这个问题。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72191025

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档