如何比较两个ctypes对象的相等性?

时间:2014-06-19 12:48:06

标签: python ctypes

import ctypes as ct

class Point(ct.Structure):
    _fields_ = [
        ('x', ct.c_int),
        ('y', ct.c_int),
    ]

p1 = Point(10, 10)
p2 = Point(10, 10)

print p1 == p2 # => False

等值运算符' == '在上面的小案例中给出 False 。有没有直截了当的方法?

编辑:

这是一个稍微改进的版本(基于接受的答案),它也可以处理嵌套数组:

import ctypes as ct

class CtStruct(ct.Structure):

    def __eq__(self, other):
        for field in self._fields_:
            attr_name = field[0]
            a, b = getattr(self, attr_name), getattr(other, attr_name)
            is_array = isinstance(a, ct.Array)
            if is_array and a[:] != b[:] or not is_array and a != b:
                return False
        return True

    def __ne__(self, other):
        for field in self._fields_:
            attr_name = field[0]
            a, b = getattr(self, attr_name), getattr(other, attr_name)
            is_array = isinstance(a, ct.Array)
            if is_array and a[:] != b[:] or not is_array and a != b:
                return True
        return False

class Point(CtStruct):
    _fields_ = [
        ('x', ct.c_int),
        ('y', ct.c_int),
        ('arr', ct.c_int * 2),
    ]

p1 = Point(10, 20, (30, 40))
p2 = Point(10, 20, (30, 40))

print p1 == p2 # True

2 个答案:

答案 0 :(得分:4)

创建一个MyCtStructure类,然后它的所有子类都不需要实现__eq__& __ne__。 在你的案例中定义eq将不再是一个单调乏味的工作。

import ctypes as ct
class MyCtStructure(ct.Structure):

    def __eq__(self, other):
        for fld in self._fields_:
            if getattr(self, fld[0]) != getattr(other, fld[0]):
                return False
        return True

    def __ne__(self, other):
        for fld in self._fields_:
            if getattr(self, fld[0]) != getattr(other, fld[0]):
                return True
        return False

class Point(MyCtStructure):
    _fields_ = [
        ('x', ct.c_int),
        ('y', ct.c_int),
    ]


p1 = Point(10, 11)
p2 = Point(10, 11)

print p1 == p2

答案 1 :(得分:2)

p1.x == p2.x and p1.y = p2.y将适用于这个微不足道的案例。

您还可以在__eq__()课程中实施__ne__()Point方法:

class Point(ct.Structure):
    _fields_ = [
        ('x', ct.c_int),
        ('y', ct.c_int),
    ]
    def __eq__(self, other):
        return (self.x == other.x) and (self.y == other.y)
    def __ne__(self, other):
        return not self.__eq__(other)

>>> p1 = Point(10, 10)
>>> p2 = Point(10, 10)
>>> p3 = Point(10, 66)
>>> p1 == p2
True
>>> p1 != p2
False
>>> p1 == p3
False
>>> p1 != p3
True