首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras CNN模型训练时间

Keras CNN模型训练时间
EN

Stack Overflow用户
提问于 2018-12-18 01:01:20
回答 2查看 971关注 0票数 0

我已经创建了一个二进制图像分类器,当我运行我的代码时,我会得到1到2个小时的训练时间。我怎样才能减少呢?这就是我的代码:

代码语言:javascript
复制
from keras.layers import Conv2D, MaxPooling2D, Input
from keras.layers import Input, Dense, Flatten
from keras.models import Model

num_classes = 2 
# This returns a tensor
inputs = Input(shape=(150,150,3))

x = Conv2D(16,(1,1), padding = 'same', activation = 'relu',)(inputs)
x = Conv2D(16,(3,3), padding = 'same', activation = 'relu')(x)
x = MaxPooling2D((3,3), strides = (1,1), padding = 'same')(inputs)

x = Conv2D(32,(1,1), padding = 'same', activation = 'relu',)(inputs)
x = Conv2D(32,(3,3), padding = 'same', activation = 'relu')(x)
x = MaxPooling2D((3,3), strides = (1,1), padding = 'same')(inputs)

x = Conv2D(64,(1,1), padding = 'same', activation = 'relu',)(inputs)
x = Conv2D(64,(3,3), padding = 'same', activation = 'relu')(x)
x = MaxPooling2D((3,3), strides = (1,1), padding = 'same')(inputs)

x = Conv2D(128,(1,1), padding = 'same', activation = 'relu',)(inputs)
x = Conv2D(128,(3,3), padding = 'same', activation = 'relu')(x)
x = MaxPooling2D((3,3), strides = (1,1), padding = 'same')(inputs)


x = Flatten()(inputs)
predictions = Dense(num_classes, activation='sigmoid')(x)

# This creates a model that includes
# the Input layer and three Dense layers
model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer='rmsprop',
          loss='sparse_categorical_crossentropy', #https://github.com/keras- 
team/keras/issues/5034
          #loss='binary_crossentropy',
          metrics=['accuracy'])

我使用ImageDataGenerator对我的图像进行预处理:

代码语言:javascript
复制
train_datagen = ImageDataGenerator(
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    "/content/drive/apagdata/train",
    target_size=(150,150),
    batch_size=32,
    class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
    "/content/drive/apagdata/test",
    target_size=(150,150),
    batch_size=32,
    class_mode='binary')

model.fit_generator(
    train_generator,
    steps_per_epoch=2000,
    epochs=50,
    validation_data=validation_generator,
    validation_steps=800)
model.save_weights('first_try.h5')

我的数据集总共只有169张图片。原因是,在对所有数据实现我的模型之前,我正在尝试构建一个基本模型。

EN

回答 2

Stack Overflow用户

发布于 2018-12-18 01:28:57

参见下面的更新

除了使用GPU之外,还有一些可能的速度改进:

  1. 如果图像是二进制的,那么您的输入可以是Input(shape=(150,150)),而不是3通道。
  2. 与其使用16 -- 32 -- 64 -- 128 conv过滤器,不如尝试更小的(如8 -- 16 -- 16 -- 32 ),检查何时停止获得更高的验证准确性。
  3. 实际上,您可以完全放弃1 x 1卷积。1 x 1卷积通常用于3 x 35 x 5卷积之前的降维。在你的情况下,这可能是不必要的。即使您想要使用它,1 x 1卷积也应该输出较少的信道数,以便3 x 3卷积可以在较小的信道上工作,输出更多的通道,从而使整个过程更快。例如,150x150x256输入-> 3x3x256卷积比150x150x256输入-> 1x1x64卷积-> 3x3x256卷积慢。看看如何将256-D输出首先映射到64-D,然后再映射到256-D。你不能这么做
  4. 您可以通过在max池操作中增加对(2,2)的步幅来更多地对输出进行子示例,至少在最后三个max池层中是这样。
  5. 与速度无关,但是如果您使用的是带有一个热标签的categorical_crossentropy,则可能需要使用softmax而不是sigmoid (除非您在多标签中有问题)。

在GPU上,这应该不会超过10分钟,因为你只有170张图片。

更新

实际上,我并没有注意到架构本身的正确性,但是您的评论让我产生了不同的想法。

你的网络基本上是这样的:

代码语言:javascript
复制
inputs = Input(shape=(150,150,3))
x = MaxPooling2D((3,3), strides = (1,1), padding = 'same')(inputs)
x = Flatten()(inputs)
predictions = Dense(num_classes, activation='sigmoid')(x)

这是因为,不是将第二次卷积的输出作为最大池的输入,而是将其丢弃,并再次对输入执行最大值池。见此:

代码语言:javascript
复制
x = Conv2D(16,(1,1), padding = 'same', activation = 'relu',)(inputs)
x = Conv2D(16,(3,3), padding = 'same', activation = 'relu')(x)
x = MaxPooling2D((3,3), strides = (1,1), padding = 'same')(inputs)  # should be x here instead of inputs

除此之外,@Alaroff的答案是主要原因。

票数 0
EN

Stack Overflow用户

发布于 2018-12-18 17:27:44

您使用的是带有fit_generatorbatch_size=32,并指定了steps_per_epoch=2000,这意味着您最终将向网络提供总共64000幅图像。您通常希望使用类似于每一个时代的唯一图像,例如steps_per_epoch = no_train_samples // batch_size + 1 (同样也适用于validation_steps)。但是,考虑到您只有169个训练样本,您可能需要进行更多的步骤,以尽量减少在不同时期之间的列车和验证阶段之间切换的Keras开销。

另外,考虑在训练之前降低训练图像的采样(您的输入是150×150 px ),以节省解压缩开销(特别是如果您有大的PNG文件要处理)。

此外,在Colab上,并不总是授予整个GPU,但可能与其他用户共享。这一事实也可能影响整体业绩。

最后,要注意你的数据是从哪里来的。这是你安装的谷歌硬盘吗?的低带宽/高延迟I/O也可能是上的一个问题。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53825007

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档