Pytorch:如何将模型动物园预训练模型映射到新GPU

时间:2018-05-19 22:37:02

标签: pytorch

我正在尝试加载一个预先训练过的模型

model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth'}

当我使用以下代码时,它总是将模型加载到cuda:0。如果我想将它加载到cuda:3?

,该怎么办?
model = ResNet(BasicBlock, [3, 4, 6, 3]) 
device = 3
model.load_state_dict(model_zoo.load_url(model_urls['resnet34'], 
                      map_location=lambda storage, loc: storage.cuda(device)))

1 个答案:

答案 0 :(得分:0)

这应该适合你:

device = torch.device('cuda')
model = ResNet(BasicBlock, [3, 4, 6, 3]) 
with torch.cuda.device(3):
    model.load_state_dict(model_zoo.load_url(model_urls['resnet34'], 
                          map_location=lambda storage, loc: storage.cuda(device)))

我认为这适用于版本0.4.0及更高版本,您可以查看0.4.0中的其他一些示例。迁移指南: https://pytorch.org/2018/04/22/0_4_0-migration-guide.html