在vgg模型中修改正向功能

时间:2019-06-26 12:17:40

标签: python pytorch

我需要修改VGG16中现有的forward方法,以便它可以通过两个分类器并返回值

我尝试手动创建自定义转发方法并覆盖现有方法,但出现以下错误

vgg.forward = forward

forward()缺少1个必需的位置参数:'x'

我的自定义转发功能

def forward(self,x):
    x = self.features(x)
    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.classifier(x)
    y = self.classifier_2(x)
    return x,y

我用另外一个分类器将默认vgg16_bn修改为

vgg = models.vgg16_bn()
final_in_features = vgg.classifier[6].in_features
mod_classifier = list(vgg.classifier.children())[:-1]
mod_classifier.extend([nn.Linear(final_in_features, 10)])
vgg.add_module('classifier_2',vgg.classifier)

添加上述分类器后,我的模型如下所示

(classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )
  (classifier_2): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )

我的卷积层结果应该通过两个单独的FFN层传递。那么我该如何修改我的前进通行证

1 个答案:

答案 0 :(得分:1)

我认为实现所需目标的最佳方法是创建扩展nn.Module的新模型。我会做类似的事情:

from torchvision import models
from torch import nn

class MyVgg (nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        vgg = models.vgg16_bn(pretrained=True)

        # Here you get the bottleneck/feature extractor
        self.vgg_feature_extractor = nn.Sequential(*list(vgg.children())[:-1])

        # Now you can include your classifiers
        self.classifier1 = nn.Sequential(layers1)
        self.classifier2 = nn.Sequential(layers2)

    # Set your own forward pass
    def forward(self, img, extra_info=None):

        x = self.vgg_convs(img)
        x = x.view(x.size(0), -1)
        x1 = self.classifier1(x)
        x2 = self.classifier2(x)

        return x1, x2

让我知道它是否对您有帮助。 祝你好运。

相关问题