首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow数据增强

Tensorflow数据增强
EN

Stack Overflow用户
提问于 2020-08-08 01:10:32
回答 1查看 665关注 0票数 0

我想转换这个keras数据增强工作流:

代码语言:javascript
复制
datagen = ImageDataGenerator( 
    rescale=1./255,
    rotation_range = 10,
    horizontal_flip = True,
    width_shift_range=0.1,
    height_shift_range=0.1,
    fill_mode = 'nearest')

这是一个代码片段,但这两个函数都不起作用,因为它不支持批处理维度!

代码语言:javascript
复制
import numpy as np
def augment(x, y):
    x = tf.keras.preprocessing.image.random_shift(x, 0.1, 0.1)
    x = tf.keras.preprocessing.image.random_rotation(
    x, 10, row_axis=1, col_axis=2, channel_axis=0, fill_mode='nearest', cval=0.0,
    interpolation_order=1)
    return x, y

X = np.random.random(size=(256, 48, 48, 1))
y = np.random.randint(0, 7, size=(256,))
dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.map(augment)
dataset = dataset.batch(16, drop_remainder=False)
dataset = dataset.prefetch(buffer_size=1)
EN

回答 1

Stack Overflow用户

发布于 2020-08-08 04:06:35

在运行您的代码时,我得到了以下错误:AttributeError: 'Tensor' object has no attribute 'ndim'。使用tf.data.Dataset运行augment函数似乎是不可能的,因为它不能处理张量。一种解决方法是将增强函数包装在tf.py_function

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

def augment(x, y):
    x = x.numpy()
    x = tf.keras.preprocessing.image.random_shift(x, 0.1, 0.1)
    x = tf.keras.preprocessing.image.random_rotation(
    x, 10, row_axis=1, col_axis=2, channel_axis=0, fill_mode='nearest', cval=0.0,
    interpolation_order=1)
    return x, y

X = np.random.random(size=(256, 48, 48, 1))
y = np.random.randint(0, 7, size=(256,))

dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.map(
    lambda x, y: tf.py_function(
        func=augment,
        inp=[x, y],
        Tout=[tf.float32, tf.int64]))
dataset = dataset.batch(16, drop_remainder=False)
dataset = dataset.prefetch(buffer_size=1)

上面的代码应该运行时没有任何错误。如果您经常需要用tf.py_function包装您的函数,那么编写一个装饰器会很方便(也很干净)。如下所示:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

def map_decorator(func):
    def wrapper(*args):
        return tf.py_function(
            func=func,
            inp=[*args],
            Tout=[a.dtype for a in args])
    return wrapper

@map_decorator
def augment(x, y):
    x = x.numpy()
    x = tf.keras.preprocessing.image.random_shift(x, 0.1, 0.1)
    x = tf.keras.preprocessing.image.random_rotation(
    x, 10, row_axis=1, col_axis=2, channel_axis=0, fill_mode='nearest', cval=0.0,
    interpolation_order=1)
    return x, y

X = np.random.random(size=(256, 48, 48, 1))
y = np.random.randint(0, 7, size=(256,))

dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.map(augment)
dataset = dataset.batch(16, drop_remainder=False)
dataset = dataset.prefetch(buffer_size=1)

希望它能帮上忙!

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63306389

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档