我对Tensorflow的新输入管道机制有问题。当我使用tf.data.Dataset创建数据管道时,它会解码jpeg图像,然后将其加载到队列中,它会尝试将尽可能多的图像加载到队列中。如果加载图像的吞吐量大于我的模型处理的图像的吞吐量,则内存使用量会无限制地增加。
下面是使用tf.data.Dataset构建管道的代码片段
def _imread(file_name, label):
_raw = tf.read_file(file_name)
_decoded = tf.image.decode_jpeg(_raw, channels=hps.im_ch)
_resized = tf.image.resize_images(_decoded, [hps.im_width, hps.im_height])
_scaled = (_resized / 127.5) - 1.0
return _scaled, label
n_samples = image_files.shape.as_list()[0]
dset = tf.data.Dataset.from_tensor_slices((image_files, labels))
dset = dset.shuffle(n_samples, None)
dset = dset.repeat(hps.n_epochs)
dset = dset.map(_imread, hps.batch_size * 32)
dset = dset.batch(hps.batch_size)
dset = dset.prefetch(hps.batch_size * 2)在这里,image_files是一个常量张量,包含30k图像的文件名。在_imread中将图像大小调整为256x256x3。
如果使用以下代码片段构建管道:
# refer to "https://www.tensorflow.org/programmers_guide/datasets"
def _imread(file_name, hps):
_raw = tf.read_file(file_name)
_decoded = tf.image.decode_jpeg(_raw, channels=hps.im_ch)
_resized = tf.image.resize_images(_decoded, [hps.im_width, hps.im_height])
_scaled = (_resized / 127.5) - 1.0
return _scaled
n_samples = image_files.shape.as_list()[0]
image_file, label = tf.train.slice_input_producer(
[image_files, labels],
num_epochs=hps.n_epochs,
shuffle=True,
seed=None,
capacity=n_samples,
)
# Decode image.
image = _imread(image_file,
images, labels = tf.train.shuffle_batch(
tensors=[image, label],
batch_size=hps.batch_size,
capacity=hps.batch_size * 64,
min_after_dequeue=hps.batch_size * 8,
num_threads=32,
seed=None,
enqueue_many=False,
allow_smaller_final_batch=True
)那么,在整个训练过程中,内存使用几乎是恒定的。如何让tf.data.Dataset加载固定数量的样本?我用tf.data.Dataset创建的管道正确吗?我认为tf.data.Dataset.shuffle中的buffer_size参数是针对image_files和labels的。所以这对于存储30k的字符串应该不是问题,对吧?即使要加载30k图像,也需要30000*256*256*3*8/(1024*1024*1024)=43 of的内存。然而,它使用了51 of系统内存中的59 of。
发布于 2017-11-13 01:01:55
这将缓冲n_samples,它看起来是您的整个数据集。你可能想要减少这里的缓冲。
dset = dset.shuffle(n_samples, None)你不妨永远重复,重复不会缓冲(Does tf.data.Dataset.repeat() buffer the entire dataset in memory?)
dset = dset.repeat()您正在进行批处理,然后预取批处理的hps.batch_size #。唉哟!
dset = dset.batch(hps.batch_size)
dset = dset.prefetch(hps.batch_size * 2)让我们用hps.batch_size=1000来做一个具体的例子。上面的第一行创建了一批1000个图像。上面的第二行创建2000批每1000张图像,总共缓冲了2,000,000张图像。糟了!
你的意思是:
dset = dset.batch(hps.batch_size)
dset = dset.prefetch(2)https://stackoverflow.com/questions/47251280
复制相似问题