我有我的自定义损失类和一个回调来更新权重,这是我从这里,这里获得的。第二个链接有点不太适合我的情况,因为我们需要访问丢失历史记录和准确性,以便更新权重,所以我认为从第一个链接回调是最好的方法。
这是我得到的密码
class AdaptiveLossCallback(tf.keras.callbacks.Callback):
def __init__(self):
super(AdaptiveLossCallback, self).__init__()
self.weight1 = tf.Variable(1.0, trainable=False, name='weight1', dtype=tf.float32)
self.weight2 = tf.Variable(0.0, trainable=False, name='weight2', dtype=tf.float32)
def on_epoch_end(self, epoch, logs=None):
if epoch == 49:
self.weight1 = tf.assign(self.weight1 , tf.constant(0.5))
self.weight2 = tf.assign(self.weight2 , tf.constant(0.5))
elif epoch == 74:
self.weight1 = tf.assign(self.weight1 , tf.constant(0.0))
self.weight2 = tf.assign(self.weight2 , tf.constant(1.0))
class CustomLoss(tf.keras.losses.Loss):
def __init__(self,
adaptive_loss=None,
from_logits=False,
reduction=losses_utils.ReductionV2.AUTO,
name=None):
super(CustomLoss, self).__init__(reduction=reduction)
self.from_logits = from_logits
self.adaptive_loss = adaptive_loss
def call(self, y_true, y_pred):
...
weight1 = self.adaptive_loss.weight1
weight2 = self.adaptive_loss.weight2
return weight1 * loss1 + weight2 * loss2但我似乎不能让它起作用。当我运行这个时,我会说
尝试使用未初始化的值weight1
我试过这个之后
session = tf.keras.backend.get_session()
session.run(tf.global_variables_initializer())
model.fit(...)它似乎有效,但权重值根本没有更新。
我做错了什么,怎么才能解决这个问题?是否有更好的方法向Keras模型中添加可变变量?
谢谢
PS。我不能使用Keras模型loss_weights,因为我只有一个输出
发布于 2019-07-28 07:39:51
问题是损失函数中的权重引用不是用tf.assign更新的。要适当更新损失系数,可以执行以下操作:
( a) K.set_value(self.weightX, update_value)
或
( b) sess.run(self.weightX.assign(update_tensor))
https://stackoverflow.com/questions/57234360
复制相似问题