首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在Tensorflow2.x 2.x中按子类tf.keras.losses.Loss类自定义损失

如何在Tensorflow2.x 2.x中按子类tf.keras.losses.Loss类自定义损失
EN

Stack Overflow用户
提问于 2020-05-14 14:12:49
回答 1查看 3.8K关注 0票数 4

当我在Tensorflow的网站上看到导轨时,我发现了两种自定义损失的方法。第一个是定义一个损失函数,就像:

代码语言:javascript
复制
def basic_loss_function(y_true, y_pred):
    return tf.math.reduce_mean(tf.abs(y_true - y_pred))

为了简单起见,我们假设批处理的大小也是1,所以y_truey_pred的形状都是(1,c),其中c是类的数目。在这种方法中,我们给出了两个向量y_truey_pred,并返回一个值(Scala)。

然后,第二个方法是子类tf.keras.losses.Loss类,指南中的代码是:

代码语言:javascript
复制
class WeightedBinaryCrossEntropy(keras.losses.Loss):
    """
    Args:
      pos_weight: Scalar to affect the positive labels of the loss function.
      weight: Scalar to affect the entirety of the loss function.
      from_logits: Whether to compute loss from logits or the probability.
      reduction: Type of tf.keras.losses.Reduction to apply to loss.
      name: Name of the loss function.
    """
    def __init__(self, pos_weight, weight, from_logits=False,
                 reduction=keras.losses.Reduction.AUTO,
                 name='weighted_binary_crossentropy'):
        super().__init__(reduction=reduction, name=name)
        self.pos_weight = pos_weight
        self.weight = weight
        self.from_logits = from_logits

    def call(self, y_true, y_pred):
        ce = tf.losses.binary_crossentropy(
            y_true, y_pred, from_logits=self.from_logits)[:,None]
        ce = self.weight * (ce*(1-y_true) + self.pos_weight*ce*(y_true))
        return ce

在调用方法中,我们通常给出两个向量y_truey_pred,但是我注意到它返回ce,它是一个形状为(1,c)的向量!

那么在上面的玩具例子中有什么问题吗?或者Tensorflow2.x 2.x背后有魔力?

EN

回答 1

Stack Overflow用户

发布于 2020-05-15 07:27:12

除了实现之外,两者之间的主要区别是损失函数的类型。第一种是L1损失(定义绝对差的平均值,主要用于类似问题的回归),第二种是二进制交叉熵(用于分类)。它们并不是相同损失的不同实现,这在您所链接的指南中有说明。

在多标签、多类分类设置中,二进制交叉熵为每个类输出一个值,就好像它们彼此独立一样。

编辑:

在第二个损失函数中,reduction参数控制输出的聚合方式,例如。默认情况下,您的代码使用keras.losses.Reduction.AUTO,如果您检查源代码,这将转换为对批处理的求和。这意味着,最终的损失将是一个向量,但是还有其他可用的减少,您可以在文档中检查它们。我相信,即使你不定义约简取损失向量中损失元素之和,TF优化器也会这样做,以避免向量反向传播的错误。向量的反向传播会导致权重问题,从而“贡献”每一个损失元素。但是,我还没有在源代码中检查这一点。:)

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61799546

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档