首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何将decode_predictions()用于非Imagenet模型..?

如何将decode_predictions()用于非Imagenet模型..?
EN

Data Science用户
提问于 2019-07-04 12:07:33
回答 1查看 6.3K关注 0票数 3

我知道decode_predictions()只适用于VGG16等模型的imagenet (1000个类),但更适合我的场景。

我的场景

我使用了vgg16预先训练过的模型,并增加了自己的权重。

这是一个非Imagenet模型。我提到了classes=9,因为我只使用了9个类来训练我以前的模型。

因此,现在要找到预测,我可以使用预测()方法,然后我的print(答案)会给出相应的类标签。但实际上,我需要打印类名。这有可能得到类名吗?如果是的话,有人能解释我是怎么做到的吗?

EN

回答 1

Data Science用户

发布于 2019-07-04 16:53:02

在深度学习中,当您执行预测时,您将得到在您的情况下的预测,这是一个包含9个概率的数组。因此,首先对它们执行以下操作。

代码语言:javascript
复制
import numpy as np

prediction = model.prediction(test_data) 
# prediction will contain [[0.6, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05]]

prediction = np.argmax(prediction[0])
# Now predition hold the index of the (0.6) i.e max probability value

在那之后,您应该有一个字典,其中包含0,1,2,..8和值为classname1、classname2、...classname9

这就是将类名作为输出的方式。

票数 1
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/55046

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档