Theano:调用Theano函数的论据

时间:2016-05-02 18:41:36

标签: python numpy theano

python的本质使我很难找到一种如何调用theano函数的正式定义。

当给出长度为4的矩阵batch列表时,我调用

validationFunction(batch[0],batch[1],batch[2],batch[3])

这可行。

当我打电话

validationFunction(batch)

validationFunction((list(batch))
抱怨道:

        validationError += validationFunction(batch) # [0],batch[1], batch[2], batch[3])
  File "/usr/lib64/python2.7/site-packages/theano/compile/function_module.py", line 786, in __call__
    allow_downcast=s.allow_downcast)
  File "/usr/lib64/python2.7/site-packages/theano/tensor/type.py", line 149, in filter
    converted_data = theano._asarray(data, self.dtype)
  File "/usr/lib64/python2.7/site-packages/theano/misc/safe_asarray.py", line 33, in _asarray
    rval = numpy.asarray(a, dtype=dtype, order=order)
  File "/usr/lib64/python2.7/site-packages/numpy/core/numeric.py", line 474, in asarray
    return array(a, dtype, copy=False, order=order)
ValueError: ('Bad input argument to theano function with name "dummy.py:96"  at index 0(0-based)', 'could not broadcast input array from shape (7,3) into shape (7)')

我有一个符号输入变量列表和相应的minibatches列表。批处理的形式如下:         print(" batch singular = \ n0:{} \ n1:{} \ n2:{} \ n3:{}" .format(batch [0],batch [1],batch [2]分批[3]))

0:[[3.0 3.0 2.0]
 [3.0 2.0 5.0]
 [2.0 5.0 3.0]
 [5.0 3.0 4.0]
 [3.0 4.0 3.0]
 [4.0 3.0 2.0]
 [3.0 2.0 6.0]]
1:[[5.0 3.0 4.0]
 [3.0 4.0 3.0]
 [4.0 3.0 2.0]
 [3.0 2.0 6.0]
 [2.0 6.0 6.0]
 [6.0 6.0 6.0]
 [6.0 6.0 2.0]]
2:[[3.0 2.0 14.0]
 [2.0 2.0 14.0]
 [6.0 2.0 14.0]
 [6.0 2.0 14.0]
 [6.0 2.0 14.0]
 [2.0 2.0 14.0]
 [4.0 2.0 14.0]]
3:[[2.0]
 [6.0]
 [6.0]
 [6.0]
 [2.0]
 [4.0]
 [4.0]]

基本上,如何在不硬编码1 ... n的情况下调用validationFunction(a [0],a [1],...,a [n-1])?什么参数的定义?

定义了该功能

validationFunction= theano.function(inputVars + [targetVar], testLoss)

其中inputVars是theano矩阵的列表,targetVar是theano矩阵。我应该以不同的方式定义功能吗? inputVars + [targetVar]创建了我的三个输入和一个目标的列表。

我真的花了很多时间与theano及其风格,但有些事情记录得太紧凑。

  

输入可以作为变量或In实例给出。在实例中也   有一个变量,但他们附加了一些关于如何的额外信息   应该使用与该变量对应的调用时参数。   同样,Out实例可以附加有关输出的信息   变量应该返回。

1 个答案:

答案 0 :(得分:0)

我在stackoverflow中找到了解决方案,我只需要调用它:

validationFunction(*batch)

而不是

validationFunction(batch)

哦亲爱的,我学到的python越多,我就越开始喜欢所有那些详细的样板声明和接口定义来自java的东西。