在大型numpy数组

时间:2018-06-08 15:06:59

标签: python-2.7 performance numpy scipy

我有一个非常大的2D numpy数组(~5e8值)。我使用scipy.ndimage.label标记了该数组然后我想找到包含每个标签的扁平数组的随机索引。我可以这样做:

import numpy as np
from scipy.ndimage import label

base_array = np.random.randint(0, 5, (100000, 5000))
labeled_array, nlabels = label(base_array)
for label_num in xrange(1, nlabels+1):
    indices = np.where(labeled_array.flat == label_num)[0]
    index = np.random.choice(indices)

但是,这个数组的数据很慢。我还尝试将np.where替换为:

indices = np.argwhere(labeled_array.flat == label).squeeze()

发现它变慢了。我怀疑布尔掩码是缓慢的部分。无论如何要加快速度,或者更好的方法来做到这一点。我将在我的实际应用程序中说,数组相当稀疏,填充量约为25%,但我没有使用scipy的稀疏数组函数。

1 个答案:

答案 0 :(得分:1)

您怀疑为每个标签单独屏蔽是否昂贵是正确的,因为无论您如何操作,屏蔽将始终为O(n)。

我们可以通过标签进行调整,然后从每个相同标签的块中随机挑选来规避这一点。

由于标签是整数范围,我们可以通过使用scipy中提供的一些稀疏矩阵机制来使argsort比np.argsort便宜。

由于我的机器没有大量的ram,我不得不缩小你的例子(因子4)。然后它会在大约5秒内运行。

import numpy as np
from scipy.ndimage import label
from scipy import sparse

def multi_randint(bins):
    """draw one random int from each range(bins[i], bins[i+1])"""
    high = np.diff(bins)
    n = high.size
    pick = np.random.randint(0, 1<<30, (n,))
    reject = np.flatnonzero(pick + (1<<30) % high >= (1<<30))
    while reject.size:
        npick = np.random.randint(0, 1<<30, (reject.size,))
        rejrej = npick + (1<<30) % sizes[reject] >= (1<<30)
        pick[reject] = npick
        reject = reject[rejrej]
    return bins[:-1] + pick % high

# build mock data, note that I had to shrink by 4x b/c memory
base_array = np.random.randint(0, 5, (50000, 2500), dtype=np.int8)
labeled_array, nlabels = label(base_array)

# build auxiliary sparse matrix
h = sparse.csr_matrix(
    (np.ones(labeled_array.size, bool), labeled_array.ravel(),
     np.arange(labeled_array.size+1, dtype=np.int32)),
    (labeled_array.size, nlabels+1))
# conversion to csc argsorts the labels (but cheaper than argsort)
h = h.tocsc()
# draw
result = h.indices[multi_randint(h.indptr)]

# check result
assert len(set(labeled_array.ravel()[result])) == nlabels+1
相关问题