我正在将logits与循环中的标签进行比较:
for r in range(logits.shape[0]):
if labels[r] == np.argmax(logits[r]):
guessed += 1.0其中labels是整数标签的一维数组,logits是二维数组,第二维是标签的概率。
上面的解决方案是一个Python循环,效率不是很高。应该有一个常用的numpy或tensorflow快捷方式来做到这一点。你能推荐一个吗?
发布于 2019-02-19 11:29:55
你可以通过np.argmax(logits,axis=1)一次得到所有的最大值。以下代码可以替换for循环,以获得猜测的总数:
guessed = np.sum(labels == np.argmax(logits,axis=1))https://stackoverflow.com/questions/54757088
复制相似问题