为什么这个Theano.scan函数不会抛出错误?

时间:2015-10-21 15:08:47

标签: python theano

我有来自Theano文档的这段代码来计算多项式

import numpy
import theano
import theano.tensor as T


theano.config.warn.subtensor_merge_bug = False

coefficients = theano.tensor.vector("coefficients")

x = T.scalar("x")

max_coefficients_supported = 1000
# Generate the components of the polynomial
full_range=theano.tensor.arange(max_coefficients_supported)

components, updates = theano.scan(fn=lambda coeff, power, free_var:
                                            coeff * (free_var ** power),
                                    outputs_info=None,
                                    sequences=[coefficients, full_range],
                                    non_sequences=x)

polynomial = components.sum()

calculate_polynomial = theano.function(inputs=[coefficients, x],
outputs=polynomial)

test_coeff = numpy.asarray([1, 0, 2], dtype=numpy.float32)
print calculate_polynomial(test_coeff, 3)

我认为它正在做的是计算1 *(3 ^ 0)+ 0 *(3 ^ 1)+ 2 *(3 ^ 2),这是我的输出19.0。

我的问题是full_range是一个包含1000个元素的[0,1,2,...,998,999]数组。这被用作Theano扫描参数中的序列。

为什么不扫描扫描功能中的序列参数。它有两个序列作为输入,即系数和full_range。这意味着在扫描函数的每次迭代中,它从每个输入中选择一个值并运行代码

coeff * (free_var ** power)

其中free_var是3.当coeff中的元素数量仅为3而幂的元素数量(即full_range输入)具有1000个元素时,代码如何成功运行?我在这里遗漏了一些东西。

由于

1 个答案:

答案 0 :(得分:1)

当多个序列传递给scan时,它将迭代与最短序列中的元素一样多次。

The documentation说:

  

给定多个不均匀长度的序列,扫描将截断到   最短的。这样就可以安全地通过一个非常长的范围   我们需要做到一般性,因为一个范围必须有它的长度   在创建时指定。

相关问题