识别欧氏距离最小的点

时间:2011-02-25 16:17:18

标签: python algorithm numpy nearest-neighbor euclidean-distance

我有一个n维点的集合,我想找到哪两个是最接近的。我能想出的最好的两个方面是:

from numpy import *
myArr = array( [[1, 2],
                [3, 4],
                [5, 6],
                [7, 8]] )

n = myArr.shape[0]
cross = [[sum( ( myArr[i] - myArr[j] ) ** 2 ), i, j]
         for i in xrange( n )
         for j in xrange( n )
         if i != j
         ]

print min( cross )

给出了

[8, 0, 1]

但这对大型阵列来说太慢了。我可以应用什么样的优化?

相关:


Euclidean distance between points in two different Numpy arrays, not within

7 个答案:

答案 0 :(得分:11)

试试scipy.spatial.distance.pdist(myArr)。这将为您提供精简距离矩阵。您可以在其上使用argmin并找到最小值的索引。这可以转换为配对信息。

答案 1 :(得分:9)

有关于此问题的整个维基百科页面,请参阅:http://en.wikipedia.org/wiki/Closest_pair_of_points

执行摘要:您可以使用递归分治算法实现O(n log n)(在上面的Wiki页面上概述)。

答案 2 :(得分:6)

您可以利用最新版本的SciPy(v0.9)Delaunay三角测量工具。您可以确定最接近的两个点将是三角测量中单形的边缘,这是一个比每个组合更小的对的子集。

这是代码(针对一般N-D更新):

import numpy
from scipy import spatial

def closest_pts(pts):
    # set up the triangluataion
    # let Delaunay do the heavy lifting
    mesh = spatial.Delaunay(pts)

    # TODO: eliminate reduncant edges (numpy.unique?)
    edges = numpy.vstack((mesh.vertices[:,:dim], mesh.vertices[:,-dim:]))

    # the rest is easy
    x = mesh.points[edges[:,0]]
    y = mesh.points[edges[:,1]]

    dists = numpy.sum((x-y)**2, 1)
    idx = numpy.argmin(dists)

    return edges[idx]
    #print 'distance: ', dists[idx]
    #print 'coords:\n', pts[closest_verts]

dim = 3
N = 1000*dim
pts = numpy.random.random(N).reshape(N/dim, dim)

似乎非常接近O(n):

enter image description here

答案 3 :(得分:2)

有一个scipy函数pdist可以以相当有效的方式获得数组中各点之间的成对距离:

http://docs.scipy.org/doc/scipy/reference/spatial.distance.html

输出N *(N-1)/ 2个唯一对(因为r_ij == r_ji)。然后,您可以搜索最小值并避免代码中的整个循环混乱。

答案 4 :(得分:1)

也许你可以沿着这些方向前进:

In []: from scipy.spatial.distance import pdist as pd, squareform as sf
In []: m= 1234
In []: n= 123
In []: p= randn(m, n)
In []: d= sf(pd(p))
In []: a= arange(m)
In []: d[a, a]= d.max()
In []: where(d< d.min()+ 1e-9)
Out[]: (array([701, 730]), array([730, 701]))

要获得更多分数,您需要能够以某种方式利用群集的层次结构。

答案 5 :(得分:0)

与仅执行嵌套循环并跟踪最短的对相比,它有多快?我认为创建一个巨大的交叉阵列可能会伤害到你。如果你只做二维点,即使O(n ^ 2)仍然很快。

答案 6 :(得分:0)

对于小型数据集,已接受的答案是可以的,但其执行时间会缩放为n**2。但是,正如@payne所指出的,最佳解决方案可以实现n*log(n)计算时间缩放。

可以使用sklearn.neighbors.BallTree获得此可选解决方案,如下所示。

import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import BallTree as tree

n = 10
dim = 2
xy = np.random.uniform(size=[n, dim])

# This solution is optimal when xy is very large
res = tree(xy)
dist, ids = res.query(xy, 2)
mindist = dist[:, 1]  # second nearest neighbour
minid = np.argmin(mindist)

plt.plot(*xy.T, 'o')
plt.plot(*xy[ids[minid]].T, '-o')

此过程适用于非常大的xy值集,甚至适用于大尺寸dim(尽管示例说明了案例dim=2)。结果输出如下所示

The nearest pair of points is connected by an orange line

使用scipy.spatial.cKDTree,可以使用以下Scipy替换sklearn导入,从而获得相同的解决方案。但请注意,与cKDTree不同,BallTree不适合高维度

from scipy.spatial import cKDTree as tree