首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >"tfa.losses.triplet_semihard_loss“是如何被调用的?

"tfa.losses.triplet_semihard_loss“是如何被调用的?
EN

Stack Overflow用户
提问于 2019-12-09 14:59:13
回答 1查看 818关注 0票数 0

在Tensorflow addons中,有两次提到三元组损失,一个是基类tfa.losses.triplet_semihard_loss,另一个是tfa.losses.TripletSemiHardLoss,它是由用户初始化的子类,反过来隐式地调用基类。在这段属于子类的代码中:

代码语言:javascript
复制
    def __init__(self, margin=1.0, name=None):
        super(TripletSemiHardLoss, self).__init__(
            name=name, reduction=tf.keras.losses.Reduction.NONE)
        self.margin = margin

    def call(self, y_true, y_pred):
        return triplet_semihard_loss(y_true, y_pred, self.margin)

我不明白call方法是怎么回事,它返回基类函数,给出了y_truey_pred ndarray,但是它们到底是从哪里来的呢?根据Tensorflow文档指南,子类在模型compile语句中初始化为:

代码语言:javascript
复制
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tfa.losses.TripletSemiHardLoss())

然后将模型拟合为:

代码语言:javascript
复制
history = model.fit(
    train_dataset,
    epochs=5)

train_dataset结构是一个包含嵌入数据和相应的整数标签的元组,但是子类如何认识到这是要操作的数据呢?那么call方法也是隐式调用的吗?

EN

回答 1

Stack Overflow用户

发布于 2019-12-09 19:30:03

当类的实例被调用时,__call__被调用。y_truey_pred分别包含真实的标签和模型预测的标签。Tensorflow(tf.keras)在内部将您提供给y_true的标签转换为看到的here,并使用model.fit()对数据进行训练。

所有tf.keras损失都是以这种形式实现的,即具有两个参数y_truey_pred的函数,如here所示。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59243942

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档