首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >具有TensorFlow Dataset API和flat_map的并行线程

具有TensorFlow Dataset API和flat_map的并行线程
EN

Stack Overflow用户
提问于 2017-11-21 10:56:57
回答 1查看 5.4K关注 0票数 20

我正在将TensorFlow代码从旧的队列接口更改为新的数据集API。使用旧接口,我可以将num_threads参数指定到tf.train.shuffle_batch队列。但是,控制Dataset API中线程数量的唯一方法似乎是在map函数中使用num_parallel_calls参数。但是,我使用的是flat_map函数,它没有这样的参数。

问题:有办法控制flat_map函数的线程/进程数量吗?或者是否有方法结合使用mapflat_map,并仍然指定并行调用的数量?

注意,并行运行多个线程至关重要,因为我打算在数据进入队列之前在CPU上运行大量的预处理。

这里这里上有两篇与GitHub相关的文章,但我认为他们没有回答这个问题。

下面是我的用例的一个最低限度的代码示例:

代码语言:javascript
复制
with tf.Graph().as_default():
    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)

    def pre_processing_func(data_):
        # normally I would do data-augmentation here
        results = (tf.expand_dims(data_, axis=0),)
        return tf.data.Dataset.from_tensor_slices(results)

    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    # do something with 'dataset'
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-11-21 13:18:21

据我所知,目前flat_map没有提供并行性选项。考虑到大部分计算都是在pre_processing_func中完成的,您可以使用一个并行的map调用,然后是一些缓冲,然后使用一个带有标识lambda函数的flat_map调用,该函数负责处理输出。

代码:

代码语言:javascript
复制
NUM_THREADS = 5
BUFFER_SIZE = 1000

def pre_processing_func(data_):
    # data-augmentation here
    # generate new samples starting from the sample `data_`
    artificial_samples = generate_from_sample(data_)
    return atificial_samples

dataset_source = (tf.data.Dataset.from_tensor_slices(input_tensors).
                  map(pre_processing_func, num_parallel_calls=NUM_THREADS).
                  prefetch(BUFFER_SIZE).
                  flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)).
                  shuffle(BUFFER_SIZE)) # my addition, probably necessary though

注意(对我自己和那些试图理解管道的人):

由于pre_processing_func从初始样本(组织在形状(?, 512)的矩阵中)开始生成任意数量的新样本,因此需要flat_map调用将所有生成的矩阵转换为包含单个样本的Datasets (因此是lambda中的tf.data.Dataset.from_tensor_slices(x) ),然后将所有这些数据集平铺成一个包含单个样本的大Dataset

对于.shuffle()来说,将数据集或生成的样本打包在一起可能是个好主意。

票数 14
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47411383

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档