首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >喀拉斯有办法立即停止训练吗?

喀拉斯有办法立即停止训练吗?
EN

Stack Overflow用户
提问于 2020-06-03 09:00:56
回答 2查看 1.8K关注 0票数 4

我正在为我的tf.keras培训编写一个定制的早期停止回调。为此,我可以在一个回调函数中设置变量self.model.stop_training = True,例如on_epoch_end()。但是,Keras只在当前时代完成时才停止训练,即使我在一个时期的训练中设置了这个变量,例如在on_batch_end()中。

因此,我的问题是:喀拉斯是否有办法立即停止训练,即使是在当前时代的进程中?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-06-03 11:45:43

您可以使用model.stop_training参数来停止培训。

例如,如果我们想在第二阶段第三批停止训练,那么您可以这样做。

代码语言:javascript
复制
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
import numpy as np
import pandas as pd

class My_Callback(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs={}):
      self.epoch = epoch

    def on_batch_end(self, batch, logs={}):
        if self.epoch == 1 and batch == 3:
          print (f"\nStopping at Epoch {self.epoch}, Batch {batch}")
          self.model.stop_training = True


X_train = np.random.random((100, 3))
y_train = pd.get_dummies(np.argmax(X_train[:, :3], axis=1)).values

clf = Sequential()
clf.add(Dense(9, activation='relu', input_dim=3))
clf.add(Dense(3, activation='softmax'))
clf.compile(loss='categorical_crossentropy', optimizer=SGD())

clf.fit(X_train, y_train, epochs=10, batch_size=16, callbacks=[My_Callback()])

输出:

代码语言:javascript
复制
Epoch 1/10
100/100 [==============================] - 0s 337us/step - loss: 1.0860
Epoch 2/10
 16/100 [===>..........................] - ETA: 0s - loss: 1.0830
Stopping at Epoch 1, Batch 3
<keras.callbacks.callbacks.History at 0x7ff2e3eeee10>
票数 1
EN

Stack Overflow用户

发布于 2020-06-03 11:17:44

在keras中,当被监视的数量停止改进时,您可以使用EarlyStopping来停止。从你的问题上看,还不清楚你想停止的条件是什么。如果您只想监视一个像EarlyStopping中的值,但是只想在批处理之后停止,那么如果值没有提高,您可以重写EarlyStopping类,用on_batch_end而不是on_epoch_end实现逻辑。

代码语言:javascript
复制
class EarlyBatchStopping(Callback):


    def __init__(self,
                 monitor='val_loss',
                 min_delta=0,
                 patience=0,
                 verbose=0,
                 mode='auto',
                 baseline=None,
                 restore_best_weights=False):
        super(EarlyStopping, self).__init__()

        self.monitor = monitor
        self.baseline = baseline
        self.patience = patience
        self.verbose = verbose
        self.min_delta = min_delta
        self.wait = 0
        self.stopped_epoch = 0
        self.restore_best_weights = restore_best_weights
        self.best_weights = None

        if mode not in ['auto', 'min', 'max']:
            warnings.warn('EarlyStopping mode %s is unknown, '
                          'fallback to auto mode.' % mode,
                          RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
        elif mode == 'max':
            self.monitor_op = np.greater
        else:
            if 'acc' in self.monitor:
                self.monitor_op = np.greater
            else:
                self.monitor_op = np.less

        if self.monitor_op == np.greater:
            self.min_delta *= 1
        else:
            self.min_delta *= -1

    def on_train_begin(self, logs=None):
        # Allow instances to be re-used
        self.wait = 0
        self.stopped_epoch = 0
        if self.baseline is not None:
            self.best = self.baseline
        else:
            self.best = np.Inf if self.monitor_op == np.less else -np.Inf

    def on_batch_end(self, epoch, logs=None):
        current = self.get_monitor_value(logs)
        if current is None:
            return

        if self.monitor_op(current - self.min_delta, self.best):
            self.best = current
            self.wait = 0
            if self.restore_best_weights:
                self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                if self.restore_best_weights:
                    if self.verbose > 0:
                        print('Restoring model weights from the end of '
                              'the best epoch')
                    self.model.set_weights(self.best_weights)

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0 and self.verbose > 0:
            print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))

    def get_monitor_value(self, logs):
        monitor_value = logs.get(self.monitor)
        if monitor_value is None:
            warnings.warn(
                'Early stopping conditioned on metric `%s` '
                'which is not available. Available metrics are: %s' %
                (self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
            )
        return monitor_value

如果您有另一个逻辑,您可以使用on_batch_end并根据您的逻辑设置self.model.stop_training = True,但我认为您有了这个想法。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62168914

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档