scipy稀疏矩阵中每行或每列的Argmax

时间:2015-06-09 20:50:15

标签: python scipy sparse-matrix

在给定轴的情况下,

scipy.sparse.coo_matrix.max返回每行或每列的最大值。我想知道的不是值,而是每行或每列的最大值的索引。我还没有办法以有效的方式做到这一点,所以我很乐意接受任何帮助。

6 个答案:

答案 0 :(得分:3)

我建议研究

的代码
moo._min_or_max_axis

其中moocoo_matrix

mat = mat.tocsc()  # for axis=0
mat.sum_duplicates()

major_index, value = mat._minor_reduce(min_or_max)
not_full = np.diff(mat.indptr)[major_index] < N
value[not_full] = min_or_max(value[not_full], 0)

mask = value != 0
major_index = np.compress(mask, major_index)
value = np.compress(mask, value)
return coo_matrix((value, (np.zeros(len(value)), major_index)),
                      dtype=self.dtype, shape=(1, M))

根据轴的不同,它更喜欢使用csc而不是csr。我没有时间对此进行分析,但我猜测应该可以在计算中加入argmax

此建议可能无效。关键是mat._minor_reduce方法,它有一些改进:

ufunc.reduceat(mat.data, mat.indptr[:-1])

这是将ufunc应用于矩阵data数组的块,使用indptr来定义块。 np.sumnp.maxiumumufunc。我不知道等效的argmax ufunc。

一般情况下,如果你想通过“排”来做事情。对于csr矩阵(或csc的col),您必须迭代行相对昂贵的行,或者使用此ufunc.reduceat在平面mat.data向量上执行相同的操作。

group argmax/argmin over partitioning indices in numpy 尝试执行argmax.reduceat。那里的解决方案可能适用于稀疏矩阵。

答案 1 :(得分:2)

从scipy版本0.19开始,csr_matrixcsc_matrix都支持argmax()argmin()方法。

答案 2 :(得分:1)

如果A是您的scipy.sparse.coo_matrix,那么您将获得最大值的行和列,如下所示:

I=A.data.argmax()
maxrow = A.row[I]
maxcol=A.col[I]

要获得每行的最大值索引,请参阅下面的编辑:

from scipy.sparse import coo_matrix
import numpy as np
row  = np.array([0, 3, 1, 0])
col  = np.array([0, 2, 3, 2])
data = np.array([-3, 4, 11, -7])
A= coo_matrix((data, (row, col)), shape=(4, 4))
print A.toarray()

nrRows=A.shape[0]
maxrowind=[]
for i in range(nrRows):
    r = A.getrow(i)# r is 1xA.shape[1] matrix
    maxrowind.append( r.indices[r.data.argmax()] if r.nnz else 0)
print maxrowind 

r.nnz是显式存储值的计数(即非零值)

答案 3 :(得分:1)

最新版本的numpy_indexed软件包(免责声明:我是其作者)可以高效优雅的方式解决这个问题:

import numpy_indexed as npi
col, argmax = group_by(coo.col).argmax(coo.data)
row = coo.row[argmax]

这里我们按col分组,所以它是列上的argmax;交换行和col将为您提供行上的argmax。

答案 4 :(得分:1)

扩展@hpaulj和@joeln的答案,并按照建议使用group argmax/argmin over partitioning indices in numpy中的代码,此函数将计算CSR上的argmax或CSC上的argmax:

import numpy as np
import scipy.sparse as sp

def csr_csc_argmax(X, axis=None):
    is_csr = isinstance(X, sp.csr_matrix)
    is_csc = isinstance(X, sp.csc_matrix)
    assert( is_csr or is_csc )
    assert( not axis or (is_csr and axis==1) or (is_csc and axis==0) )

    major_size = X.shape[0 if is_csr else 1]
    major_lengths = np.diff(X.indptr) # group_lengths
    major_not_empty = (major_lengths > 0)

    result = -np.ones(shape=(major_size,), dtype=X.indices.dtype)
    split_at = X.indptr[:-1][major_not_empty]
    maxima = np.zeros((major_size,), dtype=X.dtype)
    maxima[major_not_empty] = np.maximum.reduceat(X.data, split_at)
    all_argmax = np.flatnonzero(np.repeat(maxima, major_lengths) == X.data)
    result[major_not_empty] = X.indices[all_argmax[np.searchsorted(all_argmax, split_at)]]
    return result

对于完全稀疏的任何行(CSR)或列(CSC)的argmax,它返回-1(即X.eliminate_zeros()之后完全为零)。

答案 5 :(得分:1)

正如其他人提到的,现在 argmax() 矩阵内置了 scipy.sparse。但是,我发现它对于大型矩阵很慢,所以我查看了 the source code。逻辑非常聪明,但它包含一个 Python 循环,可以减慢速度。以源代码为例,将其减少到每行 argmax(同时为了简单起见牺牲所有通用性、形状检查等)并使用 numba 对其进行修饰可以提供一些不错的速度改进。

功能如下:

import numpy as np
from numba import jit


def argmax_row_numba(X):
    return _argmax_row_numba(X.shape[0], X.indptr, X.data, X.indices)

@jit(nopython=True)
def _argmax_row_numba(shape, indptr, data, indices):
    # prep an array to hold the indices
    ret = np.zeros(shape)
    # figure out which lines actually contain data
    nz_lines, = np.diff(indptr).nonzero()
    # loop through the lines
    for i in nz_lines:
        p, q = indptr[i: i + 2]
        line_data = data[p: q]
        line_indices = indices[p: q]
        am = np.argmax(line_data)
        ret[i] = line_indices[am]

    return ret

生成用于测试的矩阵:


from scipy.sparse import random
size = 10000
m = random(m=size, n=size, density=0.0001, format="csr")
n_vals = m.data.shape[0]
m.data = np.random.random(size=n_vals).astype("float")


# the original scipy implementation reformatted to return a np.array
maxima1 = np.squeeze(np.array(m.argmax(axis=1)))
# calling the numba version
maxima2 = argmax_row_numba(m)

# Check that the results are the same
print(np.allclose(maxima1, maxima2))
# True

计时结果:

%timeit m.argmax(axis=1)
# 30.1 ms ± 246 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit argmax_row_numba(m)
# 211 µs ± 1.04 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)