我有一组one-hot编码标签,我想看看每个类别有多少个。每个标签可以包含一个或多个类,如下所示:
[1 0 0 0 0 0 0 1 0]我对这个问题的第一个解决方案是像这样使用np.argmax和np.bincount:
newLabels = []
for i in range(len(labels)):
newLabels.append(np.argmax(labels[i]))
newLabels= np.asarray(newLabels)
np.bincount(newLabels)array([1221, 722, 199, 918, 599, 678, 1569, 786, 185])但是接下来发生的事情是,上面的one-hot编码示例将被赋予值0,而第二个值(应该是7)不会被计算在内。
有谁有解决这个问题的办法吗?
发布于 2020-04-06 16:07:42
from collections import Counter
newLabels = Counter()
for label in labels:
for idx, key in enumerate(label):
newLabels[idx]+=key输出应该是一个字典,其中键是标签、索引,值是计数。
发布于 2020-04-06 16:12:05
此问题的解决方案是:
np.sum(Labels, axis=0)https://stackoverflow.com/questions/61055378
复制相似问题