凝聚矩阵函数找到对

时间:2011-03-16 10:16:58

标签: python algorithm math statistics scipy

对于一组观察:

[a1,a2,a3,a4,a5]

他们的成对距离

d=[[0,a12,a13,a14,a15]
   [a21,0,a23,a24,a25]
   [a31,a32,0,a34,a35]
   [a41,a42,a43,0,a45]
   [a51,a52,a53,a54,0]]

以浓缩矩阵形式给出(上面的上三角形,从scipy.spatial.distance.pdist计算得出):

c=[a12,a13,a14,a15,a23,a24,a25,a34,a35,a45]

问题是,假设我在压缩矩阵中有索引,那么有一个函数(最好是在python中) f 来快速给出哪两个观察结果来计算它们?

f(c,0)=(1,2)
f(c,5)=(2,4)
f(c,9)=(4,5)
...

我尝试了一些解决方案,但没有一个值得一提:(

7 个答案:

答案 0 :(得分:24)

浓缩矩阵的索引公式为

index = d*(d-1)/2 - (d-i)*(d-i-1)/2 + j - i - 1

其中i是行索引,j是列索引,d是原始(d X d)上三角矩阵的行长度。

考虑索引引用原始矩阵中某行的最左边非零条目的情况。对于所有最左边的索引,

j == i + 1

所以

index = d*(d-1)/2 - (d-i)*(d-i-1)/2 + i + 1 - i - 1
index = d*(d-1)/2 - (d-i)*(d-i-1)/2

使用某些代数,我们可以将其重写为

i**2 + (1 - 2d)*i + 2*index == 0

然后我们可以使用二次公式来找到方程的根,我们只是去 关心积极的根源。

如果此索引确实对应于最左边的非零单元格,那么我们得到一个正整数作为解决方案 对应于行号。然后,找到列号只是算术。

j = index - d*(d-1)/2 + (d-i)*(d-i-1)/2 + i + 1

如果索引不对应最左边的非零单元格,那么我们将找不到整数根,但我们可以将正根的最低位作为行号。

def row_col_from_condensed_index(d,i):
    b = 1 -2*d 
    x = math.floor((-b - math.sqrt(b**2 - 8*i))/2)
    y = i + x*(b + x + 2)/2 + 1
    return (x,y)  

如果你不知道d,你可以从浓缩矩阵的长度来计算它。

((d-1)*d)/2 == len(condensed_matrix)
d = (1 + math.sqrt(1 + 8*len(condensed_matrix)))/2 

答案 1 :(得分:4)

您可能会发现triu_indices有用。像,

In []: ti= triu_indices(5, 1)
In []: r, c= ti[0][5], ti[1][5]
In []: r, c
Out[]: (1, 3)

请注意,指数从0开始。您可以根据需要进行调整,例如:

In []: def f(n, c):
   ..:     n= ceil(sqrt(2* n))
   ..:     ti= triu_indices(n, 1)
   ..:     return ti[0][c]+ 1, ti[1][c]+ 1
   ..:
In []: f(len(c), 5)
Out[]: (2, 4)

答案 2 :(得分:2)

Cleary,你要搜索的函数,需要第二个参数:矩阵的维度 - 在你的情况下:5

首先尝试:

def f(dim,i): 
  d = dim-1 ; s = d
  while i<s: 
    s+=d ; d-=1
  return (dim-d, i-s+d)

答案 3 :(得分:0)

这是phynfo和您的评论提供的答案的补充。从压缩矩阵的长度推断矩阵的维数对我来说感觉不是一个干净的设计。也就是说,这是你如何计算它:

from math import sqrt, ceil

for i in range(1,10):
   thelen = (i * (i+1)) / 2
   thedim = sqrt(2*thelen + ceil(sqrt(2*thelen)))
   print "compressed array of length %d has dimension %d" % (thelen, thedim)

外部平方根的参数应该始终是一个方形整数,但sqrt会返回一个浮点数,因此在使用它时需要注意。

答案 4 :(得分:0)

要完成此问题的答案列表:fgreggs答案的快速矢量化版本(如David Marx所建议的)可能如下所示:

def vec_row_col(d,i):                                                                
    i = np.array(i)                                                                 
    b = 1 - 2 * d                                                                   
    x = (np.floor((-b - np.sqrt(b**2 - 8*i))/2).astype(int)                                      
    y = (i + x*(b + x + 2)/2 + 1).astype(int)                                                    
    if i.shape:                                                                     
        return zip(x,y)                                                             
    else:                                                                           
        return (x,y) 

我需要对大型数组进行这些计算,与非向量化版本(https://stackoverflow.com/a/14839010/3631440)相比,加速比(通常情况下)非常令人印象深刻(使用IPython%timeit):

import numpy as np
from scipy.spatial import distance

test = np.random.rand(1000,1000)
condense = distance.pdist(test)
sample = np.random.randint(0,len(condense), 1000)

%timeit res = vec_row_col(1000, sample)
10000 loops, best of 3: 156 µs per loop

res = []
%timeit for i in sample: res.append(row_col_from_condensed_index(1000, i))
100 loops, best of 3: 5.87 ms per loop

在此示例中, 37 的速度更快!

答案 5 :(得分:-1)

这是另一种解决方案:

import numpy as np

def f(c,n):
    tt = np.zeros_like(c)
    tt[n] = 1
    return tuple(np.nonzero(squareform(tt))[0])

答案 6 :(得分:-1)

使用numpy.triu_indices来提高效率 用这个:

def PdistIndices(n,I):
    '''idx = {} indices for pdist results'''
    idx = numpy.array(numpy.triu_indices(n,1)).T[I]
    return idx

所以I是一系列索引。

然而更好的解决方案是在Fortran中实施优化的强力搜索:

function PdistIndices(n,indices,m) result(IJ)
    !IJ = {} indices for pdist[python] selected results[indices]
    implicit none
    integer:: i,j,m,n,k,w,indices(0:m-1),IJ(0:m-1,2)
    logical:: finished
    k = 0; w = 0; finished = .false.
    do i=0,n-2
        do j=i+1,n-1
            if (k==indices(w)) then
                IJ(w,:) = [i,j]
                w = w+1
                if (w==m) then
                    finished = .true.
                    exit
                endif
            endif
            k = k+1
        enddo
        if (finished) then
            exit
        endif
    enddo
end function

然后使用F2PY进行编译,享受无与伦比的性能。 ;)

相关问题