我有一个包含jax数组的亚麻结构数据集。
当我选择转储这个对象并再次加载它时,该数组不再是jax numpy数组,而是转换为numpy数组,下面是再现它的代码:
import flax
import jax.numpy as jnp
import pickle
@flax.struct.dataclass
class A:
data: jnp.ndarray
a = A(data=jnp.zeros((2,2)))
print(a, type(a.data))
with open('file.pickle', 'wb') as handle:
pickle.dump(a, handle)
with open('file.pickle', 'rb') as handle:
loaded_a = pickle.load(handle)
print(loaded_a, type(loaded_a.data))我不想这种行为,我想让它保持原来的类型,有可能吗?
发布于 2022-05-10 19:40:03
更新:此bug已在中修复。从JAX的下一个版本(v.0.3.14)开始,pickle和deepcopy不应该再将JAX数组转换为设备数组。
这是JAX中已知的行为,请参阅https://github.com/google/jax/issues/2632
库开发人员认为这是一种不幸的行为,但尚未对修复进行优先排序。如果你有兴趣的话,你可以考虑一下这个问题。
https://stackoverflow.com/questions/72191025
复制相似问题