基于距离阈值的群集数据

时间:2018-04-21 08:24:59

标签: python performance pandas numpy

我想删除与之前数据相差10cm的数据。

这就是我所拥有的,但它需要大量的计算时间,因为我的数据集非常庞大

for i in range(len(data)):
     for j in range(i, len(data)):
          if (i == j):
               continue
          elif np.sqrt((data[i, 0]-data[j, 0])**2 + (data[i, 1]-data[i, 1])**2) <= 0.1:
               data[j, 0] = np.nan
data = data[~np.isnan(data).any(axis=1)]

有没有pythonic的方法来做到这一点?

2 个答案:

答案 0 :(得分:3)

以下是使用KDTree

的方法
import numpy as np
from scipy.spatial import cKDTree as KDTree

def cluster_data_KDTree(a, thr=0.1):
    t = KDTree(a)
    mask = np.ones(a.shape[:1], bool)
    idx = 0
    nxt = 1
    while nxt:
        mask[t.query_ball_point(a[idx], thr)] = False
        nxt = mask[idx:].argmax()
        mask[idx] = True
        idx += nxt
    return a[mask]

借用@Divakar的测试用例,我们发现这会在100x Divakar报告之上提供另一个400x加速。与OP相比,我们推断出一个荒谬的40,000x

np.random.seed(0)
data1 = np.random.rand(10000,2)
data2 = data1.copy()

from timeit import timeit
kwds = dict(globals=globals(), number=10)

print(timeit("cluster_data_KDTree(data1)", **kwds))
print(timeit("cluster_data_pdist_v1(data2)", **kwds))

np.random.seed(0)
data1 = np.random.rand(10000,2)
data2 = data1.copy()

out1 = cluster_data_KDTree(data1, thr=0.1)
out2 = cluster_data_pdist_v1(data2, dist_thresh = 0.1)
print(np.allclose(out1, out2))

示例输出:

0.05073001119308174
5.646531613077968
True

事实证明,这个测试用例恰好对我的方法非常有利,因为集群非常少,因此迭代次数很少。

如果通过将阈值更改为3800 0.01,我们将群集数量大幅增加到大约KDTree,但仍会获胜,但加速从100x减少到{{1 }}:

15x

答案 1 :(得分:2)

我们可以将pdist与一个循环 -

一起使用
from scipy.spatial.distance import pdist

def cluster_data_pdist_v1(a, dist_thresh = 0.1):
    d = pdist(a)
    mask = d<=dist_thresh

    n = len(a)
    idx = np.concatenate(( [0], np.arange(n-1,0,-1).cumsum() ))
    start, stop = idx[:-1], idx[1:]
    idx_out = np.zeros(mask.sum(), dtype=int) # use np.empty for bit more speedup
    cur_start = 0
    for iterID,(i,j) in enumerate(zip(start, stop)):
        if iterID not in idx_out[:cur_start]:
            rm_idx = np.flatnonzero(mask[i:j])+iterID+1
            L = len(rm_idx)
            idx_out[cur_start:cur_start+L] = rm_idx
            cur_start += L

    return np.delete(a, idx_out[:cur_start], axis=0)

基准

原创方法 -

def cluster_data_org(data, dist_thresh = 0.1):
    for i in range(len(data)):
         for j in range(i, len(data)):
              if (i == j):
                   continue
              elif np.sqrt((data[i, 0]-data[j, 0])**2 +
                           (data[i, 1]-data[j, 1])**2) <= 0.1:
                   data[j, 0] = np.nan
    return data[~np.isnan(data).any(axis=1)]

运行时测试,验证范围为[0,1)10,000点的随机数据 -

In [207]: np.random.seed(0)
     ...: data1 = np.random.rand(10000,2)
     ...: data2 = data1.copy()
     ...: 
     ...: out1 = cluster_data_org(data1, dist_thresh = 0.1)
     ...: out2 = cluster_data_pdist_v1(data2, dist_thresh = 0.1)
     ...: print np.allclose(out1, out2)
True

In [208]: np.random.seed(0)
     ...: data1 = np.random.rand(10000,2)
     ...: data2 = data1.copy()

In [209]: %timeit cluster_data_org(data1, dist_thresh = 0.1)
1 loop, best of 3: 1min 50s per loop

In [210]: %timeit cluster_data_pdist_v1(data2, dist_thresh = 0.1)
1 loop, best of 3: 287 ms per loop

围绕 400x 加速进行此类设置!