首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow Keras,如何将减肥添加到不可训练的变量中?

Tensorflow Keras,如何将减肥添加到不可训练的变量中?
EN

Stack Overflow用户
提问于 2019-07-27 17:08:34
回答 1查看 902关注 0票数 1

我有我的自定义损失类和一个回调来更新权重,这是我从这里这里获得的。第二个链接有点不太适合我的情况,因为我们需要访问丢失历史记录和准确性,以便更新权重,所以我认为从第一个链接回调是最好的方法。

这是我得到的密码

代码语言:javascript
复制
class AdaptiveLossCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super(AdaptiveLossCallback, self).__init__()
        self.weight1 = tf.Variable(1.0, trainable=False, name='weight1', dtype=tf.float32)
        self.weight2 = tf.Variable(0.0, trainable=False, name='weight2', dtype=tf.float32)

    def on_epoch_end(self, epoch, logs=None):
        if epoch == 49:
            self.weight1 = tf.assign(self.weight1 , tf.constant(0.5))
            self.weight2 = tf.assign(self.weight2 , tf.constant(0.5))
        elif epoch == 74:
            self.weight1 = tf.assign(self.weight1 , tf.constant(0.0))
            self.weight2 = tf.assign(self.weight2 , tf.constant(1.0))


class CustomLoss(tf.keras.losses.Loss):
    def __init__(self,
                 adaptive_loss=None,
                 from_logits=False,
                 reduction=losses_utils.ReductionV2.AUTO,
                 name=None):
        super(CustomLoss, self).__init__(reduction=reduction)
        self.from_logits = from_logits
        self.adaptive_loss = adaptive_loss

    def call(self, y_true, y_pred):
        ...
        weight1 = self.adaptive_loss.weight1
        weight2 = self.adaptive_loss.weight2
        return weight1 * loss1 + weight2 * loss2

但我似乎不能让它起作用。当我运行这个时,我会说

尝试使用未初始化的值weight1

我试过这个之后

代码语言:javascript
复制
session = tf.keras.backend.get_session()
session.run(tf.global_variables_initializer())
model.fit(...)

它似乎有效,但权重值根本没有更新。

我做错了什么,怎么才能解决这个问题?是否有更好的方法向Keras模型中添加可变变量?

谢谢

PS。我不能使用Keras模型loss_weights,因为我只有一个输出

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-07-28 07:39:51

问题是损失函数中的权重引用不是用tf.assign更新的。要适当更新损失系数,可以执行以下操作:

( a) K.set_value(self.weightX, update_value)

( b) sess.run(self.weightX.assign(update_tensor))

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57234360

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档