首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >从火炬视觉预训练模型中获取模型类标签

从火炬视觉预训练模型中获取模型类标签
EN

Stack Overflow用户
提问于 2020-08-15 18:00:40
回答 1查看 2.5K关注 0票数 2

我正在使用一个预训练的Alexnet模型(没有微调)从火炬视觉。问题是,即使我能够在某些数据上运行模型并获得输出概率分布,但我无法找到类标签来将其映射到

遵循这个正式文件

代码语言:javascript
复制
import torch
model = torch.hub.load('pytorch/vision:v0.6.0', 'alexnet', pretrained=True)
model.eval()
代码语言:javascript
复制
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

按照处理图像的一些步骤,我可以用它作为一个(1,1000)小向量得到单个图像的输出,我将使用softmax得到一个概率分布-

代码语言:javascript
复制
#Output - 

tensor([-1.6531e+00, -4.3505e+00, -1.8172e+00, -4.2143e+00, -3.1914e+00,
         3.4163e-01,  1.0877e+00,  5.9350e+00,  8.0425e+00, -7.0242e-01,
        -9.4130e-01, -6.0822e-01, -2.4097e-01, -1.9946e+00, -1.5288e+00,
        -3.2656e+00, -5.5800e-01,  1.0524e+00,  1.9211e-01, -4.7202e+00,
        -3.3880e+00,  4.3048e+00, -1.0997e+00,  4.6132e+00, -5.7404e-03,
        -5.3437e+00, -4.7378e+00, -3.3974e+00, -4.1287e+00,  2.9064e-01,
        -3.2955e+00, -6.7051e+00, -4.7232e+00, -4.1778e+00, -2.1859e+00,
        -2.9469e+00,  3.0465e+00, -3.5882e+00, -6.3890e+00, -4.4203e+00,
        -3.3685e+00, -5.0983e+00, -4.9006e+00, -5.5235e+00, -3.7233e+00,
        -4.0204e+00,  2.6998e-01, -4.4702e+00, -5.6617e+00, -5.4880e+00,
        -2.6801e+00, -3.2129e+00, -1.6294e+00, -5.2289e+00, -2.7495e+00,
        -2.6286e+00, -1.8206e+00, -2.3196e+00, -5.2806e+00, -3.7652e+00,
        -3.0987e+00, -4.1421e+00, -5.2531e+00, -4.6505e+00, -3.5815e+00,
        -4.0189e+00, -4.0008e+00, -4.5512e+00, -3.2248e+00, -7.7903e+00,
        -1.4484e+00, -3.8347e+00, -4.5611e+00, -4.3681e+00,  2.7234e-01,
        -4.0162e+00, -4.2136e+00, -5.4524e+00,  1.1744e+00, -4.7785e+00,
        -1.8335e+00,  4.1288e-01,  2.2239e+00, -9.9919e-02,  4.8216e+00,
        -8.4304e-01,  5.6911e-01, -4.0484e+00, -3.3013e+00,  2.8698e+00,
        -1.1419e+00, -9.1690e-01, -2.9284e+00, -2.6097e+00, -1.8213e-01,
        -2.5429e+00, -2.1095e+00,  2.2419e+00, -1.6280e+00,  7.4458e+00,
         2.3184e+00, -5.7408e+00, -7.4332e-01, -5.4066e+00,  1.5177e+01,
        -4.4737e-02,  1.8237e+00, -3.7741e+00,  9.2271e-01, -4.3687e-01,
        -1.4003e+00, -4.3026e+00,  6.3782e-01, -1.0808e+00, -1.4173e+00,
         2.6194e+00, -3.8418e+00,  1.1598e+00, -2.6876e+00, -3.6103e+00,
        -4.9281e+00, -4.1411e+00, -3.3603e+00, -3.4296e+00, -1.4997e+00,
        -2.8381e+00, -1.2843e+00,  1.5745e+00, -1.7449e+00,  4.2903e-01,
         3.1234e-01, -2.8206e+00,  3.6688e-01, -2.1033e+00,  1.6481e+00,
         1.4222e+00, -2.7303e+00, -3.6292e+00,  1.2864e+00, -2.5541e+00,
        -2.9663e+00, -4.1575e+00, -3.1954e+00, -4.6487e-01,  1.8916e+00,
        -7.4721e-01,  4.5986e+00, -2.5443e+00, -6.2003e+00, -1.3215e+00,
        -2.6225e+00,  9.9639e+00,  9.7772e+00,  9.6715e+00,  9.0857e+00,...

我从哪里得到分类标签?我找不到任何方法让我从模型对象中得到它。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-08-15 18:31:32

不幸的是,您不能直接从torchvision模型中获取类标签名称。但是,这些模型是在ImageNet数据集(因此是1000个类)上进行培训的。

据我所知,您必须将类名映射到web上;没有办法将其从火炬上删除。以前,您可以使用ImageNet直接下载torchvision.datasets.ImageNet,它有一个内置的标签到类名转换器。现在下载链接是不可公开的,需要手动下载,然后datasets.ImageNet才能使用它。

因此,您可以简单地搜索类来标记ImageNet在线映射,而不是下载数据或尝试使用torch。例如,在这里尝试

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63429260

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档