首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras -对象检测模型- Xception与VGG

Keras -对象检测模型- Xception与VGG
EN

Stack Overflow用户
提问于 2022-09-28 09:19:12
回答 1查看 64关注 0票数 0

我正在使用Keras的预训练模型来训练对象检测模型(VGG16,VGG19,Xception,.)在具有YOLO坐标的4000多幅图像的数据集上,下面是对培训和验证数据以及模型编译和培训负责的部分数据预处理。

对于VGG16 & VGG19 -我将图像和YOLO坐标调整为推荐的默认图像大小224x224,而对于Xception和InceptionV3,我将调整大小为299x299。

我冻结了Keras应用程序的所有层,只添加了4个顶级密集层,这些层正在我的数据集上进行训练,以充分利用预先训练过的模型的潜力。当我使用VGG16或VGG19时,它工作得很好,我的训练和验证准确率超过了92%,这是很棒的,而且列车/ val的分割似乎是平衡的。然而,当我使用Xception或InceptionV3应用程序时,它总是以10%的精度提前结束,这一点我不明白。

代码语言:javascript
复制
IMAGE_SIZE = (299, 299)

train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)  # val 20%

val_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)


train_data = train_datagen.flow_from_dataframe(
    dataframe=df_all, 
    directory=save_dir,                                               
    x_col="image_name", 
    y_col=['yolo_x', 'yolo_y', 'yolo_width', 'yolo_length'], 
    class_mode="raw", 
    target_size=IMAGE_SIZE,
    batch_size=32,
    shuffle=True,
    Subset='training'
)

val_data = val_datagen.flow_from_dataframe(
    dataframe=df_all, 
    directory=save_dir,                                               
    x_col="image_name", 
    y_col=['yolo_x', 'yolo_y', 'yolo_width', 'yolo_length'], 
    class_mode="raw", 
    target_size=IMAGE_SIZE,
    batch_size=32,
    shuffle=False,
    Subset='validation'
)

from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint

learning_rate_reduction = ReduceLROnPlateau(monitor='loss', 
                                            patience=5, 
                                            verbose=2, 
                                            factor=0.5,                                            
                                            min_lr=0.000001)

early_stops = EarlyStopping(monitor='loss', 
                            min_delta=0, 
                            patience=10, 
                            verbose=2, 
                            mode='auto')

checkpointer = ModelCheckpoint(filepath = 'cis3115.{epoch:02d}-{accuracy:.6f}.hdf5',
                               verbose=2,
                               save_best_only=True, 
                               save_weights_only = True)


# Select a pre-trained model Xception
#pretrained_model = tf.keras.applications.VGG16(weights='imagenet', include_top=False ,input_shape=[*IMAGE_SIZE, 3])
pretrained_model = tf.keras.applications.Xception(weights='imagenet', include_top=False ,input_shape=[*IMAGE_SIZE, 3])

# Set the following to False so that the pre-trained weights are not changed 
pretrained_model.trainable = False 

model = Sequential()
model.add(pretrained_model)

# Flatten 2D images into 1D data for final layers like traditional neural network
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(32, activation='relu'))

# The final output layer
# Use Sigmoid when predicting YOLO bounding box since that output is between 0 and 1
model.add(Dense(4, activation='sigmoid'))


print ("Pretrained model used:")
pretrained_model.summary()

print ("Final model created:")
model.summary()

# Compile neural network model
model.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])


# Train the model with the images in the folders
history = model.fit(
        train_data,
        validation_data=val_data,
        batch_size=16,                  # Number of image batches to process per epoch 
        epochs=100,                      # Number of epochs
        callbacks=[learning_rate_reduction, early_stops],
        )

Xception是更复杂的预先训练的模型,因此理论上应该更精确,因此我假设我在建立CNN时做错了什么。

Xception / Inception模型失败的原因是什么?我该换什么?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-10-01 03:22:54

这个问题似乎是在扁平层,因为它创造了大量的参数,它一直在失败。然而,当我用GlobalAveragePooling2D代替扁平时,它工作得很好。

因此,我将其替换为:

代码语言:javascript
复制
model.add(Flatten())

通过这一点:

代码语言:javascript
复制
model.add(GlobalAveragePooling2D())
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73879020

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档