给定具有以下结构的TFRecord:
context_features = {
"labels_length": tf.FixedLenFeature(shape=[], dtype=tf.int64),
"filename": tf.FixedLenFeature(shape=[], dtype=tf.string)
}
sequence_features = {
"labels": tf.FixedLenSequenceFeature(shape=[], dtype=tf.int64)
}在解析记录文件时,我希望在运行时将一个元素添加到labels字段。为此数据集编写的迭代器还会创建批处理,并将其与零交换,因此在添加填充值之前在列表末尾添加元素是很重要的。例如,如果我们有三个条目: 1.2、3、4、5、6、7、8、9,并且要添加的元素为10,则填充批应该如下所示:[1, 2, 10, 0, 0], [3, 4, 5, 10, 0], [6, 7, 8, 9, 10]
你能给我推荐一种方法吗?谢谢
发布于 2017-11-14 21:59:19
要做到这一点,只需使用tf.concat()将一个元素添加到从tf.parse_single_sequence_example()获得的labels张量中。例如,将10附加到每个标签:
def _parse_labels_function(example):
context_features = {
"labels_length": tf.FixedLenFeature(shape=[], dtype=tf.int64),
"filename": tf.FixedLenFeature(shape=[], dtype=tf.string)
}
sequence_features = {
"labels": tf.FixedLenSequenceFeature(shape=[], dtype=tf.int64)
}
context_parsed, sequence_parsed = tf.parse_single_sequence_example(
serialized=example,
context_features=context_features,
sequence_features=sequence_features
)
# Append `10` to each label sequence.
labels = tf.concat([sequence_parsed["labels"], [10]], 0)
return labels, context_parsed["labels_length"], context_parsed["filename"]
dataset = tf.data.TFRecordDataset(label_record)
dataset = dataset.map(_parse_labels_function)请注意,我不确定如何在您的程序中使用"labels_length"功能,但在返回之前,您可能还需要添加一个。
https://stackoverflow.com/questions/47293577
复制相似问题