首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在MXnet中使用自定义Iterator时,如何预测测试集的标签?

在MXnet中使用自定义Iterator时,如何预测测试集的标签?
EN

Stack Overflow用户
提问于 2017-04-28 16:31:39
回答 1查看 463关注 0票数 0

我有一个大数据集(大约20 2GB用于培训,2GB用于测试),我想使用MXnet和R。由于内存不足,我搜索迭代器来加载由自定义迭代器设置的培训和测试集,并找到了解决方案。

现在,我可以使用本页上的代码对模型进行培训,但问题是,如果我使用保存迭代器读取测试集,如下所示:

代码语言:javascript
复制
test.iter <- CustomCSVIter$new(iter = NULL, data.csv = "test.csv", data.shape = 480, batch.size = batch.size)

然后,预测命令不工作,在页面中没有预测模板;

代码语言:javascript
复制
preds <- predict(model, test.iter)

因此,我的具体问题是,如果我使用页面上的代码构建模型,我如何读取测试集并预测评估过程的标签?我的测试集和训练集在这种格式中。

谢谢你的帮助

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-04-12 20:46:27

就像你解释的那样。您只需使用模型和迭代器调用“预测”:

代码语言:javascript
复制
preds = predict(model, test.iter)

这里唯一的诀窍是预测是按列显示的。我的意思是,如果使用整个你所指的样本,执行它并添加以下行:

代码语言:javascript
复制
test.iter <- CustomCSVIter$new(iter = NULL, data.csv = "mnist_train.csv", data.shape = 28, batch.size = batch.size)
preds = predict(model, test.iter)

preds[,1] # index of the sample to see in the column position

你收到:

代码语言:javascript
复制
 [1] 5.882561e-11 2.826923e-11 7.873914e-11 2.760162e-04 1.221306e-12 9.997239e-01 4.567645e-11 3.177564e-08 1.763889e-07 3.578671e-09

这显示了训练集第一个元素的softmax输出。如果您试图通过只编写preds来打印所有内容,那么您将只看到空值,因为RStudio打印限制为1000 --实际数据将没有机会出现。

请注意,我重用培训数据进行预测。我这样做,因为我不想调整迭代器的代码,它需要能够使用前面有标签和没有标签的数据(培训和测试集)。在现实场景中,您需要调整迭代器,这样它就可以使用和不带标签。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/43684975

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档