首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TF Dataset API用于图像增强

TF Dataset API用于图像增强
EN

Stack Overflow用户
提问于 2018-10-29 19:00:08
回答 2查看 2.7K关注 0票数 1

我正在使用tf Dataset API读取图像及其标签。我喜欢对图像进行多次图像增强,并增加我的训练数据大小。我现在所做的如下所示。

代码语言:javascript
复制
def flip(self, img, lbl):
  image = tf.image.flip_left_right(img)
  return image, lbl

def transpose(self, img, lbl):
  image = tf.image.transpose_image(img)
  return image, lbl

# just read and resize the image.
process_fn = lambda img, lbl: self.read_convert_image(img, lbl, self.args)
flip_fn = lambda img, lbl: self.flip(img,lbl)
transpose_fn = lambda img, lbl: self.transpose(img,lbl)

train_set = self.train_set.repeat()
train_set = train_set.shuffle(args.batch_size)
train_set = train_set.map(process_fn)

fliped_data = train_set.map(flip_fn)
transpose_data = train_set.map(transpose_fn)

train_set = train_set.concatenate(fliped_data)
train_set = train_set.concatenate(transpose_data)

train_set = train_set.batch(args.batch_size)
iterator = train_set.make_one_shot_iterator()

images, labels = iterator.get_next()

有没有更好的方法来做多次增强。上述方法的问题是,如果我添加更多的增强函数,则需要许多映射和连接。

谢谢

EN

回答 2

Stack Overflow用户

发布于 2018-10-30 16:51:34

如果你想自己做扩充,而不是依赖于Keras的ImageDataGenerator,你可以创建一个像img_aug这样的函数,然后在你的模型或Dataset API管道中使用它。下面的代码只是一个伪代码,但它展示了这个想法。你定义了所有的转换,然后你有了一些通用的阈值,超过这个阈值,你可以应用一个转换,并尝试将它们应用到X次(在下面的代码中是4次)

代码语言:javascript
复制
def img_aug(image):
  image = distorted_image

  def h_flip():
    return tf.image.flip_left_right(distorted_image)                
  def v_flip():
    return tf.image.flip_up_down(distorted_image)

  threshold = tf.constant(0.9, dtype=tf.float32)      

  def body(i, distorted_image):
    p_order = tf.random_uniform(shape=[2], minval=0., maxval=1., dtype=tf.float32)
    distorted_image = tf.case({                                      
                               tf.greater(p_order[0], threshold): h_flip,  
                               tf.greater(p_order[1], threshold): v_flip, 
                              }
                              ,default=identity, exclusive=False)
    return (i+1, distorted_image)

  def cond(i, *args):
    return i < 4 # max number of transformations

  parallel_iterations = 1
  tf.while_loop(cond, body, [0,distorted_image], 
                parallel_iterations=parallel_iterations)
  return distorted_image
票数 3
EN

Stack Overflow用户

发布于 2018-10-30 16:35:52

一种简单的图像增强方法是使用包含易于使用的apiTensorflow implemented Keras

它看起来像这样

代码语言:javascript
复制
ImageDataGenerator(rescale=1./255, 
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range = 0.2, 
    horizontal_flip = True)

您已经准备好尽可能多地使用增强的图像。

下面是一个可用的github代码示例Conv_net_with_augmentation

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53044066

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档