首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >带有JAX的不规则/非均匀数组

带有JAX的不规则/非均匀数组
EN

Stack Overflow用户
提问于 2022-10-17 00:03:03
回答 1查看 60关注 0票数 0

在JAX中,对不规则/非均匀数据(拥有一些不均匀维度)实现数组行为/方法的推荐方法是什么?

人们想到了两个原则选择:

  1. 使同构化,并使用掩码
  2. 扁平并实现自定义方法(即广播和还原)

显然,备选方案1是有利的,因为这需要较少的实现开销(因此需要验证/测试)。需要关注的是内存复杂性--在这种情况下(为了避免分配数组),是否有更好的替代选项2(可以利用高度优化的数组方法)?

编辑:下面实现了一个包含稀疏性的具体示例。

代码语言:javascript
复制
import jax as jx
import jax.numpy as jnp
jx.config.update("jax_enable_x64", True)


# Problem specific variables (static)
n_vars = 3 # Number of variable sets
n_smps = 10 # Maximum number of set elements
p_smps = 0.2 # Representation of problem sparsity


# Each set contains a differing number of elements (binomial random for example)
n_lvls = jx.random.bernoulli(
    jx.random.PRNGKey(0),
    p_smps,
    (n_vars, n_smps)
).sum(axis=1, dtype='i4')


# Derived quantities depend on constant coefficients (uniform random for example)
a_vars = jx.random.uniform(jx.random.PRNGKey(1), (n_vars, ), dtype='f8')

b_vars = jx.random.uniform(jx.random.PRNGKey(2), (n_vars, ), dtype='f8')
b_vars = 10.0*b_vars

c_vars = jx.random.uniform(jx.random.PRNGKey(3), (n_vars, ), dtype='f8')
c_vars = 2.0*c_vars

这个问题本质上是用7个元素状态表示的。以下是备选方案1的一个实现

代码语言:javascript
复制
### Homogeneous with mask ###

# Define the level index array
i_smps = jnp.arange(n_smps, dtype='i4')
mask = n_lvls[:,None]>i_smps[None,:]

# Generate an initial state that respects the unity axiom
x_vars = 1.0/(1.0+n_lvls[:,None]*i_smps[None,:]).astype('f8')
x_vars = jnp.where(mask, x_vars, 0.0)
x_vars = x_vars/x_vars.sum()


# Generate a coefficient tensor
P_vars = a_vars[:,None]+b_vars[:,None]*i_smps[None,:]
P_vars = jnp.where(mask, P_vars, 0.0)


# Determine a scalar moment
scalar_moment = (x_vars*c_vars[:,None]).sum()
# >>> DeviceArray(0.66574861, dtype=float64)

# Determine a transition tensor
trans_tens = (P_vars[:,:,None,None]-P_vars[None,None,:,:])
trans_tens = trans_tens*x_vars[None,None,:,:]*x_vars[:,:,None,None]
trans_tens.sum(axis=(2,3))
# >>>  DeviceArray([[-0.37032842,  0.16153429,  0.22063015,  0.24335933, ...

确保同质性使这一数字增加到30。此外,计算导出的量涉及到无数的零运算。

EN

回答 1

Stack Overflow用户

发布于 2022-10-18 11:35:49

这是一种使用布尔索引来平缓和减少稀疏性的方法。

代码语言:javascript
复制
# Required for static Boolean indexing
import numpy as np
n_lvls = np.asarray(n_lvls, dtype='i4')

### Flatten with Boolean Indexing ###

i_grid = np.arange(n_lvls.size, dtype='i4')
j_grid = np.arange(n_lvls.max(), dtype='i4')

# Determine the boolean mask
mask = n_lvls[:,None]>j_grid[None,:]
bc_size = (i_grid.size, j_grid.size)

i = np.broadcast_to(i_grid[:,None], bc_size)[mask]
# >>> array([0, 0, 0, 0, 1, 1, 2], dtype=int32)

j = np.broadcast_to(j_grid[None,:], bc_size)[mask]
# >>> array([0, 1, 2, 3, 0, 1, 0], dtype=int32)

# Generate an initial state that respects the unity axiom
x = 1.0/(1.0+n_lvls[i]*j)
x = x/x.sum()


# Generate a coefficient tensor
P_vars = a_vars[i]+b_vars[i]*j


# Determine scalar moment
scalar_moment = (x*c_vars[i]).sum()
# >> DeviceArray(0.66574861, dtype=float64)


# Determine transition tensor
trans_tens = (P_vars[:,None]-P_vars[None,:])
trans_tens = trans_tens*x[None,:]*x[:,None]
trans_tens.sum(axis=1)
# >>> DeviceArray([-0.37032842,  0.16153429,  0.22063015,  0.24335933, ...

可能有一种更有效的内存方法来应用布尔索引来确定ij

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74091461

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档