对于数组中的每个点,在第二个数组中找到与其最接近的点并输出该索引

时间:2016-12-12 14:05:26

标签: python arrays numpy enumerate

如果我有两个阵列:

X = np.random.rand(10000,2)
Y = np.random.rand(10000,2)

对于X中的每个点,我怎样才能找出Y中最接近它的点?所以最后我有一个数组显示:

x1_index   y_index_of_closest
   1               7
   2               54
   3               3
  ...             ...

我想对X中的两个列执行此操作,并将每个列与Y

中的每个列和值进行比较

2 个答案:

答案 0 :(得分:2)

这个问题很受欢迎。由于类似的问题在这里一直处于封闭和联系状态,因此我认为值得指出的是,即使对于数千个数据点而言,现有的答案是相当快的,但此后它们就开始分解。我的马铃薯断面每个阵列中有1万个物品。

其他答案的潜在问题是算法复杂性。他们将X中的所有内容与Y中的所有内容进行比较。为了解决这个问题,至少在平均水平上,我们需要一种更好的策略来排除Y中的某些问题。

在一个维度上这很容易-只需对所有内容进行排序并开始弹出最近的邻居即可。在两个维度上,有各种策略,但是KD树相当流行,并且已经在scipy堆栈中实现。在我的计算机上,XY中的每一个都有6k的东西周围的各种方法之间存在交叉。

from scipy.spatial import KDTree

tree = KDTree(X)
neighbor_dists, neighbor_indices = tree.query(Y)

scipy的KDTree实现的极差性能一直是我的痛处,尤其是在其基础上构建了许多东西之后。可能有一些数据集表现良好,但是我还没有看到。

如果您不介意额外的依赖关系,只需切换KDTree库即可获得 1000倍的速度提升。软件包pykdtree是可点安装的,我几乎可以保证conda软件包也能正常工作。通过这种方法,我用过的预算有限的Chromebook可以在短短30秒内处理XY并获得1000万分。胜过一时的万分秒节错;)

from pykdtree.kdtree import KDTree

tree = KDTree(X)
neighbor_dists, neighbor_indices = tree.query(Y)

答案 1 :(得分:1)

这一直是问题最多的问题(我在上周已经两次自己回答过),但是因为它可以用百万种方式表达:

import numpy as np
import scipy.spatial.distance.cdist as cdist

def withScipy(X,Y):  # faster
    return np.argmin(cdist(X,Y,'sqeuclidean'),axis=0)

def withoutScipy(X,Y): #slower, using broadcasting
    return np.argmin(np.sum((X[None,:,:]-Y[:,None,:])**2,axis=-1), axis=0)

还有一种仅使用What is DDL and DML的numpy方法比我的功能更快(但不是cdist)但我不太了解它解释一下。

EDIT + = 21个月:

通过算法进行此操作的最佳方法是使用KDTree。

from sklearn.neighbors import KDTree 
# since the sklearn implementation allows return_distance = False, saving memory

y_tree = KDTree(Y)
y_index_of_closest = y_tree.query(X, k = 1, return_distance = False)

@HansMusgrave对KDTree的速度非常快。

为了完成起见,np.einsum答案,我现在明白了:

np.argmin(                                      #  (X - Y) ** 2 
    np.einsum('ij, ij ->i', X, X)[:, None] +    # = X ** 2        \
    np.einsum('ij, ij ->i', Y, Y)          -    # + Y ** 2        \
    2 * X.dot(Y.T),                             # - 2 * X * Y
    axis = 1)

@Divakar在他的软件包einsumwiki page上很好地解释了这个方法

相关问题