讨厌的生成器错误

时间:2015-08-18 17:19:52

标签: python generator

此错误的原始上下文是一段太大而无法在此类问题中发布的代码。我不得不将这段代码缩小到仍然显示错误的最小片段。这就是为什么下面显示的代码看起来有点奇怪。

在下面的代码中,类Foo可能被视为一种令人费解的方式来获取类似xrange的内容。

class Foo(object):
    def __init__(self, n):
        self.generator = (x for x in range(n))

    def __iter__(self):
        for e in self.generator:
            yield e

事实上,Foo似乎与xrange非常相似:

for c in Foo(3):
    print c
# 0
# 1
# 2

print list(Foo(3))
# [0, 1, 2]

现在,Bar的子类Foo只添加__len__方法:

class Bar(Foo):
    def __len__(self):
        return sum(1 for _ in self.generator)
Bar循环中使用时,

Foo的行为与for相似:

for c in Bar(3):
    print c
# 0
# 1
# 2

BUT:

print list(Bar(3))
# []

我的猜测是,在list(Bar(3))的评估中,__len__的{​​{1}}方法被调用,从而耗尽了生成器。

(如果此猜测正确,则无需调用Bar(3);毕竟Bar(3).__len__会生成正确的结果,即使list(Foo(3))没有Foo方法。)

这种情况很烦人:__len__list(Foo(3))没有充分理由产生不同的结果。

是否可以修复list(Bar(3))(当然,没有摆脱其Bar方法)__len__返回list(Bar(3))

2 个答案:

答案 0 :(得分:6)

你的问题是Foo的行为与xrange的行为不同:xrange会在你每次询问iter方法时给你一个新的迭代器,而Foo总是给你一个相同的意思,这意味着一旦它耗尽了对象也是:

>>> a = Foo(3)
>>> list(a)
[0, 1, 2]
>>> list(a)
[]
>>> a = range(3)
>>> list(a)
[0, 1, 2]
>>> list(a)
[0, 1, 2]

我可以通过向您的方法添加间谍来轻松确认__len__调用list方法:

class Bar(Foo):
    def __len__(self):
        print "LEN"
        return sum(1 for _ in self.generator)

(我在print "ITERATOR"中添加了Foo.__iter__。它产生:

>>> list(Bar(3))
LEN
ITERATOR
[]

我只能想象两个解决方法:

  1. 我的首选:在__iter__Foo每次调用时返回一个新的迭代器,以模仿xrange

    class Foo(object):
        def __init__(self, n):
            self.n = n
    
        def __iter__(self):
            print "ITERATOR"
            return ( x for x in range(self.n))
    
    class Bar(Foo):
        def __len__(self):
            print "LEN"
            return sum(1 for _ in self.generator)
    

    我们得到了正确的答案:

    >>> list(Bar(3))
    ITERATOR
    LEN
    ITERATOR
    [0, 1, 2]
    
  2. 替代方法:将len更改为不调用迭代器并让Foo不变:

    class Bar(Foo):
        def __init__(self, n):
            self.len  = n
            super(Bar, self).__init__(n)
        def __len__(self):
            print "LEN"
            return self.len
    

    我们再次得到:

    >>> list(Bar(3))
    LEN
    ITERATOR
    [0, 1, 2]
    

    但是,一旦第一个迭代器到达终点,Foo和Bar对象就会耗尽。

  3. 但我必须承认,我不知道你真正班级的背景......

答案 1 :(得分:2)

这种行为可能很烦人,但实际上它是可以理解的。在内部,list只是一个数组,数组是固定大小的数据结构。这样做的结果是,如果您的list大小为n并且您想要添加额外的项目以达到n+1,则必须创建一个全新的数组并完全复制旧的一个到新的。实际上,您的list.append(x)现在是O(n)操作,而不是常规O(1)

为防止这种情况发生,list()会尝试获取输入的大小,以便猜出数组需要的大小。

因此,针对此问题的一个解决方案是使用iter强制猜测:

list(iter(Bar(3)))