首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用Keras添加和访问辅助tf.Dataset属性

使用Keras添加和访问辅助tf.Dataset属性
EN

Stack Overflow用户
提问于 2018-11-10 07:32:24
回答 1查看 197关注 0票数 0

我使用tf.py_func调用将数据(要素、标签和sample_weights)从文件解析到tf.Dataset

代码语言:javascript
复制
dataset = tf.data.Dataset.from_tensor_slices((records, labels, sample_weights))    
dataset = dataset.map(
   lambda filename, label, sample_weight: tuple(tf.py_func(
     self._my_parse_function, [filename, label, sample_weights], [tf.float32, label.dtype, tf.float32])))

数据是可变长度的一维序列,因此我还将序列填充到my_parse_function中的固定长度。

我使用tensorflow.python.keras.models.Sequential.fit(...)训练数据(它现在接受数据集作为输入,包括使用sample_weights的数据集)和tensorflow.python.keras.models.Sequential.predict来预测输出。

一旦我有了预测,我想做一些后处理,以使输出有意义。例如,我想将填充的数据截断为实际的序列长度。此外,我想确切地知道数据来自哪个文件,因为我不确定数据集迭代器是否保证排序,特别是如果使用批处理(我也会对数据集进行批处理),或者涉及多GPU或多工作者(我希望尝试多场景)。即使订单是“有保证的”,这也是一个很好的理智检查。

这些信息,文件名(即一个字符串)和序列长度(即一个整数),目前还不能方便地访问,所以我想将这两个属性添加到数据集元素中,并能够在调用期间/之后检索它们以进行预测。

做到这一点的最佳方法是什么?

谢谢

EN

回答 1

Stack Overflow用户

发布于 2018-11-16 02:28:11

作为一种变通方法,我将这些辅助信息存储在my_parse_fn中的一个“全局”字典中,这样它就可以在通过tf.Dataset进行的每次迭代中存储(和重新存储)这些辅助信息。这目前还可以,因为在训练集中只有大约1000个示例,所以存储1000个字符串和整数不是问题。但是如果辅助信息更大,或者训练集更大,这种方法就不会有很好的可扩展性。在我的例子中,每个训练示例的输入数据都非常大,大约50MB,这就是为什么从文件中读取tf.Dataset (即在每个时期)很重要。

我仍然认为,能够使用这些信息更方便地扩展tf.Dataset会很有帮助。我还注意到,当我向像dataset.tag这样的tf.Dataset添加一个字段来标识,比如说,dataset.tag = ' training ',dataset.tag ='validation‘或dataset.tag = 'test’集合时,该字段在训练的迭代中无法存活。

因此,在这种情况下,我想知道如何扩展tf.Dataset

在另一个问题上,看起来tf.Dataset元素的顺序在迭代过程中是遵守的,所以预测,比如说来自tensorflow.python.keras.models.Sequential.predict(...)的预测是按照文件is呈现给my_parse_fn的顺序排序的(至少批处理遵循这种顺序,但我仍然不知道多GPU场景是否也是如此)。

感谢你的见解。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53234585

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档