使用索引列表有效地从数组中提取值

时间:2018-04-15 21:03:20

标签: python python-2.7 performance list numpy

给定2D NumPy数组a和存储在index中的索引列表,必须有一种非常有效地提取列表值的方法。使用如下的for循环需要大约5 ms,这对于提取的2000个元素来说似乎非常慢:

import numpy as np
import time

# generate dummy array 
a = np.arange(4000).reshape(1000, 4) 
# generate dummy list of indices
r1 = np.random.randint(1000, size=2000)
r2 = np.random.randint(3, size=2000)
index = np.concatenate([[r1], [r2]]).T

start = time.time()
result = [a[i, j] for [i, j] in index]
print time.time() - start

如何提高提取速度? np.take在这里似乎不合适,因为它会返回一个2D数组而不是一维数组。

2 个答案:

答案 0 :(得分:2)

您可以使用advanced indexing,这基本上是指从index数组中提取行索引和列索引,然后使用它从a中提取值,即a[index[:,0], index[:,1]] - < / p>

%timeit a[index[:,0], index[:,1]]
# 12.1 µs ± 368 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit [a[i, j] for [i, j] in index]
# 2.22 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

答案 1 :(得分:2)

另一个选项是numpy.ravel_multi_index,它可以让您避免手动编制索引。

np.ravel_multi_index(index.T, a.shape)