首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TensorFlow图像增强: datagen - ValueError

TensorFlow图像增强: datagen - ValueError
EN

Stack Overflow用户
提问于 2021-12-07 16:04:42
回答 1查看 177关注 0票数 0

对于TensorFlow 2.6、Python3.9和CIFAR-10数据集,我试图训练一个简单的模型,定义如下:

代码语言:javascript
复制
def conv6_cnn():
    """
    Function to define the architecture of a neural network model
    following Conv-6 architecture for CIFAR-10 dataset and using
    provided parameter which are used to prune the model.
    
    Conv-6 architecture-
    64, 64, pool  -- convolutional layers
    128, 128, pool -- convolutional layers
    256, 256, pool -- convolutional layers
    256, 256, 10  -- fully connected layers
    
    Output: Returns designed and compiled neural network model
    """
    
    # l = tf.keras.layers
    
    model = Sequential()
    
    model.add(
        Conv2D(
            filters = 64, kernel_size = (3, 3),
            activation='relu', kernel_initializer = tf.keras.initializers.GlorotNormal(),
            strides = (1, 1), padding = 'same',
            input_shape=(32, 32, 3)
        )    
    )
        
    model.add(
        Conv2D(
            filters = 64, kernel_size = (3, 3),
            activation='relu', kernel_initializer = tf.keras.initializers.GlorotNormal(),
            strides = (1, 1), padding = 'same'
        )
    )
    
    model.add(
        MaxPooling2D(
            pool_size = (2, 2),
            strides = (2, 2)
        )
    )
    
    model.add(
        Conv2D(
            filters = 128, kernel_size = (3, 3),
            activation='relu', kernel_initializer = tf.keras.initializers.GlorotNormal(),
            strides = (1, 1), padding = 'same'
        )
    )

    model.add(
        Conv2D(
            filters = 128, kernel_size = (3, 3),
            activation='relu', kernel_initializer = tf.keras.initializers.GlorotNormal(),
            strides = (1, 1), padding = 'same'
        )
    )

    model.add(
        MaxPooling2D(
            pool_size = (2, 2),
            strides = (2, 2)
        )
    )

    model.add(
        Conv2D(
            filters = 256, kernel_size = (3, 3),
            activation='relu', kernel_initializer = tf.keras.initializers.GlorotNormal(),
            strides = (1, 1), padding = 'same'
        )
    )

    model.add(
        Conv2D(
            filters = 256, kernel_size = (3, 3),
            activation='relu', kernel_initializer = tf.keras.initializers.GlorotNormal(),
            strides = (1, 1), padding = 'same'
        )
    )

    model.add(
        MaxPooling2D(
            pool_size = (2, 2),
            strides = (2, 2)
        )
    )
    
    model.add(Flatten())
    
    model.add(
        Dense(
            units = 256, activation = 'relu',
            kernel_initializer = tf.keras.initializers.GlorotNormal()
        )
    )
    
    model.add(
        Dense(
            units = 256, activation = 'relu',
            kernel_initializer = tf.keras.initializers.GlorotNormal()
        )
    )
    
    model.add(
        Dense(
            units = 10, activation = 'softmax'
        )
    )
    
    return model

# Initialize a Conv-6 CNN object-
model = conv6_cnn()

# Define data Augmentation using ImageDataGenerator:

# Initialize and define the image data generator-
datagen = ImageDataGenerator(
    # featurewise_center=True,
    # featurewise_std_normalization=True,
    rotation_range = 90,
    width_shift_range = 0.1,
    height_shift_range = 0.1,
    horizontal_flip = True
)

# Compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied)
datagen.fit(X_train)

# Compile defined model-
model.compile(
    optimizer = optimizer,
    loss = loss_fn,
    metrics = ['accuracy']
    )

# Define early stopping criterion-
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor = 'val_loss', min_delta = 0.001,
    patience = 4, verbose = 0,
    mode = 'auto', baseline = None,
    restore_best_weights = True
)

当我训练这个CNN模型时,没有使用以下代码进行任何数据增强,似乎没有问题:

代码语言:javascript
复制
# Train model without any data augmentation-
history = model.fit(
    x = X_train, y = y_train,
    batch_size = batch_size, epochs = num_epochs,
    callbacks = [early_stopping],
    validation_data = (X_test, y_test)
    )

然而,在使用数据(图像)增强时:

代码语言:javascript
复制
# Train model on batches with real-time data augmentation-
training_history = model.fit(
    datagen.flow(
        X_train, y_train,
        batch_size = batch_size, subset = 'training'
        ),
        validation_data = (X_test, y_test),
        steps_per_epoch = len(X_train) / batch_size,
        epochs = num_epochs,
        callbacks = [early_stopping]
        )

它给出了错误:

ValueError:在拆分后,培训和验证子集有不同数量的类。如果您的numpy数组是按标签排序的,则可能需要对它们进行洗牌。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-12-08 09:56:49

您只需删除参数subset='training',因为您没有在ImageDataGenerator中设置一个validation_split。这两个参数都必须设置才能工作,否则您就不能使用它们:

代码语言:javascript
复制
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    # featurewise_center=True,
    # featurewise_std_normalization=True,
    rotation_range = 90,
    width_shift_range = 0.1,
    height_shift_range = 0.1,
    horizontal_flip = True
)

datagen.fit(x_train)

# Compile defined model-
model.compile(
    optimizer = tf.keras.optimizers.Adam(),
    loss = tf.keras.losses.CategoricalCrossentropy(),
    metrics = ['accuracy']
    )

# Define early stopping criterion-
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor = 'val_loss', min_delta = 0.001,
    patience = 4, verbose = 0,
    mode = 'auto', baseline = None,
    restore_best_weights = True
)
batch_size = 32

training_history = model.fit(
    datagen.flow(
        x_train, y_train,
        batch_size = batch_size
        ),
        steps_per_epoch = len(x_train) // batch_size,
        epochs = 2,
        callbacks = [early_stopping])

有关更多信息,请查看文档

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70263348

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档