排序的numpy数组的交集

时间:2017-10-04 19:01:09

标签: numpy

我有一个排序的numpy数组列表。计算这些数组的排序交集的最有效方法是什么?

在我的应用程序中,我希望数组的数量小于10 ^ 4,我希望各个数组的长度小于10 ^ 7,我希望交集的长度接近p * N,其中N是最大阵列的长度,其中0.99

快速而肮脏的方法是反复调用numpy.intersect1d()。这似乎效率低下,但intersect1d()没有利用数组排序的事实。

2 个答案:

答案 0 :(得分:1)

由于intersect1d每次排序数组,因此效率很低。

在这里,您必须将交叉点和每个样本扫描在一起以构建新的交叉点,这可以在线性时间内完成,从而保持顺序。

这种任务通常必须通过低级别例程手动调整。

这是使用numba

执行此操作的方法
from numba import njit
import numpy as np

@njit
def drop_missing(intersect,sample):
    i=j=k=0
    new_intersect=np.empty_like(intersect)
    while i< intersect.size and j < sample.size:
            if intersect[i]==sample[j]: # the 99% case
                new_intersect[k]=intersect[i]
                k+=1
                i+=1
                j+=1
            elif intersect[i]<sample[j]:
                i+=1
            else : 
                j+=1
    return new_intersect[:k]  

现在的样本:

n=10**7
ref=np.random.randint(0,n,n)  
ref.sort()

def perturbation(sample,k):
    rands=np.random.randint(0,n,k-1)
    rands.sort()
    l=np.split(sample,rands)
    return np.concatenate([a[:-1] for a in l])

samples=[perturbation(ref,100) for  _ in range(10)] #similar samples 

运行10个样本

def find_intersect(samples):
    intersect=samples[0]
    for sample in samples[1:]:
        intersect=drop_missing(intersect,sample)
    return intersect                

In [18]: %time u=find_intersect(samples)
Wall time: 307 ms

In [19]: len(u)
Out[19]: 9999009     

这种方式似乎可以在大约5分钟内完成,超出加载时间。

答案 1 :(得分:0)

几个月前,我为此目的编写了一个基于C ++的python扩展。 package被称为sortednp,可通过pip使用。可以使用

来计算多个排序的numpy数组的交集,例如abc
import sortednp as snp
i = snp.kway_intersect(a, b, c)

默认情况下,它使用exponential search在内部推进数组索引,这在交点较小的情况下非常快。就您而言,如果将algorithm=snp.SIMPLE_SEARCH添加到方法调用中,则可能会更快。

相关问题