遍历numpy数组以有效测试多个元素

时间:2019-01-30 00:46:25

标签: python numpy

我有以下代码遍历名为“ m”的2d numpy数组。它工作非常慢。如何使用numpy函数转换此代码,从而避免使用for循环?

pairs = []
for i in range(size):
    for j in range(size):
        if(i >= j):
            continue
        if(m[i][j] + m[j][i] >= 0.75):
            pairs.append([i, j, m[i][j] + m[j][i]])

2 个答案:

答案 0 :(得分:4)

一种优化代码的方法是避免比较if (i >= j)。要仅遍历数组的下三角而不进行比较,必须使内部循环以最外部循环的i的值开始。这样,您就可以避免进行size x size if比较。

import numpy as np
size = 5000
m = np.random.rand(size, size)
pairs = []


for i in range(size):
    for j in range(i , size):

        if(m[i][j] + m[j][i] >= 0.75):
            pairs.append([i, j, m[i][j] + m[j][i]])

答案 1 :(得分:4)

您可以使用NumPy使用向量化方法。这个想法是:

  • 首先初始化一个矩阵m,然后创建与m+m.T等效的m[i][j] + m[j][i],其中m.T是矩阵转置并命名为summ
  • np.triu (summ)返回矩阵的上三角部分(这等效于在代码中使用continue来忽略下三角部分)。这样可以避免在代码中使用显式if(i >= j):。在这里,您必须使用k=1来排除对角线元素。默认情况下,k=0也包括对角线元素。
  • 然后使用np.argwhere获得点的索引,其中总和m+m.T等于0.75
  • 然后将这些索引和相应的值存储在列表中,以供以后处理/打印。

可验证的示例(使用小的3x3随机数据集)

import numpy as np

np.random.seed(0)
m = np.random.rand(3,3)
summ = m + m.T

index = np.argwhere(np.triu(summ, k=1)>=0.75)

pairs = [(x,y, summ[x,y]) for x,y in index]
print (pairs)
# # [(0, 1, 1.2600725493693163), (0, 2, 1.0403505873343364), (1, 2, 1.537667113848736)]

进一步的性能改进

我刚刚想出了一种更快的方法来生成最终的pairs列表,从而避免显式的for循环

pairs = list(zip(index[:, 0], index[:, 1], summ[index[:,0], index[:,1]]))
相关问题