有没有办法检查函数在python中是否递归?

时间:2016-04-16 08:59:56

标签: python recursion python-internals

我想为练习编写测试功能,以确保正确实现功能 所以我想知道,有没有办法,给定一个函数“foo”,来检查它是否是递归实现的?
如果它封装了一个递归函数并使用它,它也会计数。例如:

const std::string base64_padding[] = {"", "==","="};

std::string base64EncodeText(std::string text) {
    using namespace boost::archive::iterators;
    typedef std::string::const_iterator iterator_type;
    typedef base64_from_binary<transform_width<iterator_type, 6, 8> > base64_enc;
    std::stringstream ss;
    std::copy(base64_enc(text.begin()), base64_enc(text.end()), ostream_iterator<char>(ss));
    ss << base64_padding[text.size() % 3];
    return ss.str();
}

std::string base64EncodeData(std::vector<uint8_t> data) {
    using namespace boost::archive::iterators;
    typedef std::vector<uint8_t>::const_iterator iterator_type;
    typedef base64_from_binary<transform_width<iterator_type, 6, 8> > base64_enc;
    std::stringstream ss;
    std::copy(base64_enc(data.begin()), base64_enc(data.end()), ostream_iterator<char>(ss));
    ss << base64_padding[data.size() % 3];
    return ss.str();
}

这也应该被认为是递归的。
请注意,我想使用外部测试功能来执行此检查。不改变函数的原始代码。

3 个答案:

答案 0 :(得分:6)

解决方案:

from bdb import Bdb
import sys

class RecursionDetected(Exception):
    pass

class RecursionDetector(Bdb):
    def do_clear(self, arg):
        pass

    def __init__(self, *args):
        Bdb.__init__(self, *args)
        self.stack = set()

    def user_call(self, frame, argument_list):
        code = frame.f_code
        if code in self.stack:
            raise RecursionDetected
        self.stack.add(code)

    def user_return(self, frame, return_value):
        self.stack.remove(frame.f_code)

def test_recursion(func):
    detector = RecursionDetector()
    detector.set_trace()
    try:
        func()
    except RecursionDetected:
        return True
    else:
        return False
    finally:
        sys.settrace(None)

示例用法/测试:

def factorial_recursive(x):
    def inner(n):
        if n == 0:
            return 1
        return n * factorial_recursive(n - 1)
    return inner(x)


def factorial_iterative(n):
    product = 1
    for i in xrange(1, n+1):
        product *= i
    return product

assert test_recursion(lambda: factorial_recursive(5))
assert not test_recursion(lambda: factorial_iterative(5))
assert not test_recursion(lambda: map(factorial_iterative, range(5)))
assert factorial_iterative(5) == factorial_recursive(5) == 120

基本上test_recursion接受一个没有参数的callable,调用它,并返回True如果在执行该callable期间的任何时候,相同的代码在堆栈中出现两次,False除此以外。我认为它可能会发现这不是OP想要的。例如,如果相同的代码在特定时刻出现在堆栈中10次,则可以轻松修改它。

答案 1 :(得分:0)

我还没有为亚历克斯(Alex)的答案进行验证(尽管我认为是可行的,并且比我将要提出的要好得多),但是如果您想要比它简单(或更小)的东西,可以简单地使用sys.getrecursionlimit()手动将其出错,然后在函数中进行检查。例如,这是我为自己的递归验证编写的:

import sys

def is_recursive(function, *args):
  try:
    # Calls the function with arguments
    function(sys.getrecursionlimit()+1, *args)
  # Catches RecursionError instances (means function is recursive)
  except RecursionError:
    return True
  # Catches everything else (may not mean function isn't recursive,
  # but it means we probably have a bug somewhere else in the code)
  except:
    return False
  # Return False if it didn't error out (means function isn't recursive)
  return False

尽管它可能不太优雅(在某些情况下更是错误),但它比Alex的代码小了 ,并且在大多数情况下都可以正常工作。这里的主要缺点是,使用这种方法时,您要使计算机执行每次递归操作,直到功能达到递归限制为止。我建议使用sys.setrecursionlimit()临时更改递归限制,同时使用此代码来最大程度地减少处理递归所需的时间,例如:

sys.setrecursionlimit(10)
if is_recursive(my_func, ...):
  # do stuff
else:
  # do other stuff
sys.setrecursionlimit(1000) # 1000 is the default recursion limit

答案 2 :(得分:0)

from inspect import stack

already_called_recursively = False


def test():
    global already_called_recursively
    function_name = stack()[1].function
    if not already_called_recursively:
        already_called_recursively = True
        print(test())  # One recursive call, leads to Recursion Detected!

    if function_name == test.__name__:
        return "Recursion detected!"
    else:
        return "Called from {}".format(function_name)


print(test())  # Not Recursion, "father" name: "<module>"


def xyz():
    print(test())  # Not Recursion, "father" name: "xyz"


xyz()

输出是

Recursion detected!
Called from <module>
Called from xyz

我使用全局变量 already_called_recursively 来确保我只调用它一次,正如你所看到的,在递归时它显示“检测到递归”,因为“父亲”名称与当前名称相同函数,这意味着我从同一个函数又名递归调用它。

其他打印是模块级调用,以及 xyz 内部的调用。

希望有帮助:D