首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow 2.0:如何将MapDataset (从TFRecord读取后)转换为可以输入到model.fit的结构

Tensorflow 2.0:如何将MapDataset (从TFRecord读取后)转换为可以输入到model.fit的结构
EN

Stack Overflow用户
提问于 2020-03-09 13:18:00
回答 1查看 8.8K关注 0票数 8

我将训练和验证数据存储在两个单独的TFRecord文件中,其中存储了4个值:信号A (float32 shape (150,))、信号B (float32 shape (150,))、label (标量int64)、id (string)。我的阅读解析功能是:

代码语言:javascript
复制
def _parse_data_function(sample_proto):

    raw_signal_description = {
        'label': tf.io.FixedLenFeature([], tf.int64),
        'id': tf.io.FixedLenFeature([], tf.string),
    }

    for key, item in SIGNALS.items():
        raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)

    # Parse the input tf.Example proto using the dictionary above.
    return tf.io.parse_single_example(sample_proto, raw_signal_description)

其中SIGNALS是一个字典映射信号名->信号形状。然后,我阅读了原始数据集:

代码语言:javascript
复制
training_raw = tf.data.TFRecordDataset(<path to training>), compression_type='GZIP')
val_raw = tf.data.TFRecordDataset(<path to validation>), compression_type='GZIP')

并使用map解析值:

代码语言:javascript
复制
training_data = training_raw.map(_parse_data_function)
val_data = val_raw.map(_parse_data_function)

显示training_dataval_data的标题时,我得到:

<MapDataset shapes: {Signal A: (150,), Signal B: (150,), id: (), label: ()}, types: {Signal A: tf.float32, Signal B: tf.float32, id: tf.string, label: tf.int64}>

和预期的差不多。我还检查了一些值的一致性,它们似乎是正确的。

现在,关于我的问题:如何从MapDataset (具有类似于结构的字典)到可以作为模型输入的东西?

输入到我的模型是对(信号A,标签),虽然在未来我也将使用信号B。

对我来说,最简单的方法似乎是在我想要的元素上创建一个生成器。类似于:

代码语言:javascript
复制
def data_generator(mapdataset):
    for sample in mapdataset:
        yield (sample['Signal A'], sample['label'])

但是,通过这种方法,我失去了一些数据集的便利性,比如批处理,而且也不清楚如何在validation_data参数化的model.fit中使用相同的方法。理想情况下,我只会在映射表示和数据集表示之间进行转换,其中它在信号A张量和标号对上迭代。

编辑:我的最终产品应该具有类似于:<TensorSliceDataset shapes: ((150,), ()), types: (tf.float32, tf.int64)>但不一定是TensorSliceDataset的标题

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-03-09 14:04:15

您可以简单地在解析函数中这样做。例如:

代码语言:javascript
复制
def _parse_data_function(sample_proto):

    raw_signal_description = {
        'label': tf.io.FixedLenFeature([], tf.int64),
        'id': tf.io.FixedLenFeature([], tf.string),
    }

    for key, item in SIGNALS.items():
        raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)

    # Parse the input tf.Example proto using the dictionary above.
    parsed = tf.io.parse_single_example(sample_proto, raw_signal_description)

    return parsed['Signal A'], parsed['label']

如果您在map上使用TFRecordDataset函数,那么您将有一个元组(signal_a, label)数据集,而不是一个字典数据集。您应该能够将其直接放入model.fit中。

票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60601412

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档