首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >了解如何在Keras中使用Albumentations

了解如何在Keras中使用Albumentations
EN

Stack Overflow用户
提问于 2022-02-24 12:28:35
回答 1查看 602关注 0票数 1

我正在尝试理解如何构建一个数据增强管道,并使用Albumentations来输入Keras模型。我遵循这个示例-> https://albumentations.ai/docs/examples/tensorflow-example/,其中他们创建了dataset对象PrefetchDataset并将其传递给model.fit()。见下面的代码:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from functools import partial
from albumentations import (
    Compose, RandomBrightness, JpegCompression, HueSaturationValue, RandomContrast, HorizontalFlip,
    Rotate
)
AUTOTUNE = tf.data.experimental.AUTOTUNE

#load data
data, info= tfds.load(name="tf_flowers", split="train", as_supervised=True, with_info=True)
data

# augmentations
transforms = Compose([
            Rotate(limit=40),
            RandomBrightness(limit=0.1),
            JpegCompression(quality_lower=85, quality_upper=100, p=0.5),
            HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
            RandomContrast(limit=0.2, p=0.5),
            HorizontalFlip(),
        ])


def aug_fn(image, img_size):
    data = {"image":image}
    aug_data = transforms(**data)
    aug_img = aug_data["image"]
    aug_img = tf.cast(aug_img/255.0, tf.float32)
    aug_img = tf.image.resize(aug_img, size=[img_size, img_size])
    return aug_img

def process_data(image, label, img_size):
    aug_img = tf.numpy_function(func=aug_fn, inp=[image, img_size], Tout=tf.float32)
    return aug_img, label

# create dataset
ds_alb = data.map(partial(process_data, img_size=120),
                  num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)


# ...And then just pass this dataset object to the model

def create_model(input_shape):
    pass
    # define model layers...

model = create_model(input_shape)

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics='accuracy', run_eagerly=True)
model.fit(ds_alb, epochs=2)

我的问题是:这个对象ds_alb是否在训练期间为每批返回不同的图像(根据随机参数设置)?我读过代码,似乎Compose中的所有增强都只执行了一次,在这里:

ds_alb = data.map(partial(process_data, img_size=120), num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)

但我相信,建造这条管道的目的是在每一次传送不同的增强图像,而不是只增加一次.有件事我不明白。如何检查是否正在生成不同的图像?

EN

回答 1

Stack Overflow用户

发布于 2022-02-24 12:44:56

我不确定每个批tf.numpy_function是否会调用该函数一次还是多次。但是有一个简单的方法来测试它。在aug_fn中放置一个打印,并取两批:

代码语言:javascript
复制
def aug_fn(...):
    print("Called")
    ....

batch = next(ds_alb) #depending on the type of generator, you might need ds_alb[0]   
another_batch = next(ds_alb) #or ds_alb[1]

#also check the shape of the batch to make sure there are many images

如果“调用”出现了多次,它将对每个图像进行一次转换。如果只出现两次,则每批进行一次转换。如果它只出现一次,那么代码中就有一个问题。

老实说,我更喜欢“每批一个转换”,因为这将意味着性能上的巨大提高,这一点很重要。有时预处理可能是限制性能的操作。

,但这难道不超过增强的目的吗?

不是的!您向模型提供了许多批,并且想必,虽然您对整个批处理具有相同的预处理,但是您有不同的源映像!当然,你的下一批会有不同的预处理。

当然,由于您将对许多时代进行培训,同样的图像将在后面的一个具有不同预处理的时代再次出现。最后,通行证的数量将是如此之大,在一个批处理中使用相同的预处理将不会是一个模型的问题。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71251982

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档