首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >多维jax.isin()

多维jax.isin()
EN

Stack Overflow用户
提问于 2022-10-21 13:14:48
回答 1查看 56关注 0票数 1

我正在尝试过滤一个三元组的数组。我要筛选的标准是,另一个三元组是否至少包含一个具有相同的第一个第三个元素的元素。E.g

代码语言:javascript
复制
import jax.numpy as jnp
array1 = jnp.array(
  [
    [0,1,2],
    [1,0,2],
    [0,3,3],
    [3,0,1],
    [0,1,1],
    [1,0,3],
  ]
)
array2 = jnp.array([[0,1,3],[0,3,2]])
# the mask to filter the first array1 should look like this:
jnp.array([True,False,True,False,False,False])

使用jax实现此掩码的有效计算方法是什么?我期待着你的意见。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-10-21 15:00:13

要做到这一点,可以减少通过广播的平等检查:

代码语言:javascript
复制
import jax.numpy as jnp
array1 = jnp.array(
  [
    [0,1,2],
    [1,0,2],
    [0,3,3],
    [3,0,1],
    [0,1,1],
    [1,0,3],
  ]
)
array2 = jnp.array([[0,1,2],[0,3,2]])  # note adjustment to match first entry of array1

mask = (array1[:, None] == array2[None, :]).all(-1).any(-1)
print(mask)
# [ True False False False False False]

XLA没有任何类似于二进制搜索的原语,所以通常最好的方法是生成完全相等的矩阵并进行约简。如果您像GPU/TPU一样在加速器上运行代码,这种向量化操作是高效的并行化操作,因此在实践中将非常有效地计算它。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74154196

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档