我想知道在TensorFlow 2.4中遍历tf.data.Dataset最有效的方法是什么。
我使用了典型的:
for example in dataset:
code但是,我测量了墙时间,因为我的数据集很大,所以计算循环需要太多时间。有没有其他方法可以减少计算时间?
发布于 2021-08-01 17:52:29
您可以使用.map(map_func)函数,这是对数据集中的每个样本应用一些预处理的有效方法。它对数据集的每个样本并行运行map_func。您甚至可以通过num_parallel_calls参数设置并行调用的数量。[Reference]
以下是来自tensorflow网站的一个示例:
dataset = tf.data.Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x + 1) # instead of adding 1 to each sample in a for loop
list(dataset.as_numpy_iterator()) # ==> [ 2, 3, 4, 5, 6 ]您也可以传递函数:
def my_map(x): # if dataset has y, it should be like "def my_map(x,y)" and "return x,y"
return x+1
dataset = tf.data.Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(my_map) # instead of adding 1 to each sample in a for loop
list(dataset.as_numpy_iterator()) # ==> [ 2, 3, 4, 5, 6 ]https://stackoverflow.com/questions/68612779
复制相似问题