首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Resnet模型训练时间过长

Resnet模型训练时间过长
EN

Stack Overflow用户
提问于 2020-09-26 00:42:21
回答 2查看 491关注 0票数 0

我正在使用this教程为我的模型学习迁移学习。正如我们可以看到的,他的单次历时平均为1秒。

代码语言:javascript
复制
Epoch 1/100
1080/1080 [==============================] - 10s 10ms/step - loss: 3.6862 - acc: 0.2000
Epoch 2/100
1080/1080 [==============================] - 1s 1ms/step - loss: 3.0746 - acc: 0.2574
Epoch 3/100
1080/1080 [==============================] - 1s 1ms/step - loss: 2.6839 - acc: 0.3185
Epoch 4/100
1080/1080 [==============================] - 1s 1ms/step - loss: 2.3929 - acc: 0.3583
Epoch 5/100
1080/1080 [==============================] - 1s 1ms/step - loss: 2.1382 - acc: 0.3870
Epoch 6/100
1080/1080 [==============================] - 1s 1ms/step - loss: 1.7810 - acc: 0.4593

但是,当我为我的cifar模型遵循几乎相同的代码时,我的单个时期大约需要1个小时才能运行。

代码语言:javascript
复制
Train on 50000 samples
 3744/50000 [=>............................] - ETA: 43:38 - loss: 3.3223 - acc: 0.1760
1

我的代码是

代码语言:javascript
复制
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras import Model

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

base_model = ResNet50(weights= None, include_top=False, input_shape= (32,32,3))

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.4)(x)
predictions = Dense(10 , activation= 'softmax')(x)
model = Model(inputs = base_model.input, outputs = predictions)

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])

hist = model.fit(x_train, y_train)

请注意,我对此模型使用的是cifar 10数据集。我的代码或数据有什么问题吗?我该如何改进这一点呢?一次历时1小时太长了。我也有NVIDIA MX-110 2 2GB,这是ofc TensorFlow正在使用的。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-09-26 14:47:30

我复制并运行了您的代码,但为了让它运行,我必须进行以下更改

代码语言:javascript
复制
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras import Model

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
print (len(x_train))
x_train = x_train / 255.0
x_test = x_test / 255.0

y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)

base_model = ResNet50(weights= None, include_top=False, input_shape= (32,32,3))

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.4)(x)
predictions = Dense(10 , activation= 'softmax')(x)
model = Model(inputs = base_model.input, outputs = predictions)

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])

hist = model.fit(x_train, y_train, )
# the result for 2 epochs is shown below
50000
Epoch 1/2
1563/1563 [==============================] - 58s 37ms/step - loss: 2.8654 - acc: 0.2537
Epoch 2/2
1563/1563 [==============================] - 51s 33ms/step - loss: 2.5331 - acc: 0.2748

根据model.fit文档,如果不指定批处理大小,则默认为32。因此,对于50,000个样本/32=1563个步骤。由于您的代码中的某些原因,批处理大小默认为1。我不知道为什么。因此设置batch_size=50,然后您将需要1000个步骤。为了加快速度,我会设置weights="imagenet“并冻结基础模型中的层

代码语言:javascript
复制
for layer in base_model.layers:
    layer.trainable = False
#if you set batch_size=50, weights="imagenet" with the base model frozen you get
50000
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
94773248/94765736 [==============================] - 5s 0us/step
Epoch 1/2
1000/1000 [==============================] - 16s 16ms/step - loss: 2.5101 - acc: 0.1487
Epoch 2/2
1000/1000 [==============================] - 10s 10ms/step - loss: 2.1159 - acc: 0.2249
票数 1
EN

Stack Overflow用户

发布于 2020-09-26 01:49:51

看起来不像是批量处理数据。因此,模型的每一次前向传递只看到一个训练实例,这是非常低效的。

尝试在model.fit()调用中设置批处理大小:

代码语言:javascript
复制
hist = model.fit(x_train, y_train, batch_size=16, epochs=num_epochs, 
                 validation_data=(x_test, y_test), shuffle=True)

调整您的批处理大小,使它是最大的,可以容纳在您的GPU的内存-尝试几个不同的值,然后确定一个。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64068204

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档