首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch RetinaNet列车模型输入

PyTorch RetinaNet列车模型输入
EN

Stack Overflow用户
提问于 2022-10-18 07:58:52
回答 1查看 45关注 0票数 0

我有model = torchvision.models.detection.retinanet_resnet50_fpn_v2(progress=True),想要对它进行自定义数据的培训。为了得到损失,我必须

代码语言:javascript
复制
classification_loss, regression_loss = model(images, targets)

我已经为images创建了一个批处理张量,但是在我的生活中,无法找到我应该如何格式化targets以用于对象检测.每个目标都有一个边框和一个类标签。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-10-18 08:42:57

查看本官方教程:https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

一般来说,targets是一个list of dict,例如

代码语言:javascript
复制
targets = [
    {
        "boxes": torch.as_tensor([[xmin, ymin, xmax, ymax]], dtype=torch.float32),
        "labels": torch.as_tensor([1,], dtype=torch.int64)
    }
]
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74107599

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档