首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >ND4j INDArray值索引和布尔运算

ND4j INDArray值索引和布尔运算
EN

Stack Overflow用户
提问于 2021-06-06 19:38:32
回答 2查看 102关注 0票数 2

下面是numpy代码:

代码语言:javascript
复制
import numpy as np

arr1 = np.array([[0, 1, np.nan], [3, np.nan, 5], [np.nan, 7, 8]])
arr2 = np.array([[np.nan, 7, 6], [5, np.nan, 3], [2, 1, np.nan]])

print(arr1)
print(arr2)

arr1为:[ 0.1.南南7. 8.]

arr2是:[nan 7.6.2.1. nan]

然后我就做了:

代码语言:javascript
复制
idx1 = np.isnan(arr1)
idx2 = np.isnan(arr2)

idx = idx1 | idx2

arr1[idx] = -1
arr2[idx] = -1

print(arr1)
print(arr2)

arr1变为:[-1。1. -1.-1。7. -1.]

arr2变为:[-1。7. -1.-1。1. -1.]

然后我想用scala和ND4j重写这段代码:

代码语言:javascript
复制
import org.nd4j.linalg.factory.Nd4j

val arr1 = Nd4j.create(Array(Array(0, 1, Double.NaN), Array(3, Double.NaN, 5), Array(Double.NaN, 7, 8)))
val arr2 = Nd4j.create(Array(Array(Double.NaN, 7, 6), Array(5, Double.NaN, 3), Array(2, 1, Double.NaN)))

println(arr1)
println(arr2)

val idx1 = arr1.isNaN
val idx2 = arr2.isNaN

val idx = idx1 | idx2 // error

arr1.putWhereWithMask(idx, -1)
arr2.putWhereWithMask(idx, -1)

println(arr1)
println(arr2)

此代码未编译。如何修改?谢谢!

EN

回答 2

Stack Overflow用户

发布于 2021-06-06 20:18:40

您可以使用BooleanIndexing类实现这一点:https://github.com/eclipse/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java

下面是使用nan的测试示例:

代码语言:javascript
复制
 INDArray array = Nd4j.create(new double[] {0, 1, 2, 3, Double.NaN, 5, 6, 7, 8, 9});
int val = (int) Nd4j.getExecutioner().exec(new MatchCondition(array, Conditions.isNan()))
                .getDouble(0);
票数 0
EN

Stack Overflow用户

发布于 2021-06-08 23:05:03

经过大量的搜索和尝试,我找到了一个解决方案:

代码语言:javascript
复制
import org.nd4j.linalg.factory.Nd4j

val arr1 = Nd4j.create(Array(Array(0, 1, Double.NaN), Array(3, Double.NaN, 5), Array(Double.NaN, 7, 8)))
val arr2 = Nd4j.create(Array(Array(Double.NaN, 7, 6), Array(5, Double.NaN, 3), Array(2, 1, Double.NaN)))

println(arr1)
println(arr2)

val idx1 = arr1.isNaN
val idx2 = arr2.isNaN

val idx = Nd4j.createUninitialized(DataType.BOOL, idx1.shape():_*)

Nd4j.exec(new Or(idx1, idx2, idx))

val minusOnes = Nd4j.zerosLike(arr1).subi(1.0)

val arr1_modified = Nd4j.where(idx, arr1, minusOnes)(0)
val arr2_modified = Nd4j.where(idx, arr2, minusOnes)(0)

println(arr1_modified)
println(arr2_modified)

其他解决方案也很受欢迎。谢谢!

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67858668

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档