首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TensorFlow tf.data.Dataset和水桶

TensorFlow tf.data.Dataset和水桶
EN

Stack Overflow用户
提问于 2018-05-30 13:37:22
回答 1查看 2.2K关注 0票数 9

对于LSTM网络,我看到了水桶的巨大改进。

我遇到了TensorFlow文档中的桶形部分 the (tf.contrib)。

虽然在我的网络中,我使用的是tf.data.Dataset API,特别是我正在使用TFRecords,所以我的输入管道如下所示

代码语言:javascript
复制
dataset = tf.data.TFRecordDataset(TFRECORDS_PATH)
dataset = dataset.map(_parse_function)
dataset = dataset.map(_scale_function)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.padded_batch(batch_size, padded_shapes={.....})

如何将斗式方法集成到tf.data.Dataset管道中?

如果重要的话,在TFRecords文件中的每个记录中,我都将序列长度保存为整数。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-05-30 15:28:59

使用bucketing的各种Dataset API用例都能很好地解释这里

bucket_by_sequence_length() 示例:

代码语言:javascript
复制
def elements_gen():
   text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2]]
   label = [1, 2, 1, 2]
   for x, y in zip(text, label):
       yield (x, y)

def element_length_fn(x, y):
   return tf.shape(x)[0]

dataset = tf.data.Dataset.from_generator(generator=elements_gen,
                                     output_shapes=([None],[]),
                                     output_types=(tf.int32, tf.int32))

dataset =   dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_func=element_length_fn,
                                                              bucket_batch_sizes=[2, 2, 2],
                                                              bucket_boundaries=[0, 8]))

batch = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:

   for _ in range(2):
      print('Get_next:')
      print(sess.run(batch))

输出:

代码语言:javascript
复制
Get_next:
(array([[1, 2, 3, 0, 0],
   [3, 4, 5, 6, 7]], dtype=int32), array([1, 2], dtype=int32))
Get_next:
(array([[1, 2, 0, 0],
   [8, 9, 0, 2]], dtype=int32), array([1, 2], dtype=int32))
票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50606178

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档