将受GPU训练的模型加载到Pytorch中的CPU

时间:2020-07-02 06:01:51

标签: python pytorch

当我尝试将经过gpu训练的模型加载到cpu时,这是一段代码:

model_conv.load_state_dict(torch.load(model_file, map_location='cpu'))
model_conv = model_conv.cpu()

,错误消息是:

Traceback (most recent call last):
  File "prediction.py", line 269, in <module>
    model_conv.load_state_dict(torch.load(resume_file, map_location='cpu'))
  File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 229, in load
    return _load(f, map_location, pickle_module)
  File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 377, in _load
    result = unpickler.load()
  File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 348, in persistent_load
    data_type(size), location)
  File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 246, in restore_location
    result = map_location(storage, location)
TypeError: 'str' object is not callable

我的pytorch版本是0.1.12_1。任何想法如何解决这个问题?我已经检查过how to load the gpu trained model into the cpu?,但该解决方案似乎不适用于我的情况。

任何建议都值得赞赏!

-更新- 如果不使用map_location参数,则错误消息为

Traceback (most recent call last):
  File "prediction.py", line 268, in <module>
    model_conv.load_state_dict(torch.load(resume_file))
  File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 229, in load
    return _load(f, map_location, pickle_module)
  File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 377, in _load
    result = unpickler.load()
  File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 348, in persistent_load
    data_type(size), location)
  File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 85, in default_restore_location
    result = fn(storage, location)
  File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/serialization.py", line 67, in _cuda_deserialize
    return obj.cuda(device_id)
  File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/_utils.py", line 57, in _cuda
    with torch.cuda.device(device):
  File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/cuda/__init__.py", line 124, in __enter__
    _lazy_init()
  File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/cuda/__init__.py", line 84, in _lazy_init
    _check_driver()
  File "/home/ubuntu/anaconda2/lib/python2.7/site-packages/torch/cuda/__init__.py", line 58, in _check_driver
    http://www.nvidia.com/Download/index.aspx""")
AssertionError: 
Found no NVIDIA driver on your system. Please check that you
have an NVIDIA GPU and installed a driver from
http://www.nvidia.com/Download/index.aspx

0 个答案:

没有答案