首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >NotImplementedError:学习进度计划必须覆盖get_config

NotImplementedError:学习进度计划必须覆盖get_config
EN

Stack Overflow用户
提问于 2020-05-02 09:17:37
回答 1查看 2.5K关注 0票数 6

我使用tf.keras创建了一个自定义计划,在保存模型时遇到了这个错误:

NotImplementedError:学习进度计划必须覆盖get_config

这门课看起来如下:

代码语言:javascript
复制
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):

    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps**-1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

    def get_config(self):
        config = {
            'd_model':self.d_model,
            'warmup_steps':self.warmup_steps

        }
        base_config = super(CustomSchedule, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-05-21 05:30:33

当您使用自定义子类模型时,保存模型体系结构有点棘手。相反,只使用Model.save_weights()来保存权重比较容易。

如果将代码更改为此,您将不会看到该错误:

代码语言:javascript
复制
  def get_config(self):
    config = {
    'd_model': self.d_model,
    'warmup_steps': self.warmup_steps,

     }
    return config
票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61557024

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档