首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >基于tf.dataset.Dataset的数据增强

基于tf.dataset.Dataset的数据增强
EN

Stack Overflow用户
提问于 2020-05-12 19:23:02
回答 1查看 5.4K关注 0票数 6

为了使用Google,我需要一个tf.dataset.Dataset。那么,如何在这样的数据集中使用数据增强呢?

更具体地说,到目前为止,我的代码是:

代码语言:javascript
复制
def get_dataset(batch_size=200):
  datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
                             try_gcs=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255.0

    label = tf.one_hot(label,10)

    return image, label

  train_dataset = mnist_train.map(scale).shuffle(10000).batch(batch_size)
  test_dataset = mnist_test.map(scale).batch(batch_size)

  return train_dataset, test_dataset

它被纳入到这方面:

代码语言:javascript
复制
# TPU Strategy ...
with strategy.scope():
  model = create_model()
  model.compile(loss="categorical_crossentropy",
                optimizer="adam",
                metrics=["acc"])

train_dataset, test_dataset = get_dataset()

model.fit(train_dataset,
          epochs=20,
          verbose=1,
          validation_data=test_dataset)

那么,我如何在这里使用数据增强呢?据我所知,我不能使用tf.keras ImageDataGenerator,对吗?

我试过以下几种方法,但都没有用。

代码语言:javascript
复制
data_generator = ...

model.fit_generator(data_generator.flow(train_dataset, batch_size=32),
                    steps_per_epoch=len(train_dataset) / 32, epochs=20)

这并不奇怪,因为通常,train_x和train_y是作为流函数的两个参数输入的,而不是“打包”到一个tf.dataset.Dataset中。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-05-18 13:58:57

您可以使用tf.image函数。tf.image模块包含各种图像处理功能。

例如,

您可以在函数def get_dataset中添加以下功能。

  • tf.float64范围内将每个图像转换为0-1
  • cache()结果,因为这些结果可以在每个repeat之后重复使用。
  • 使用left_to_right随机翻转每个图像使用random_flip_left_right
  • 利用random_contrast随机改变图像对比度。
  • 通过重复所有步骤的repeat,图像数量增加了两倍。

码-

代码语言:javascript
复制
mnist_train = mnist_train.map(
    lambda image, label: (tf.image.convert_image_dtype(image, tf.float32), label)
).cache(
).map(
    lambda image, label: (tf.image.random_flip_left_right(image), label)
).map(
    lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
).shuffle(
    1000
).
batch(
    batch_size
).repeat(2)

类似地,您可以使用其他功能,如random_flip_up_downrandom_crop函数,将图像垂直翻转(倒转),并将张量随机裁剪到给定的大小。

您的get_dataset函数如下所示-

代码语言:javascript
复制
def get_dataset(batch_size=200):
  datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
                             try_gcs=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  train_dataset = mnist_train.map(
               lambda image, label: (tf.image.convert_image_dtype(image, tf.float32),label)
              ).cache(
              ).map(
                    lambda image, label: (tf.image.random_flip_left_right(image), label)
              ).map(
                    lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
              ).shuffle(
                    1000
              ).batch(
                    batch_size
              ).repeat(2)

  test_dataset = mnist_test.map(scale).batch(batch_size)

  return train_dataset, test_dataset

添加@Andrew建议的链接,该链接在数据增强上给出了端到端的示例,该示例也使用mnist数据集。

希望这能回答你的问题。学习愉快。

票数 7
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61760235

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档