我正在学习如何在tensorflow 2.0和Udemy课程的Keras中从头开始创建MNIST模型。
因此,我获得了mnist数据集,如下所示
mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']一切都很好,即使我对我的模型进行了97%的测试,我也很高兴。
当我尝试做一些与课程不同的事情时,问题就开始了。我尝试使用matplotlib plt.imshow()打印mnist_dataset中的一些示例,但完全失败了。然后我开始了一些研究,我得到了一个解决方案,我需要像这样获得数据集:
mnist_dataset2 = tfds.load(name = 'mnist')
mnistt = mnist_dataset2['train']其中mnistt是我可以使用matplotlib操作和打印的数据集。
因此,我的问题如下:我从哪里可以获得有关tfds.load()类型的信息,以及如何根据需要正确地操作它们?(对于像我这样的tensorflow初学者来说,这在某种程度上是可扩展的)。
发布于 2019-11-04 03:51:47
tfds.load方法的主要调用包含您需要的所有内容:
mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)name="mnist" (mnist)with_info=True您正在指定要使用的构建器,您正在请求tfds.load返回包含所有您需要知道的info对象的所有信息关于返回的name="mnist" ->您正在请求tfds.load仅获取监督学习任务所需的数据集的元素(图像和标签对)。您第一次尝试使用mnist_dataset获取数据(与matplotlib一起使用)失败了,因为您可以从
print(mnist_info) #run me!数据集包含两个不同的拆分:train和test。
tfds.core.DatasetInfo(
name='mnist',
version=1.0.0,
description='The MNIST database of handwritten digits.',
urls=['https://storage.googleapis.com/cvdf-datasets/mnist/'],
features=FeaturesDict({
'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
}),
total_num_examples=70000,
splits={
'test': 10000,
'train': 60000,
},
supervised_keys=('image', 'label'),
citation="""@article{lecun2010mnist,
title={MNIST handwritten digit database},
author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
volume={2},
year={2010}
}""",
redistribution_info=,
)因此,tfds.load返回的对象是一个字典
{
"train": <train dataset>,
"test": <test dataset>
}实际上,在示例的下一行中,您将以这种方式提取"train“和"test”数据集:
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']从mnist_info对象中,您可以获得操作数据集所需的所有信息:拆分数量、数据类型(例如," image“是一个数据类型为tf.uint8的28x28x1图像)等等……
发布于 2020-07-11 14:31:56
我在使用此代码加载mnist时遇到错误
mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)错误消息初始化()缺少两个必需的位置参数:'op‘和’‘
源码Udemy课程
发布于 2020-08-06 05:26:50
尝尝这个
x_train, y_train = Next(iter(mnist_train))
然后绘制x_train图
https://stackoverflow.com/questions/58683675
复制相似问题