我已经开始使用tf.data.Dataset作为将数据加载到keras模型中的一种方式,因为它们看起来比keras的ImageDataGenerator快得多,而且比数组上的培训更节省内存。
一个人认为我无法理解我的想法,那就是我似乎找不到一种方法来访问数据集的len()。Keras的ImageDataGenerator有一个名为n的属性,用于此目的。这使得我的代码非常难看,因为我需要在片段的各个部分对长度进行硬编码(例如,了解一个时代有多少次迭代)。
有什么我能解决这个问题的主意吗?
一个示例脚本:
# Generator
def make_mnist_train_generator(batch_size):
(x_train, y_train), (_,_) = tf.keras.mnist.load_data()
x_train = x_train.reshape((-1, 28, 28, 1))
x_train = x_train.astype(np.float32) / 225.
y_train = tf.keras.utils.to_categorical(y_train, 10)
ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
ds = ds.shuffle(buffer_size=len(x_train))
ds = ds.repeat()
ds = ds.batch(batch_size=batch_size)
ds = ds.prefetch(buffer_size=1)
return ds
model = ... # create a tf.keras model
batch_size 256
gen = make_mnist_train_generator(batch_size)
# Training
model.fit(gen, epochs=50, steps_per_epoch=60000//batch_size+1) # Hard coded size of generator发布于 2020-12-29 01:25:25
tl;dr
不幸的是,tf.data.Dataset是一个生成器,而且没有,没有找到其大小的固有方法。
但是..。
一般来说,当您使用.from_tensor_slices()时,您可以通过在这个方法中添加的参数来知道它的大小,在您的例子中是x_train。您唯一的问题是在函数中创建它。
要绕过这个问题,您可以做的一个简单的方法是自己添加一个__len__属性!我发现你能做到的最简单的方法是:
ds.__class__ = type(ds.__class__.__name__, (ds.__class__,), {'__len__': lambda self: len(x_train)})在你的例子中,它看起来应该是这样的:
def make_mnist_train_generator(batch_size):
(x_train, y_train), (_,_) = tf.keras.mnist.load_data()
x_train = x_train.reshape((-1, 28, 28, 1))
x_train = x_train.astype(np.float32) / 225.
y_train = tf.keras.utils.to_categorical(y_train, 10)
ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
ds = ds.shuffle(buffer_size=len(x_train))
ds = ds.repeat()
ds = ds.batch(batch_size=batch_size)
ds = ds.prefetch(buffer_size=1)
ds.__class__ = type(ds.__class__.__name__, (ds.__class__,), {'__len__': lambda self: len(x_train)})
return ds
gen = make_mnist_train_generator(batch_size)
model.fit(gen, epochs=50, steps_per_epoch=len(gen)//batch_size+1) # Hard coded size of generator为什么要这么做?
我在过去做过这件事,它令人惊讶地有用。您希望您的生成器有一个len()的原因有很多。一些例子是:
如果希望将生成器放在单独的模块中,并导入生成器
的数据的其他人使用
https://stackoverflow.com/questions/65486078
复制相似问题