是否可以删除在调用tf.data.Dataset.cache()之后生成的内存缓存
这是我想做的事。数据集的增强非常昂贵,因此当前代码或多或少:
data = tf.data.Dataset(...) \
.map(<expensive_augmentation>) \
.cache() \
# .shuffle().batch() etc. 然而,这意味着在data上的每一次迭代都会看到数据样本的相同的增广版本。相反,我想要做的是在几个时代中使用缓存,然后重新开始,或者等效地执行类似Dataset.map(<augmentation>).fleeting_cache().repeat(8)的操作。这是可能的吗?
发布于 2022-02-08 01:49:23
缓存生命周期与数据集相关联,因此您可以通过重新创建数据集来实现这一点:
def create_dataset():
dataset = tf.data.Dataset(...)
dataset = dataset.map(<expensive_augmentation>)
dataset = dataset.shuffle(...)
dataset = dataset.batch(...)
return dataset
for epoch in range(num_epochs):
# Drop the cache every 8 epochs.
if epoch % 8 == 0: dataset = create_dataset()
for batch in dataset:
train(batch)https://stackoverflow.com/questions/65037119
复制相似问题