我使用tflearn的mnist数据集来预测手写数字。
一切正常,但我的标签是one_hot。tflearn中是否有与Tensorflow中的argmax()相同的函数?
发布于 2017-07-20 04:42:59
您可以简单地这样做:
pred = model.predict(test_data)
print([ np.where(r==1)[0][0] for r in np.round(pred) ])最好的。
https://stackoverflow.com/questions/43920064
复制相似问题