首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras网络训练速度下降

Keras网络训练速度下降
EN

Stack Overflow用户
提问于 2019-02-09 17:55:42
回答 1查看 160关注 0票数 1

为什么每次我在Jupyter笔记本上构建新模型时,Keras模型的训练都需要more time。退出Jupyter/Python并重启会重置训练速度。每次我这样做,The scatterplot看起来都是一样的。

我正在使用Keras 'Sequential‘训练一个普通的MLP,输入层大约有6000个特征,3个隐藏的relu层(大小为2500,800,800),带有batchnorm和dropout,以及一个sigmoid输出,没有什么特别的。

我正在优化(使用GPyOpt,但当我在一个简单的for循环中构建模型时,效果也会出现),我给它提供了一个函数,每次引用它时,它都会构建一个新的上述Keras模型。在该函数中构建模型之前,它调用了函数limitmen(),因为否则我会遇到GPU内存问题:

代码语言:javascript
复制
def limit_mem():
    """
    Clear GPU-memory and tensorflow session.
    """
    K.get_session().close()
    cfg = K.tf.ConfigProto()
    cfg.gpu_options.allow_growth = True
    K.set_session(K.tf.Session(config=cfg))

我在stackoverflow上搜索了一些here后发现了这个函数

代码语言:javascript
复制
def f_beta(precision, recall, beta):
    f_beta_result = (1 + (beta ** 2)) * (precision * recall) / (((beta ** 2) * precision) + recall)
    if isinstance(f_beta_result, np.ndarray):
        f_beta_result[np.isnan(f_beta_result)] = 0
    else:
        if math.isnan(f_beta_result):
            f_beta_result = 0
    return f_beta_result

beta = 1.5  # define beta for f-score

class Metrics(Callback):
    def on_train_begin(self, logs={}):
        self.val_f1s = []
        self.val_f2s = []
        self.val_recalls = []
        self.val_precisions = []
        self.val_briers = []

    def on_epoch_end(self, epoch, logs={}):
        val_predict = (np.asarray(self.model.predict(X_val))).round()
        val_targ = y_val
        _val_precision, _val_recall, _val_f1, _support = precision_recall_fscore_support(val_targ, val_predict, labels=[0,1])
        _val_f2 = f_beta(_val_precision[1], _val_recall[1], beta)
        _val_brier = brier_score_loss(val_targ, val_predict)
#         print(_val_precision)
        self.val_f1s.append(_val_f1[1])
        self.val_f2s.append(_val_f2)
        self.val_recalls.append(_val_recall[1])
        self.val_precisions.append(_val_precision[1])
        self.val_briers.append(_val_brier)
#         print (' — val_f1: %.3f — val_precision: %.3f — val_recall %.3f' % (    _val_f1[1], _val_precision[1], _val_recall[1]))
        return

    def return_metrics(self):
        return self.val_f1s, self.val_f2s, self.val_recalls, self.val_precisions, self.val_briers, np.array(self.val_f2s).argmax()

metrics = Metrics()

# create model
def build_model(dropout=0.9, dense1=2500, dense2=800, dense3=800, lr=0.0001):
    model = Sequential()

    # first layer
    model.add(Dense(dense1, input_dim=X_train.shape[1], init='uniform'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Dropout(dropout))

    # second layer
    model.add(Dense(dense2, init='uniform'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Dropout(dropout))

    # third layer
    model.add(Dense(dense3, init='uniform'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Dropout(dropout))

    # final layer
    model.add(Dense(1, activation='sigmoid'))

    # Compile model
    adam = Adam(lr)
    model.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy', 'mae'])
    return model

这或多或少就是构建keras模型的循环:

代码语言:javascript
复制
        for i in range(self.cycle):
            # actually build model
            t_before = time.time()
            self.keras_model = build_model(dropout, dense1, dense2, dense3, lr)

            # train model
            self.hist = self.keras_model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=self.epochs, 
                                             batch_size=1024, verbose=0, callbacks=[metrics], 
                                             class_weight={ 0 : 1, 1 : weight1 })
            t_after = time.time()

有没有人有过同样的经历?您需要更多信息吗?或者这是一个众所周知的问题,只有一个简单的解决方案(或者根本没有解决方案)?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-02-10 15:47:13

代码语言:javascript
复制
from keras import backend as K
K.clear_session() 

做了trick

对于那些制作教程的人来说:也许在代码的末尾添加这个是个好主意。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54605121

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档