我正试着用PyTorch训练一个模型。有没有什么简单的方法可以从Tensorflow中创建一个像weighted_cross_entropy_with_logits这样的损失?
weighted_cross_entropy_with_logits中的pos_weight参数可以帮助平衡。但是在BCEWithLogitsLoss的参数列表中只有标签的权重。
发布于 2018-03-05 14:28:37
您可以根据需要编写自己的自定义损失函数。例如,您可以这样写:
def weighted_cross_entropy_with_logits(logits, target, pos_weight):
return targets * -logits.sigmoid().log() * pos_weight +
(1 - targets) * -(1 - logits.sigmoid()).log()这是一个基本的实现。您应该遵循前面提到的here步骤,以确保稳定性并避免溢出。只需使用他们导出的最终公式即可。
https://stackoverflow.com/questions/49069502
复制相似问题