首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在多维数组上使用numpy.where

在多维数组上使用numpy.where
EN

Stack Overflow用户
提问于 2017-07-21 21:23:41
回答 1查看 616关注 0票数 2

我有一个2D数组,每一行表示一个分类器的输出,该分类器将一些输入分类为3类(数组大小为1000 * 3):

代码语言:javascript
复制
0.3 0.3 0.3
0.3 0.3 1.0
1.0 0.3 0.3
0.3 0.3 0.3
0.3 1.0 0.3
...

我想得到分类器对它们“不确定”的所有输入的列表。我把“不确定”定义为没有超过0.8的类别。

为了解决这个问题,我用:

代码语言:javascript
复制
np.where(model1_preds.max(axis=1) < 0.8)

这个很好用。

但是现在我有了6个分类器(它们以相同的顺序分析了相同的输入),还有一个数组6 * 1000 * 3表示它们的结果。

我想找到两件事:

  1. 所有的输入,至少一个分类器是“不确定”的。
  2. 所有分类器都“不确定”的所有输入。

我认为总的方向是这样的:

代码语言:javascript
复制
np.stack(np.where(model_preds.max(axis=1) < 0.8) for model_preds in all_preds)

但是它不能工作,因为python不知道我在for循环中的意思。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-07-21 21:56:37

可选到np.where

代码语言:javascript
复制
res_all_unsure = preds[:,np.amax(preds, axis=(0,2)) <= 0.8,:]
res_one_unsure = preds[:,preds.max(-1).min(0) <= 0.8,:]
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45246903

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档