首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >pytorch中的多标签分类

pytorch中的多标签分类
EN

Stack Overflow用户
提问于 2018-10-17 21:16:40
回答 1查看 22.2K关注 0票数 18

我有一个多标签分类问题。我有11个类,大约4k个例子。每个示例可以有1到4-5个标签。目前,我正在使用log_loss分别为每个类训练一个分类器。正如你所期望的,训练11个分类器需要相当多的时间,我想尝试另一种方法,只训练1个分类器。其思想是,这个分类器的最后一层将有11个节点,并将按类输出实数,该实数将被sigmoid转换为proba。我想要优化的损失是所有类的log_loss的平均值。

不幸的是,我是pytorch的新手,即使通过阅读损失的源代码,我也不能确定现有的损失是否完全符合我的要求,或者我是否应该创建一个新的损失,如果是这样的话,我真的不知道该怎么做。

更具体地说,我希望为批次的每个元素提供一个大小为11的向量(每个标签包含一个实数(越接近无穷大,该类预测为1)),以及一个大小为11的向量(每个真标签包含1),并能够计算所有11个标签的平均log_loss,并基于该损失优化我的分类器。

任何帮助都将不胜感激:)

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-10-18 00:18:47

您正在寻找torch.nn.BCELoss。示例代码如下:

代码语言:javascript
复制
import torch

batch_size = 2
num_classes = 11

loss_fn = torch.nn.BCELoss()

outputs_before_sigmoid = torch.randn(batch_size, num_classes)
sigmoid_outputs = torch.sigmoid(outputs_before_sigmoid)
target_classes = torch.randint(0, 2, (batch_size, num_classes))  # randints in [0, 2).

loss = loss_fn(sigmoid_outputs, target_classes)

# alternatively, use BCE with logits, on outputs before sigmoid.
loss_fn_2 = torch.nn.BCEWithLogitsLoss()
loss2 = loss_fn_2(outputs_before_sigmoid, target_classes)
assert loss == loss2
票数 20
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52855843

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档