首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在PyTorch中模拟weighted_cross_entropy_with_logits

在PyTorch中模拟weighted_cross_entropy_with_logits
EN

Stack Overflow用户
提问于 2018-03-02 20:52:12
回答 1查看 1K关注 0票数 1

我正试着用PyTorch训练一个模型。有没有什么简单的方法可以从Tensorflow中创建一个像weighted_cross_entropy_with_logits这样的损失?

weighted_cross_entropy_with_logits中的pos_weight参数可以帮助平衡。但是在BCEWithLogitsLoss的参数列表中只有标签的权重。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-03-05 14:28:37

您可以根据需要编写自己的自定义损失函数。例如,您可以这样写:

代码语言:javascript
复制
def weighted_cross_entropy_with_logits(logits, target, pos_weight):
    return targets * -logits.sigmoid().log() * pos_weight + 
               (1 - targets) * -(1 - logits.sigmoid()).log()

这是一个基本的实现。您应该遵循前面提到的here步骤,以确保稳定性并避免溢出。只需使用他们导出的最终公式即可。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49069502

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档