首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tf.data.Dataset附加元素

tf.data.Dataset附加元素
EN

Stack Overflow用户
提问于 2017-11-14 19:14:14
回答 1查看 4.7K关注 0票数 2

给定具有以下结构的TFRecord:

代码语言:javascript
复制
context_features = {
    "labels_length": tf.FixedLenFeature(shape=[], dtype=tf.int64),
    "filename": tf.FixedLenFeature(shape=[], dtype=tf.string)
}
sequence_features = {
    "labels": tf.FixedLenSequenceFeature(shape=[], dtype=tf.int64)
}

在解析记录文件时,我希望在运行时将一个元素添加到labels字段。为此数据集编写的迭代器还会创建批处理,并将其与零交换,因此在添加填充值之前在列表末尾添加元素是很重要的。例如,如果我们有三个条目: 1.2、3、4、5、6、7、8、9,并且要添加的元素为10,则填充批应该如下所示:[1, 2, 10, 0, 0], [3, 4, 5, 10, 0], [6, 7, 8, 9, 10]

你能给我推荐一种方法吗?谢谢

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-11-14 21:59:19

要做到这一点,只需使用tf.concat()将一个元素添加到从tf.parse_single_sequence_example()获得的labels张量中。例如,将10附加到每个标签:

代码语言:javascript
复制
def _parse_labels_function(example):
    context_features = {
        "labels_length": tf.FixedLenFeature(shape=[], dtype=tf.int64),
        "filename": tf.FixedLenFeature(shape=[], dtype=tf.string)
    }
    sequence_features = {
        "labels": tf.FixedLenSequenceFeature(shape=[], dtype=tf.int64)
    }

    context_parsed, sequence_parsed = tf.parse_single_sequence_example(
        serialized=example,
        context_features=context_features,
        sequence_features=sequence_features
    )

    # Append `10` to each label sequence.
    labels = tf.concat([sequence_parsed["labels"], [10]], 0)

    return labels, context_parsed["labels_length"], context_parsed["filename"]

dataset = tf.data.TFRecordDataset(label_record)
dataset = dataset.map(_parse_labels_function)

请注意,我不确定如何在您的程序中使用"labels_length"功能,但在返回之前,您可能还需要添加一个。

票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47293577

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档