我有一个2D数组,每一行表示一个分类器的输出,该分类器将一些输入分类为3类(数组大小为1000 * 3):
0.3 0.3 0.3
0.3 0.3 1.0
1.0 0.3 0.3
0.3 0.3 0.3
0.3 1.0 0.3
...我想得到分类器对它们“不确定”的所有输入的列表。我把“不确定”定义为没有超过0.8的类别。
为了解决这个问题,我用:
np.where(model1_preds.max(axis=1) < 0.8)这个很好用。
但是现在我有了6个分类器(它们以相同的顺序分析了相同的输入),还有一个数组6 * 1000 * 3表示它们的结果。
我想找到两件事:
我认为总的方向是这样的:
np.stack(np.where(model_preds.max(axis=1) < 0.8) for model_preds in all_preds)但是它不能工作,因为python不知道我在for循环中的意思。
发布于 2017-07-21 21:56:37
可选到np.where
res_all_unsure = preds[:,np.amax(preds, axis=(0,2)) <= 0.8,:]
res_one_unsure = preds[:,preds.max(-1).min(0) <= 0.8,:]https://stackoverflow.com/questions/45246903
复制相似问题