首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在Keras中使用Tensorflow数据集API时出现的问题

在Keras中使用Tensorflow数据集API时出现的问题
EN

Stack Overflow用户
提问于 2019-05-14 23:33:22
回答 3查看 989关注 0票数 2

我正在尝试拟合CNN Keras模型,向其提供由Tensorflow的Datasets处理的数据。然而,尽管遵循了官方文档(请参阅there),我还是一次又一次地遇到相同的异常:

代码语言:javascript
复制
ValueError: No data provided for "conv2d_8_input". Need data for each key in: ['conv2d_8_input']
# conv2d_8 is the first Conv2D layer of my model, see below

我使用的是来自tensorflow-datasets的MNIST数据集,图像被标准化,类标签被转换为单热点编码。您可以从下面的代码中看到摘录。

代码语言:javascript
复制
test_data, train_data = tfds.load("mnist", split=Split.ALL.subsplit([1, 3]))

# [...] Images are normalized using Dataset.map method
# [...] Labels are converted into one-hot encodings as well, using tf.one_hot function

model = keras.Sequential([
    keras.layers.Conv2D(
        32,
        kernel_size=5,
        padding="same",
        input_shape=(28, 28, 1),
        activation="relu",
    ),
    keras.layers.MaxPooling2D(
        (2, 2),
        padding="same"
    ),
    keras.layers.Conv2D(
        64,
        kernel_size=5,
        padding="same",
        activation="relu"
    ),
    keras.layers.MaxPooling2D(
        (2, 2),
        padding="same"
    ),
    keras.layers.Flatten(),
    keras.layers.Dense(
        512,
        activation="relu"
    ),
    keras.layers.Dropout(rate=0.4),
    keras.layers.Dense(10, activation="softmax")
])

model.compile(
    optimizer=tf.train.AdamOptimizer(0.01),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

train_data = train_data.batch(32).repeat()
test_data = test_data.batch(32).repeat()

model.fit(
    train_data,
    epochs=10,
    steps_per_epoch=30,
    validation_data=test_data,
    validation_steps=3
) # The exception occurs at this step

我不明白为什么它不起作用,我试着给fit方法提供单次迭代器,而不是数据集,但我得到了相同的结果。我不习惯使用PyTorch和TensorFlow (我通常使用Keras ),所以我想我可能遗漏了一些明显的东西。

EN

回答 3

Stack Overflow用户

回答已采纳

发布于 2019-05-15 04:01:58

好的,我知道了。我启用了急切执行,以查看Keras是否会产生更精确的异常,我得到了以下结果:

代码语言:javascript
复制
ValueError: Output of generator should be a tuple `(x, y, sample_weight)` or `(x, y)`. Found: {'image': <tf.Tensor: id=1012, shape=(32, 28, 28, 1), dtype=float64, numpy=array([...])>, 'label': <tf.Tensor: id=1013, shape=(32, 10), dtype=uint8, numpy=array([...]), dtype=uint8)>}

实际上,我的数据集的组件(图像及其相关标签)都有名称("image“和"label"),因为这是tensorflow_datasets加载它们的方式。因此,数据集上的迭代器产生一个具有两个值的字典:"image“和"label”。

但是,两个值的Keras expects a tuple(inputs, targets) (或三个值为(inputs, targets, sample_wheights)),并且它不喜欢数据集迭代器产生的字典(因此我得到了错误)。

我在model.fit之前添加了以下代码

代码语言:javascript
复制
train_data = train_data.map(lambda x: tuple(x.values()))
test_data = test_data.map(lambda x: tuple(x.values()))

而且它是有效的。

票数 1
EN

Stack Overflow用户

发布于 2019-08-06 09:42:06

对于那些在遵循TF2.0Beta教程加载图像(https://www.tensorflow.org/beta/tutorials/load_data/images)后访问此页面的人:

我可以通过在preprocess_image函数中返回一个元组来避免这个错误

代码语言:javascript
复制
def preprocess_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [192, 192])
image /= 255.0  # normalize to [0,1] range
return (image,image)

我没有在我的用例中使用标签,因此您可能需要进行其他更改才能按照本教程进行操作

票数 3
EN

Stack Overflow用户

发布于 2019-07-03 14:57:23

您可以使用as_supervised将数据作为元组直接从tensorflow-datasets加载

代码语言:javascript
复制
test_data, train_data = tfds.load("mnist", split=tfds.Split.ALL.subsplit([1, 3]), as_supervised=True)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56134016

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档