将基于GPU构建的theano模型转换为CPU?

时间:2014-08-11 06:28:42

标签: cuda gpu pickle theano deep-learning

我有一些基于gpu的深度学习模型的pickle文件。我正试图在生产中使用它们。但是,当我尝试在服务器上取消它们时,我收到以下错误。

  

追踪(最近的呼叫最后):
    文件“score.py”,第30行,在       model =(cPickle.load(file))
    在CudaNdarray_unpickler中输入文件“/usr/local/python2.7/lib/python2.7/site-packages/Theano-0.6.0-py2.7.egg/theano/sandbox/cuda/type.py”,第485行登记/>       return cuda.CudaNdarray(npa)
  AttributeError :(“'NoneType'对象没有属性'CudaNdarray'”,(数组([[0.011515,0.01171047,0.10408644,..., - 0.0343636,
           0.04944979,-0.06583775],
         [-0.03771918,0.080524,-0.10609912,...,0.11019105,
          -0.0570752,0.02100536],
         [-0.03628891,-0.07109226,-0.00932018,...,0.04316209,
           0.02817888,0.05785328],
         ...,
         [0.0703947,-0.00172865,-0.05942701,..., - 0.00999349,
           0.01624184,0.09832744],
         [-0.09029484,-0.11509365,-0.07193922,...,0.10658887,
           0.17730837,0.01104965],
         [0.06659461,-0.02492988,0.02271739,..., - 0.0646857,
           0.03879852,0.08779807]],dtype = float32),))

我在本地机器上检查了cudaNdarray包,但是没有安装,但我仍然能够解开它们。但在服务器中,我无法做到。如何让它们在没有GPU的服务器上运行?

4 个答案:

答案 0 :(得分:4)

pylearn2中有一个脚本可以满足你的需要:

pylearn2/scripts/gpu_pkl_to_cpu_pkl.py

答案 1 :(得分:2)

相关的Theano代码为here

从那里看,你可以设置一个选项config.experimental.unpickle_gpu_on_cpu,这会使CudaNdarray_unpickler返回底层的原始Numpy数组。

答案 2 :(得分:1)

这对我有用。注意:除非设置了以下环境变量,否则这不起作用:export THEANO_FLAGS='device=cpu'

import os
from pylearn2.utils import serial
import pylearn2.config.yaml_parse as yaml_parse

if __name__=="__main__":

_, in_path, out_path = sys.argv
os.environ['THEANO_FLAGS']="device=cpu"

model = serial.load(in_path)

model2 = yaml_parse.load(model.yaml_src)
model2.set_param_values(model.get_param_values())

serial.save(out_path, model2)

答案 3 :(得分:0)

我只是通过保存参数W&来解决这个问题。 b,但不是整个模型。您可以使用以下命令保存参数:http://deeplearning.net/software/theano/tutorial/loading_and_saving.html?highlight=saving%20load#robust-serialization 这可以将CudaNdarray保存为numpy数组。然后你需要通过numpy.load()读取params,最后使用theano.shared()将numpy数组转换为tensorSharedVariable。

相关问题