在theano扫描中连接

时间:2017-06-22 11:03:36

标签: theano

我无法弄清楚以下代码无效的原因:

import theano as th
import theano.tensor as tt

a = tt.vector("a")

res, up = th.scan(
    fn = lambda a : tt.concatenate([a,a]), 
    outputs_info = a,
    n_steps = 2
)

f = th.function(inputs = [a], outputs=res)

f(np.array([1.]))

我希望它能像

一样返回
f = th.function(
    inputs = [a], 
    outputs = tt.concatenate([tt.concatenate([a,a]),tt.concatenate([a,a])])
)

1 个答案:

答案 0 :(得分:0)

theano.scan docstring:

  

......初始状态应具有与输出相同的形状 ...

如果扫描涉及向量上的循环表达式,则该向量的形状不得更改。

lambda a: T.sqr(a)  # OK

lambda a: T.concat([a,a])  # ERROR

原因是,扫描内部使用矩阵在所有时间步骤上存储矢量。如果形状发生变化,矩阵就会变得粗糙。虽然理论上并非不可能实现,但它引入了更多的复杂性和潜在的问题。

所以,是的,scan有点受限。

相关问题