加快阵列中所有可能对之间的距离

时间:2016-02-26 19:13:18

标签: python python-2.7 numpy while-loop

我有几个(~10 ^ 10)点的x,y,z坐标数组(这里只显示了5个)

a= [[ 34.45  14.13   2.17]
    [ 32.38  24.43  23.12]
    [ 33.19   3.28  39.02]
    [ 36.34  27.17  31.61]
    [ 37.81  29.17  29.94]]

我想创建一个新数组,其中只包含距离列表中所有其他点至少距离d的点。我使用while循环编写了一个代码,

 import numpy as np
 from scipy.spatial import distance 

 d=0.1 #or some distance 
 i=0
 selected_points=[]
 while i < len(a):
          interdist=[]  
          j=i+1
          while j<len(a):
              interdist.append(distance.euclidean(a[i],a[j]))
              j+=1

          if all(dis >= d for dis in interdist):
              np.array(selected_points.append(a[i]))
          i+=1

这样可行,但执行此计算需要很长时间。我在某处读到while循环非常慢。

我想知道是否有人对如何加快计算有任何建议。

编辑:虽然我的目标是找到距离所有其他距离至少有一段距离的粒子,但我只是意识到我的代码中有一个严重的缺陷,让我们#39 ;假设我有3个粒子,我的代码执行以下操作,对于i的第一次迭代,它计算距离1->21->3,让我们说{{1小于阈值距离1->2,因此代码会抛弃粒子d。对于1的下一次迭代,它只会i,并且让我们发现它大于2->3,因此它会保留粒子d,但这是错的!因为2也应该与粒子2一起丢弃。 @svohara的解决方案是正确的!

4 个答案:

答案 0 :(得分:5)

对于大数据集和低维点(例如三维数据),有时使用空间索引方法有很大好处。低维数据的一个流行选择是k-d树。

策略是索引数据集。然后使用相同的数据集查询索引,以返回每个点的2个最近邻居。第一个最近邻居总是点本身(dist = 0),所以我们真的想知道下一个最近点(第二个最近邻居)有多远。对于那些2-NN> 1的点。阈值,你有结果。

from scipy.spatial import cKDTree as KDTree
import numpy as np

#a is the big data as numpy array N rows by 3 cols
a = np.random.randn(10**8, 3).astype('float32')

# This will create the index, prepare to wait...
# NOTE: took 7 minutes on my mac laptop with 10^8 rand 3-d numbers
#  there are some parameters that could be tweaked for faster indexing,
#  and there are implementations (not in scipy) that can construct
#  the kd-tree using parallel computing strategies (GPUs, e.g.)
k = KDTree(a)

#ask for the 2-nearest neighbors by querying the index with the
# same points
(dists, idxs) = k.query(a, 2)
# (dists, idxs) = k.query(a, 2, n_jobs=4)  # to use more CPUs on query...

#Note: 9 minutes for query on my laptop, 2 minutes with n_jobs=6
# So less than 10 minutes total for 10^8 points.

# If the second NN is > thresh distance, then there is no other point
# in the data set closer.
thresh_d = 0.1   #some threshold, equiv to 'd' in O.P.'s code
d_slice = dists[:, 1]  #distances to second NN for each point
res = np.flatnonzero( d_slice >= thresh_d )

答案 1 :(得分:2)

这是使用distance.pdist -

的矢量化方法
# Store number of pts (number of rows in a)
m = a.shape[0]

# Get the first of pairwise indices formed with the pairs of rows from a
# Simpler version, but a bit slow : idx1,_ = np.triu_indices(m,1)
shifts_arr = np.zeros(m*(m-1)/2,dtype=int)
shifts_arr[np.arange(m-1,1,-1).cumsum()] = 1
idx1 = shifts_arr.cumsum()

# Get the IDs of pairs of rows that are more than "d" apart and thus select 
# the rest of the rows using a boolean mask created with np.in1d for the 
# entire range of number of rows in a. Index into a to get the selected points.
selected_pts = a[~np.in1d(np.arange(m),idx1[distance.pdist(a) < d])] 

对于像10e10这样的庞大数据集,我们可能必须根据可用的系统内存以块的形式执行操作。

答案 2 :(得分:0)

  1. 删除追加,它一定很慢。您可以使用静态矢量距离并使用[]将数字放在正确的位置。

  2. 使用min而不是all。您只需要检查最小距离是否大于x。

  3. 实际上,你可以在发现距离小于限制的那一刻打破你的追尾,然后你就可以退出两个点。这样你甚至不必保存任何距离(除非你以后需要它们)。

    1. 由于d(a,b)= d(b,a),您只能对以下几点进行内部循环,忘记已计算的距离。如果你需要它们,你可以从阵列中选择更快。
  4. 如果你没有重复的观点,我会相信你的评论。

    selected_points = []
    for p1 in a:
        save_point = True
        for p2 in a:
            if p1!=p2 and distance.euclidean(p1,p2)<d:
                save_point = False
                break
        if save_point:
            selected_points.append(p1)
    
    return selected_points
    

    最后,我检查a,b和b,a,因为你不应该在处理它时修改列表,但你可以更聪明地使用一些附加变量。

答案 3 :(得分:0)

你的算法是二次的(10 ^ 20次运算),如果分布几乎是随机的,这是一个线性方法。 将您的空间拆分为大小为d/sqrt(3)^3的框。将每个点放在其框中。

然后为每个方框

  • 如果只有一个点,你只需要在一个小街区计算距离。

  • 否则无事可做。