比较(断言相等)两个在unittest中包含numpy数组的复杂数据结构

时间:2013-01-09 21:38:40

标签: python unit-testing numpy

我使用Python的unittest模块,并想检查两个复杂的数据结构是否相等。对象可以是具有各种值的dicts列表:数字,字符串,Python容器(列表/元组/ dicts)和numpy数组。后者是提出问题的原因,因为我不能只做

self.assertEqual(big_struct1, big_struct2)

因为它产生了

ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()

我想我需要为此编写自己的相等测试。它应该适用于任意结构。我目前的想法是一个递归函数:

  • 尝试直接将arg1的当前“节点”与arg2的相应节点进行比较;
  • 如果没有引发异常,则继续(此处也处理“终端”节点/叶子);
  • 如果发现ValueError,则会更深入,直至找到numpy.array;
  • 比较数组(例如like this)。

看起来有点问题的是跟踪两个结构的“对应”节点,但在这里我可能需要zip

问题是:这种方法是否有更好(更简单)的替代方案?也许numpy为此提供了一些工具?如果没有建议的替代方案,我将实施这个想法(除非我有一个更好的想法)并发布作为答案。

P.S。我有一种模糊的感觉,我可能已经看到了解决这个问题的问题,但我现在找不到它。

P.P.S。另一种方法是遍历结构并将所有numpy.array转换为列表的函数,但是这更容易实现吗?对我来说似乎也一样。


编辑:子类化numpy.ndarray听起来非常有前景,但显然我没有将比较的两个方面硬编码到测试中。但其中一个确实是硬编码的,所以我可以:

  • 使用numpy.array;
  • 的自定义子类填充它
  • jterrace's answer;
  • 中将isinstance(other, SaneEqualityArray)更改为isinstance(other, np.ndarray)
  • 在比较中始终将其用作LHS。

我在这方面的问题是:

  1. 它会起作用吗(我的意思是,这对我来说听起来不错,但也许一些棘手的边缘案例无法正确处理)?我的自定义对象在递归等式检查中是否总是以LHS结束?正如我所期望的那样?
  2. 再次,是否有更好的方法(假设我至少得到一个具有真实numpy数组的结构)。

  3. 编辑2 :我试了一下,(貌似)工作实现显示在this answer

7 个答案:

答案 0 :(得分:12)

会有评论,但它太长了......

有趣的是,您无法使用==来测试数组是否相同我建议您使用np.testing.assert_array_equal

  1. 检查dtype,shape等,
  2. 对于(float('nan') == float('nan')) == False(正常的python序列==的整齐的小数学运算没有失败,有一种更有趣的方法来忽略这个有时,因为它使用{ {1}}执行(对于NaNs不正确)PyObject_RichCompareBool快速检查(当然测试是完美的)...
  3. 还有is因为如果你进行实际计算并且你通常希望几乎相同的值,浮点相等可能变得非常棘手,因为这些值可以变为硬件依赖或可能是随机的取决于你对它们做了什么。
  4. 我几乎建议尝试使用pickle进行序列化,如果你想要这种疯狂嵌套的东西,但这是非常严格的(当然,第3点完全被破坏),例如你的数组的内存布局并不重要,但是对其序列化很重要。

答案 1 :(得分:8)

assertEqual函数将调用对象的__eq__方法,这些方法应该针对复杂数据类型进行递归。例外是numpy,它没有合理的__eq__方法。使用numpy subclass from this question,您可以恢复相等行为的健全性:

import copy
import numpy
import unittest

class SaneEqualityArray(numpy.ndarray):
    def __eq__(self, other):
        return (isinstance(other, SaneEqualityArray) and
                self.shape == other.shape and
                numpy.ndarray.__eq__(self, other).all())

class TestAsserts(unittest.TestCase):

    def testAssert(self):
        tests = [
            [1, 2],
            {'foo': 2},
            [2, 'foo', {'d': 4}],
            SaneEqualityArray([1, 2]),
            {'foo': {'hey': SaneEqualityArray([2, 3])}},
            [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
             SaneEqualityArray([5, 6]), 34]
        ]
        for t in tests:
            self.assertEqual(t, copy.deepcopy(t))

if __name__ == '__main__':
    unittest.main()

此测试通过。

答案 2 :(得分:7)

所以jterrace所说明的想法似乎对我有所改变:

class SaneEqualityArray(np.ndarray):
    def __eq__(self, other):
        return (isinstance(other, np.ndarray) and self.shape == other.shape and 
            np.allclose(self, other))

就像我说的那样,带有这些对象的容器应该在等式检查的左侧。我从现有的SaneEqualityArray创建了numpy.ndarray个对象:

SaneEqualityArray(my_array.shape, my_array.dtype, my_array)

根据ndarray构造函数签名:

ndarray(shape, dtype=float, buffer=None, offset=0,
        strides=None, order=None)

此类在测试套件中定义,仅用于测试目的。等式检查的RHS是被测试函数返回的实际对象,包含真实的numpy.ndarray对象。

P.S。感谢到目前为止发布的两个答案的作者,他们都非常有帮助。如果有人发现这种方法存在任何问题,我将非常感谢您的反馈。

答案 3 :(得分:2)

我将定义自己的assertNumpyArraysEqual()方法,该方法显式地进行您要使用的比较。这样,您的生产代码保持不变,但您仍然可以在单元测试中做出合理的断言。确保在包含__unittest = True的模块中定义它,以便它不会包含在堆栈跟踪中:

import numpy
__unittest = True

def assertNumpyArraysEqual(self, other):
    if self.shape != other.shape:
        raise AssertionError("Shapes don't match")
    if not numpy.allclose(self, other)
        raise AssertionError("Elements don't match!")

答案 4 :(得分:1)

检查@Autowired "raises an AssertionError if two items are not equal up to desired precision",例如:

numpy.testing.assert_almost_equal

答案 5 :(得分:1)

我遇到了同样的问题,并开发了一个函数来比较基于为对象创建固定哈希的相等性。这样做的另一个好处是,您可以通过将对象的哈希与代码中已修复的哈希值进行比较来测试对象是否符合预期。

代码(一个独立的python文件,is here)。有两个函数:fixed_hash_eq,它解决了你的问题,compute_fixed_hash,它从结构中产生一个哈希。 Tests are here

这是一个测试:

obj1 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj2 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3[2]['b'][4] = 0
assert fixed_hash_eq(obj1, obj2)
assert not fixed_hash_eq(obj1, obj3)

答案 6 :(得分:0)

在@dbw(感谢)的基础上,在test-case子类中插入的以下方法对我来说效果很好:

 def assertNumpyArraysEqual(self,this,that,msg=''):
    '''
    modified from http://stackoverflow.com/a/15399475/5459638
    '''
    if this.shape != that.shape:
        raise AssertionError("Shapes don't match")
    if not np.allclose(this,that):
        raise AssertionError("Elements don't match!")

我在测试用例方法中将其称为self.assertNumpyArraysEqual(this,that),并且像魅力一样工作。

相关问题