Numpy矩阵乘法与自定义点积

时间:2013-10-09 17:02:48

标签: numpy vectorization matrix-multiplication dot-product

默认矩阵乘法计算为

c[i,j] = sum(a[i,k] * b[k,j])

我正在尝试使用自定义公式代替点积来获取

c[i,j] = sum(a[i,k] == b[k,j])

在numpy中有一种有效的方法吗?

1 个答案:

答案 0 :(得分:4)

您可以使用广播:

c = sum(a[...,np.newaxis]*b[np.newaxis,...],axis=1)  # == np.dot(a,b)

c = sum(a[...,np.newaxis]==b[np.newaxis,...],axis=1)

我在newaxis中添加了b,只是明确了该数组的扩展方式。还有其他方法可以为数组添加维度(重塑,重复等),但效果是一样的。将ab展开为相同的形状,以逐元素(或==)进行,然后在正确的轴上求和。

相关问题