首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tf.data.Dataset中的队列容量

tf.data.Dataset中的队列容量
EN

Stack Overflow用户
提问于 2017-11-13 00:41:56
回答 1查看 981关注 0票数 2

我对Tensorflow的新输入管道机制有问题。当我使用tf.data.Dataset创建数据管道时,它会解码jpeg图像,然后将其加载到队列中,它会尝试将尽可能多的图像加载到队列中。如果加载图像的吞吐量大于我的模型处理的图像的吞吐量,则内存使用量会无限制地增加。

下面是使用tf.data.Dataset构建管道的代码片段

代码语言:javascript
复制
def _imread(file_name, label):
  _raw = tf.read_file(file_name)
  _decoded = tf.image.decode_jpeg(_raw, channels=hps.im_ch)
  _resized = tf.image.resize_images(_decoded, [hps.im_width, hps.im_height])
  _scaled = (_resized / 127.5) - 1.0
  return _scaled, label

n_samples = image_files.shape.as_list()[0]
dset = tf.data.Dataset.from_tensor_slices((image_files, labels))
dset = dset.shuffle(n_samples, None)
dset = dset.repeat(hps.n_epochs)
dset = dset.map(_imread, hps.batch_size * 32)
dset = dset.batch(hps.batch_size)
dset = dset.prefetch(hps.batch_size * 2)

在这里,image_files是一个常量张量,包含30k图像的文件名。在_imread中将图像大小调整为256x256x3。

如果使用以下代码片段构建管道:

代码语言:javascript
复制
# refer to "https://www.tensorflow.org/programmers_guide/datasets"
def _imread(file_name, hps):
  _raw = tf.read_file(file_name)
  _decoded = tf.image.decode_jpeg(_raw, channels=hps.im_ch)
  _resized = tf.image.resize_images(_decoded, [hps.im_width, hps.im_height])
  _scaled = (_resized / 127.5) - 1.0
  return _scaled

n_samples = image_files.shape.as_list()[0]

image_file, label = tf.train.slice_input_producer(
  [image_files, labels],
  num_epochs=hps.n_epochs,
  shuffle=True,
  seed=None,
  capacity=n_samples,
)

# Decode image.
image = _imread(image_file, 

images, labels = tf.train.shuffle_batch(
  tensors=[image, label],
  batch_size=hps.batch_size,
  capacity=hps.batch_size * 64,
  min_after_dequeue=hps.batch_size * 8,
  num_threads=32,
  seed=None,
  enqueue_many=False,
  allow_smaller_final_batch=True
)

那么,在整个训练过程中,内存使用几乎是恒定的。如何让tf.data.Dataset加载固定数量的样本?我用tf.data.Dataset创建的管道正确吗?我认为tf.data.Dataset.shuffle中的buffer_size参数是针对image_fileslabels的。所以这对于存储30k的字符串应该不是问题,对吧?即使要加载30k图像,也需要30000*256*256*3*8/(1024*1024*1024)=43 of的内存。然而,它使用了51 of系统内存中的59 of。

EN

回答 1

Stack Overflow用户

发布于 2017-11-13 01:01:55

这将缓冲n_samples,它看起来是您的整个数据集。你可能想要减少这里的缓冲。

代码语言:javascript
复制
dset = dset.shuffle(n_samples, None)

你不妨永远重复,重复不会缓冲(Does tf.data.Dataset.repeat() buffer the entire dataset in memory?)

代码语言:javascript
复制
dset = dset.repeat()

您正在进行批处理,然后预取批处理的hps.batch_size #。唉哟!

代码语言:javascript
复制
dset = dset.batch(hps.batch_size)
dset = dset.prefetch(hps.batch_size * 2)

让我们用hps.batch_size=1000来做一个具体的例子。上面的第一行创建了一批1000个图像。上面的第二行创建2000批每1000张图像,总共缓冲了2,000,000张图像。糟了!

你的意思是:

代码语言:javascript
复制
dset = dset.batch(hps.batch_size)
dset = dset.prefetch(2)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47251280

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档