首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在TF2中动态更新批处理规范动量?

如何在TF2中动态更新批处理规范动量?
EN

Stack Overflow用户
提问于 2020-12-10 10:51:15
回答 1查看 338关注 0票数 0

我找到了一个PyTorch实现,它将批处理规范momentum参数从第一个时期的0.1衰减到最后一个时期的0.001。对于如何在momentum中使用批处理规范TF2参数,有什么建议吗?(例如,从0.9开始,以0.999结束),这就是PyTorch代码中所做的:

代码语言:javascript
复制
# in training script
momentum = initial_momentum * np.exp(-epoch/args.epochs * np.log(initial_momentum/final_momentum))
model_pos_train.set_bn_momentum(momentum)

# model class function
def set_bn_momentum(self, momentum):
    self.expand_bn.momentum = momentum
    for bn in self.layers_bn:
        bn.momentum = momentum

解决方案:

下面所选的答案提供了使用tf.keras.Model.fit() API时的可行解决方案。然而,我使用的是一个定制的训练循环。以下是我所做的:

在每一个时代之后:

代码语言:javascript
复制
mi = 1 - initial_momentum  # i.e., inital_momentum = 0.9, mi = 0.1
mf = 1 - final_momentum  # i.e., final_momentum = 0.999, mf = 0.001
momentum = 1 - mi * np.exp(-epoch / epochs * np.log(mi / mf))
model = set_bn_momentum(model, momentum)

set_bn_momentum函数(归功于这篇文章):

代码语言:javascript
复制
def set_bn_momentum(model, momentum):
    for layer in model.layers:
        if hasattr(layer, 'momentum'):
            print(layer.name, layer.momentum)
            setattr(layer, 'momentum', momentum)

    # When we change the layers attributes, the change only happens in the model config file
    model_json = model.to_json()

    # Save the weights before reloading the model.
    tmp_weights_path = os.path.join(tempfile.gettempdir(), 'tmp_weights.h5')
    model.save_weights(tmp_weights_path)

    # load the model from the config
    model = tf.keras.models.model_from_json(model_json)

    # Reload the model weights
    model.load_weights(tmp_weights_path, by_name=True)
    return model

这种方法没有给训练例程增加很大的开销。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-12-10 13:25:21

您可以在每个批处理的开始/结束中设置一个操作,这样您就可以在这个时期控制任何参数。

以下是回调的选项:

代码语言:javascript
复制
class CustomCallback(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

你可以获得动量

代码语言:javascript
复制
batch = tf.keras.layers.BatchNormalization()
batch.momentum = 0.001

在模型中,必须指定正确的层。

代码语言:javascript
复制
model.layers[1].momentum = 0.001

您可以在回调找到更多的信息和示例。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65233132

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档