Python - 如何提高复杂递归函数的效率?

时间:2016-08-11 12:51:44

标签: python performance math recursion optimization

在Mathologer的this video中,除了其他事项之外,在9:25显示3种不同的无限和,当视频突然冻结并且大象叮当响起时,挑战观众发现“表达式的可能值。我写了下面的脚本,以近似的方式逼近三个中的最后一个(即1 + 3 ... / 2 ......):

from decimal import Decimal as D, getcontext  # for accurate results

def main(c):  # faster code when functions defined locally (I think)
    def run1(c):
        c += 1
        if c <= DEPTH:
            return D(1) + run3(c)/run2(c)
        else:
            return D(1)

    def run2(c):
        c += 1
        if c <= DEPTH:
            return D(2) + run2(c)/run1(c)
        else:
            return D(2)

    def run3(c):
        c += 1
        if c <= DEPTH:
            return D(3) + run1(c)/run3(c)
        else:
            return D(3)
    return run1(c)

getcontext().prec = 10  # too much precision isn't currently necessary

for x in range(1, 31):
    DEPTH = x
    print(x, main(0))

现在这对于1&lt; = x&lt; = 20ish完全正常,但在此之后每个结果都会开始永久性。我确实意识到这是由于每个DEPTH级别的函数调用数量呈指数级增长。很明显,我无法在任意点上舒适地计算出该系列。但是,程序减慢的时间对我来说太早了,无法清楚地确定它正在收敛的系列的限制(可能是1.75,但我需要更多DEPTH来确定。)

我的问题是:如何尽可能多地从我的脚本中获取(性能方面)?

我试过了:
 1.找到这个问题的数学解决方案。 (没有匹配的结果)
 2.寻找一般优化递归函数的方法。根据多个来源(例如this),Python默认情况下不会优化尾递归,所以我尝试切换到迭代样式,但我几乎没有立即实现这个想法...

感谢任何帮助!

注意:我知道我可以通过数学方式进行此操作,而不是“强制”限制,但我想让我的程序运行良好,现在我已经开始......

2 个答案:

答案 0 :(得分:1)

您可以将run1run2run3函数的结果存储在数组中,以防止每次都重新计算它们,因为在您的示例中,main(1)调用run1(1)run3(2)run2(2),后者又调用run1(3)run2(3)run1(3)(再次)和run3(3)等等。

你可以看到run1(3)被称为评估两次,而这只会随着数量的增加而变得更糟;如果我们计算每个函数被调用的次数,那就是结果:

   run1 run2 run3
1  1    0    0
2  0    1    1
3  1    2    1
4  3    2    3
5  5    6    5
6  11   10   11
7  21   22   21
8  43   42   43
9  85   86   85
   ...
20 160,000 each (approx.)
   ...
30 160 million each (approx.)

这实际上是Pascal三角形的变体,你可以 以数学方式计算出结果;但是既然你在这里要求进行非数学优化,只需注意调用次数如何呈指数增长;它在每次迭代时翻倍。这更糟糕,因为每次调用都会产生数千个后续调用,值更高,这是你想要避免的。

因此,您要做的是存储每个调用的值,以便不需要调用该函数一千次(并且本身会进行数千次调用)以始终获得相同的结果。这称为memoization

以下是伪代码中的示例解决方案:

before calling main, declare the arrays val1, val2, val3, all of size DEPTH, and fill them with -1

function run1(c) # same thing for run2 and run3
    c += 1
    if c <= DEPTH
        local3 = val3(c)     # read run3(c)
        if local3 is -1      # if run3(c) hasn't been computed yet
            local3 = run3(c) # we compute it
            val3(c) = local3 # and store it into the array
        local2 = val2(c)     # same with run2(c)
        if local2 is -1
            local2 = run2(c)
            val2(c) = local2

        return D(1) + local3/local2 # we use the value we got from the array or from the computation
    else
        return D(1)

这里我使用-1,因为你的函数似乎只生成正数,而-1是空单元格的简单占位符。在其他情况下,你可能不得不使用一个对象作为我下面的Cabu。然而,我认为由于检索对象中的属性与读取数组的成本相比,这会更慢,但我可能错了。无论哪种方式,现在你的代码应该更快,更快,成本为O(n)而不是O(2 ^ n)。

这在技术上允许您的代码以恒定速度永远运行,但递归实际上会导致早期堆栈溢出。在此之前,您可能仍然可以达到数千的深度。

修改:在评论中添加了ShadowRanger,您可以保留原始代码,只需在每个@lru_cache(maxsize=n)run1和{之前添加run2 {1}}函数,其中n是DEPTH之上的两个中的第一个幂(例如,如果深度为25,则为32)。这可能需要使用import指令。

答案 1 :(得分:0)

通过一些记忆,你可以达到堆栈溢出:

from decimal import Decimal as D, getcontext  # for accurate results

def main(c):  # faster code when functions defined locally (I think)
    mrun1 = {}  # store partial results of run1, run2 and run3
                # This have not been done in the as parameter of the
                # run function to be able to reset them easily

    def run1(c):
        if c in mrun1:  # if partial result already computed, return it
            return mrun1[c]

        c += 1
        if c <= DEPTH:
            v = D(1) + run3(c) / run2(c)
        else:
            v = D(1)

        mrun1[c] = v  # else store it and return the value
        return v

    def run2(c):
        if c in mrun2:
            return mrun2[c]

        c += 1
        if c <= DEPTH:
            v = D(2) + run2(c) / run1(c)
        else:
            v = D(2)

        mrun2[c] = v
        return v

    def run3(c):
        if c in mrun3:
            return mrun3[c]

        c += 1
        if c <= DEPTH:
            v = D(3) + run1(c) / run3(c)
        else:
            v = D(3)

        mrun3[c] = v
        return v

    return run1(c)

getcontext().prec = 150  # too much precision isn't currently necessary

for x in range(1, 997):
    DEPTH = x
    print(x, main(0))

如果你超过997,Python会堆叠溢出。