从Torchvision预训练模型中获取模型类标签

时间:2020-08-15 18:00:40

标签: python pytorch torchvision

我使用的是Torchvision提供的经过预先训练的Alexnet模型(无微调)。 问题是,即使我能够对某些数据运行模型并获得输出概率分布,也无法找到将其映射到的类标签

遵循此official documentation

import torch
model = torch.hub.load('pytorch/vision:v0.6.0', 'alexnet', pretrained=True)
model.eval()
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

按照一些处理图像的步骤,我能够使用它来获取单个图像的输出,作为(1,1000)暗淡矢量,我将使用softmax来获得概率分布-

#Output - 

tensor([-1.6531e+00, -4.3505e+00, -1.8172e+00, -4.2143e+00, -3.1914e+00,
         3.4163e-01,  1.0877e+00,  5.9350e+00,  8.0425e+00, -7.0242e-01,
        -9.4130e-01, -6.0822e-01, -2.4097e-01, -1.9946e+00, -1.5288e+00,
        -3.2656e+00, -5.5800e-01,  1.0524e+00,  1.9211e-01, -4.7202e+00,
        -3.3880e+00,  4.3048e+00, -1.0997e+00,  4.6132e+00, -5.7404e-03,
        -5.3437e+00, -4.7378e+00, -3.3974e+00, -4.1287e+00,  2.9064e-01,
        -3.2955e+00, -6.7051e+00, -4.7232e+00, -4.1778e+00, -2.1859e+00,
        -2.9469e+00,  3.0465e+00, -3.5882e+00, -6.3890e+00, -4.4203e+00,
        -3.3685e+00, -5.0983e+00, -4.9006e+00, -5.5235e+00, -3.7233e+00,
        -4.0204e+00,  2.6998e-01, -4.4702e+00, -5.6617e+00, -5.4880e+00,
        -2.6801e+00, -3.2129e+00, -1.6294e+00, -5.2289e+00, -2.7495e+00,
        -2.6286e+00, -1.8206e+00, -2.3196e+00, -5.2806e+00, -3.7652e+00,
        -3.0987e+00, -4.1421e+00, -5.2531e+00, -4.6505e+00, -3.5815e+00,
        -4.0189e+00, -4.0008e+00, -4.5512e+00, -3.2248e+00, -7.7903e+00,
        -1.4484e+00, -3.8347e+00, -4.5611e+00, -4.3681e+00,  2.7234e-01,
        -4.0162e+00, -4.2136e+00, -5.4524e+00,  1.1744e+00, -4.7785e+00,
        -1.8335e+00,  4.1288e-01,  2.2239e+00, -9.9919e-02,  4.8216e+00,
        -8.4304e-01,  5.6911e-01, -4.0484e+00, -3.3013e+00,  2.8698e+00,
        -1.1419e+00, -9.1690e-01, -2.9284e+00, -2.6097e+00, -1.8213e-01,
        -2.5429e+00, -2.1095e+00,  2.2419e+00, -1.6280e+00,  7.4458e+00,
         2.3184e+00, -5.7408e+00, -7.4332e-01, -5.4066e+00,  1.5177e+01,
        -4.4737e-02,  1.8237e+00, -3.7741e+00,  9.2271e-01, -4.3687e-01,
        -1.4003e+00, -4.3026e+00,  6.3782e-01, -1.0808e+00, -1.4173e+00,
         2.6194e+00, -3.8418e+00,  1.1598e+00, -2.6876e+00, -3.6103e+00,
        -4.9281e+00, -4.1411e+00, -3.3603e+00, -3.4296e+00, -1.4997e+00,
        -2.8381e+00, -1.2843e+00,  1.5745e+00, -1.7449e+00,  4.2903e-01,
         3.1234e-01, -2.8206e+00,  3.6688e-01, -2.1033e+00,  1.6481e+00,
         1.4222e+00, -2.7303e+00, -3.6292e+00,  1.2864e+00, -2.5541e+00,
        -2.9663e+00, -4.1575e+00, -3.1954e+00, -4.6487e-01,  1.8916e+00,
        -7.4721e-01,  4.5986e+00, -2.5443e+00, -6.2003e+00, -1.3215e+00,
        -2.6225e+00,  9.9639e+00,  9.7772e+00,  9.6715e+00,  9.0857e+00,...

我从哪里获得课程标签?我找不到任何可以从模型对象中获取方法的方法。

1 个答案:

答案 0 :(得分:0)

不幸的是,您不能直接从Torchvision模型获取类标签名称。但是,这些模型是在ImageNet数据集上训练的(因此有1000个类)。

据我所知,您必须从网络上获取类名映射;没有办法把它从火炬上拿下来。以前,您可以使用torchvision.datasets.ImageNet直接下载ImageNet,它具有一个内置的标签到类名转换器。现在,下载链接不再公开可用,需要手动下载,然后才能被数据集使用。ImageNet。

因此,您可以简单地在线搜索类以标记ImageNet的标签,而无需下载数据或尝试使用手电筒。 Try here for example