在Tensorflow addons中,有两次提到三元组损失,一个是基类tfa.losses.triplet_semihard_loss,另一个是tfa.losses.TripletSemiHardLoss,它是由用户初始化的子类,反过来隐式地调用基类。在这段属于子类的代码中:
def __init__(self, margin=1.0, name=None):
super(TripletSemiHardLoss, self).__init__(
name=name, reduction=tf.keras.losses.Reduction.NONE)
self.margin = margin
def call(self, y_true, y_pred):
return triplet_semihard_loss(y_true, y_pred, self.margin)我不明白call方法是怎么回事,它返回基类函数,给出了y_true和y_pred ndarray,但是它们到底是从哪里来的呢?根据Tensorflow文档指南,子类在模型compile语句中初始化为:
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tfa.losses.TripletSemiHardLoss())然后将模型拟合为:
history = model.fit(
train_dataset,
epochs=5)train_dataset结构是一个包含嵌入数据和相应的整数标签的元组,但是子类如何认识到这是要操作的数据呢?那么call方法也是隐式调用的吗?
https://stackoverflow.com/questions/59243942
复制相似问题