为了使用Google,我需要一个tf.dataset.Dataset。那么,如何在这样的数据集中使用数据增强呢?
更具体地说,到目前为止,我的代码是:
def get_dataset(batch_size=200):
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
try_gcs=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255.0
label = tf.one_hot(label,10)
return image, label
train_dataset = mnist_train.map(scale).shuffle(10000).batch(batch_size)
test_dataset = mnist_test.map(scale).batch(batch_size)
return train_dataset, test_dataset它被纳入到这方面:
# TPU Strategy ...
with strategy.scope():
model = create_model()
model.compile(loss="categorical_crossentropy",
optimizer="adam",
metrics=["acc"])
train_dataset, test_dataset = get_dataset()
model.fit(train_dataset,
epochs=20,
verbose=1,
validation_data=test_dataset)那么,我如何在这里使用数据增强呢?据我所知,我不能使用tf.keras ImageDataGenerator,对吗?
我试过以下几种方法,但都没有用。
data_generator = ...
model.fit_generator(data_generator.flow(train_dataset, batch_size=32),
steps_per_epoch=len(train_dataset) / 32, epochs=20)这并不奇怪,因为通常,train_x和train_y是作为流函数的两个参数输入的,而不是“打包”到一个tf.dataset.Dataset中。
发布于 2020-05-18 13:58:57
您可以使用tf.image函数。tf.image模块包含各种图像处理功能。
例如,:
您可以在函数def get_dataset中添加以下功能。
tf.float64范围内将每个图像转换为0-1。cache()结果,因为这些结果可以在每个repeat之后重复使用。random_flip_left_right。random_contrast随机改变图像对比度。repeat,图像数量增加了两倍。码-
mnist_train = mnist_train.map(
lambda image, label: (tf.image.convert_image_dtype(image, tf.float32), label)
).cache(
).map(
lambda image, label: (tf.image.random_flip_left_right(image), label)
).map(
lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
).shuffle(
1000
).
batch(
batch_size
).repeat(2)类似地,您可以使用其他功能,如random_flip_up_down、random_crop函数,将图像垂直翻转(倒转),并将张量随机裁剪到给定的大小。
您的get_dataset函数如下所示-
def get_dataset(batch_size=200):
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
try_gcs=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
train_dataset = mnist_train.map(
lambda image, label: (tf.image.convert_image_dtype(image, tf.float32),label)
).cache(
).map(
lambda image, label: (tf.image.random_flip_left_right(image), label)
).map(
lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
).shuffle(
1000
).batch(
batch_size
).repeat(2)
test_dataset = mnist_test.map(scale).batch(batch_size)
return train_dataset, test_dataset添加@Andrew建议的链接,该链接在数据增强上给出了端到端的示例,该示例也使用mnist数据集。
希望这能回答你的问题。学习愉快。
https://stackoverflow.com/questions/61760235
复制相似问题