首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在parallel_interleave中使用TensorFlow

如何在parallel_interleave中使用TensorFlow
EN

Stack Overflow用户
提问于 2018-04-26 15:12:34
回答 1查看 3.8K关注 0票数 6

我正在阅读TensorFlow 基准回购中的代码。下面的代码是从TensorFlow文件创建TFRecord数据集的部分:

代码语言:javascript
复制
ds = tf.data.TFRecordDataset.list_files(tfrecord_file_names)
ds = ds.apply(interleave_ops.parallel_interleave(tf.data.TFRecordDataset, cycle_length=10))

我试图更改此代码以直接从JPEG图像文件创建数据集:

代码语言:javascript
复制
ds = tf.data.Dataset.from_tensor_slices(jpeg_file_names)
ds = ds.apply(interleave_ops.parallel_interleave(?, cycle_length=10))

我不知道该写什么?地点。map_func in parallel_interleave()是用于TFRecord文件的tf.data.TFRecordDataset类的__init__(),但我不知道该为TFRecord文件编写什么。

我们不需要在这里进行任何转换。因为我们将压缩两个数据集,然后再进行转换。守则如下:

代码语言:javascript
复制
counter = tf.data.Dataset.range(batch_size)
ds = tf.data.Dataset.zip((ds, counter))
ds = ds.apply( \
     batching.map_and_batch( \
     map_func=preprocess_fn, \
     batch_size=batch_size, \
     num_parallel_batches=num_splits))

因为我们不需要转化?位置上,我尝试使用一个空的map_func,但是有错误"map_funcmust return aDataset`‘object“。我也尝试使用tf.data.Dataset,但是输出显示Dataset是一个不允许放在那里的抽象类。

有人能帮忙吗?非常感谢。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-06-05 08:59:02

当转换将源数据集的每个元素转换为多个元素到目标数据集时,parallel_interleave非常有用。我不知道为什么他们会像这样在基准测试中使用它,因为他们可以只使用带有并行调用的map

下面是我建议使用parallel_interleave从几个目录中读取图像的方法,每个目录包含一个类:

代码语言:javascript
复制
classes = sorted(glob(directory + '/*/')) # final slash selects directories only
num_classes = len(classes)

labels = np.arange(num_classes, dtype=np.int32)

dirs = DS.from_tensor_slices((classes, labels))               # 1
files = dirs.apply(tf.contrib.data.parallel_interleave(
    get_files, cycle_length=num_classes, block_length=4,      # 2
    sloppy=False)) # False is important ! Otherwise it mixes labels
files = files.cache()
imgs = files.map(read_decode, num_parallel_calls=20)\.        # 3
            .apply(tf.contrib.data.shuffle_and_repeat(100))\
            .batch(batch_size)\
            .prefetch(5)

有三个步骤。首先,我们得到目录及其标签(#1)的列表。

然后,我们将这些映射到一个文件集。但是如果我们做一个简单的.flatmap(),我们最终会得到标签0的所有文件,然后是标签1的所有文件,然后是2等.然后,我们需要真正的大洗牌缓冲器,以获得有意义的洗牌。

因此,我们应用parallel_interleave (#2)。这是get_files()

代码语言:javascript
复制
def get_files(dir_path, label):
    globbed = tf.string_join([dir_path, '*.jpg'])
    files = tf.matching_files(globbed)

    num_files = tf.shape(files)[0] # in the directory
    labels = tf.tile([label], [num_files, ]) # expand label to all files
    return DS.from_tensor_slices((files, labels))

使用parallel_interleave确保每个目录的list_files并行运行,因此当第一个目录列出第一个block_length文件时,第二个目录中的第一个block_length文件也将可用(也来自第3、第4等目录)。此外,生成的数据集将包含每个标签的交错块,例如1 1 1 1 2 2 2 2 3 3 3 3 3 1 1 1 1 ... (针对3个类和block_length=4)。

最后,我们从文件列表(#3)中读取图像。这是read_and_decode()

代码语言:javascript
复制
def read_decode(path, label):
    img = tf.image.decode_image(tf.read_file(path), channels=3)
    img = tf.image.resize_bilinear(tf.expand_dims(img, axis=0), target_size)
    img = tf.squeeze(img, 0)
    img = preprocess_fct(img) # should work with Tensors !

    label = tf.one_hot(label, num_classes)
    img = tf.Print(img, [path, label], 'Read_decode')
    return (img, label)

该函数接受图像路径及其标签,并为每个路径返回一个张量:路径的图像张量和标签的one_hot编码。这也是您可以对图像进行所有转换的地方。在这里,我做调整大小和基本的预处理。

票数 9
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50046505

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档