在Python中使用乘法运算符的Matrix-matrix-multiplication:如何?

时间:2019-06-10 13:16:14

标签: python operator-overloading matrix-multiplication

我想在Python中重载乘法运算符。我真正想做的是在Python中使用*运算符执行矩阵-矩阵乘法。不允许使用Numpy。

import math
class Vec4():
    def __init__(self, x = 0, y = 0, z = 0, w = 0):
        """Constructor for Vec4
        DO NOT MODIFY THIS METHOD"""
        self.values = [x,y,z,w]

    def __str__(self):
        """Returns the vector as a string representation
        DO NOT MODIFY THIS METHOD"""
        toReturn = ''
        if self is None: return '0.00 0.00 0.00 0.00'
        for c in range(0,4):
                toReturn += "%.2f" % self.values[c]
                if c != 3:
                    toReturn += ' '
        return toReturn

class Matrix4():
    def __init__(self, row1=None, row2=None, row3=None, row4=None):
        """Constructor for Matrix4
        DO NOT MODIFY THIS METHOD"""
        if row1 is None: row1 = Vec4()
        if row2 is None: row2 = Vec4()
        if row3 is None: row3 = Vec4()
        if row4 is None: row4 = Vec4()
        self.m_values = [row1,row2,row3,row4]

    def __str__(self):
        """Returns a string representation of the matrix
        DO NOT MODIFY THIS METHOD"""
        toReturn = ''
        if self is None: return '0.00 0.00 0.00 0.00\n0.00 0.00 0.00 0.00\n0.00 0.00 0.00 0.00\n0.00 0.00 0.00 0.00'
        for r in range(0,4):
            for c in range(0,4):
                toReturn += "%.2f" % self.m_values[r].values[c]
                if c != 3:
                    toReturn += ' '
            toReturn += '\n'
        return toReturn

   def __matmul__(self, m):
        x = self.m_values[0].values[0]*m.m_values[0].values[0]+self.m_values[0].values[1]*m.m_values[1]*values[0]+self.m_values[0].values[2]*m.m_values[2].values[0]+self.m_values[0].values[3]*m.m_values[3].values[0]
        y = self.m_values[1].values[0]*m.m_values[0].values[1]+self.m_values[1].values[1]*m.m_values[1]*values[1]+self.m_values[1].values[2]*m.m_values[2].values[1]+self.m_values[1].values[3]*m.m_values[3].values[1]
        z = self.m_values[2].values[0]*m.m_values[0].values[2]+self.m_values[2].values[1]*m.m_values[1]*values[2]+self.m_values[2].values[2]*m.m_values[2].values[2]+self.m_values[2].values[3]*m.m_values[3].values[2]
        w = self.m_values[3].values[0]*m.m_values[0].values[3]+self.m_values[3].values[1]*m.m_values[1]*values[3]+self.m_values[3].values[2]*m.m_values[2].values[3]+self.m_values[3].values[3]*m.m_values[3].values[3]
        return Matrix4()

没有得到像下面这样的结果:

A = Matrix4(Vec4(1, 0, 0, 0),
            Vec4(0, 1, 0, 0),
            Vec4(0, 0, 1, 0),
            Vec4(0, 0, 0, 1))

B = Matrix4(Vec4(1,2,3,4),
            Vec4(1,2,3,4),
            Vec4(1,2,3,4),
            Vec4(1,2,3,4))

print(A * B)

输出应为:

1.00 2.00 3.00 4.00
1.00 2.00 3.00 4.00
1.00 2.00 3.00 4.00
1.00 2.00 3.00 4.00

但是在我的情况下,它会导致错误:

Traceback (most recent call last):
  File "<pyshell#14>", line 1, in <module>
    print(A*B)
  File "C:\Users\xxx\Downloads\Download-Stuff\Gmail\TransformMatrix.py", line 45, in __mul__
    x = self.m_values[0].values[0]*v.values[0]+self.m_values[1].values[0]*v.values[1]+self.m_values[2].values[0]*v.values[2]+self.m_values[3].values[0]*v.values[3]
AttributeError: 'Matrix4' object has no attribute 'values'

我在做什么错了?

感谢您的提前帮助。

2 个答案:

答案 0 :(得分:1)

您必须通过定义def __mul__(self, m):

重载

答案 1 :(得分:1)

咨询https://en.wikipedia.org/wiki/Matrix_multiplication之后, 我首先实现了dot_product(),并试图获得更通用的解决方案:

import math
class Vec4():
    def __init__(self, x = 0, y = 0, z = 0, w = 0):
        """Constructor for Vec4
        DO NOT MODIFY THIS METHOD"""
        self.values = [x,y,z,w]

    def __str__(self):
        """Returns the vector as a string representation
        DO NOT MODIFY THIS METHOD"""
        toReturn = ''
        if self is None: return '0.00 0.00 0.00 0.00'
        for c in range(0,4):
                toReturn += "%.2f" % self.values[c]
                if c != 3:
                    toReturn += ' '
        return toReturn

class Matrix4():
    def __init__(self, row1=None, row2=None, row3=None, row4=None):
        """Constructor for Matrix4
        DO NOT MODIFY THIS METHOD"""
        if row1 is None: row1 = Vec4()
        if row2 is None: row2 = Vec4()
        if row3 is None: row3 = Vec4()
        if row4 is None: row4 = Vec4()
        self.m_values = [row1,row2,row3,row4]

    def __str__(self):
        """Returns a string representation of the matrix
        DO NOT MODIFY THIS METHOD"""
        toReturn = ''
        if self is None: return '0.00 0.00 0.00 0.00\n0.00 0.00 0.00 0.00\n0.00 0.00 0.00 0.00\n0.00 0.00 0.00 0.00'
        for r in range(0,4):
            for c in range(0,4):
                toReturn += "%.2f" % self.m_values[r].values[c]
                if c != 3:
                    toReturn += ' '
            toReturn += '\n'
        return toReturn

    def get_column(self, j):
        return [vec.values[j] for vec in self.m_values]

    def get_row(self, i):
        return self.m_values[i].values

    def dot_product(self, m, i, j):
        return sum([x * y for x, y in zip(self.get_row(i), \
                                      m.get_column(j))])

    def shape(self):
        return len(self.m_values), len(self.m_values[0].values)

    def __mul__(self, mat):
        # m = len(self.m_values[0].values)
        n = self.shape()[0]
        p = mat.shape()[1]
        return Matrix4(*[Vec4(*[self.dot_product(mat, i, j)  for j in range(p)])  for i in range(n)])


A = Matrix4(Vec4(1, 0, 0, 0),
            Vec4(0, 1, 0, 0),
            Vec4(0, 0, 1, 0),
            Vec4(0, 0, 0, 1))

B = Matrix4(Vec4(1,2,3,4),
            Vec4(1,2,3,4),
            Vec4(1,2,3,4),
            Vec4(1,2,3,4))

print(A * B)

# 1.00 2.00 3.00 4.00
# 1.00 2.00 3.00 4.00
# 1.00 2.00 3.00 4.00
# 1.00 2.00 3.00 4.00

常规解决方案

用于任意大小的矩阵。 并且使用__repr__(),因此不必总是打印print()来查看 字符串表示形式。

class Vec4():
    def __init__(self, *args):
        """Generalized constructor for Vec4"""
        self.values = args

    def __str__(self):
        """Returns the vector as a string representation"""
        if self.values == []: 
            return "Empy Vector of class Vec4"
        else:
            return ' '.join(["{0:.2f}".format(c) for c in self.values])

    def __repr__(self):
        return self.__str__()

class Matrix4():
    def __init__(self, *args):
        """Constructor for Matrix4"""
        self.values = args

    def __str__(self):
        """Returns a string representation of the matrix"""
        if self.values == []:
            return "Empty Matrix of class Matrix4"
        else:
            return '\n'.join([str(v) for v in self.values])

    def __repr__(self):
        return self.__str__()

    def get_column(self, j):
        return [vec.values[j] for vec in self.values]

    def get_row(self, i):
        return self.values[i].values

    def dot_product(self, m, i, j):
        return sum([x * y for x, y in zip(self.get_row(i), \
                                      m.get_column(j))])

    def shape(self):
        return len(self.values), len(self.values[0].values)

    def __mul__(self, mat):
        # m = len(self.values[0].values)
        n = self.shape()[0]
        p = mat.shape()[1]
        return Matrix4(*[Vec4(*[self.dot_product(mat, i, j)  for j in range(p)])  for i in range(n)])
相关问题