确定哪些行(或列)在稀疏矩阵中具有值

时间:2017-05-14 20:19:00

标签: julia sparse-matrix

我需要识别在大型稀疏布尔矩阵中定义了值的行(/列)。我想用这个来1.通过那些行/列切片(实际上是view)Matrix; 2.切片(/ view)矢量和矩阵,其尺寸与Matrix的边距相同。即结果可能应该是索引/ Bools的Vector或(最好是)迭代器。

我已经尝试过显而易见的事了:

a = sprand(10000, 10000, 0.01)
cols = unique(a.colptr)
rows = unique(a.rowvals)

但是我的机器上每个都需要20毫秒,可能是因为它们分配了大约1MB(至少它们分配了colsrows)。这是一个性能关键的功能,所以我希望优化代码。基本代码似乎有稀疏矩阵的nzrange迭代器,但我不容易看到如何将它应用于我的情况。

有建议的方法吗?

第二个问题:我还需要对稀疏矩阵的视图执行此操作 - 这类似于x = view(a,:,:); cols = unique(x.parent.colptr[x.indices[:,2]])还是有专门的功能?稀疏矩阵的视图似乎很棘手(参考https://discourse.julialang.org/t/slow-arithmetic-on-views-of-sparse-matrices/3644 - 不是交叉发布)

非常感谢!

2 个答案:

答案 0 :(得分:6)

关于获取稀疏矩阵的非零行和列,以下函数应该非常有效:

nzcols(a::SparseMatrixCSC) = collect(i 
  for i in 1:a.n if a.colptr[i]<a.colptr[i+1])

function nzrows(a::SparseMatrixCSC)
    active = falses(a.m)
    for r in a.rowval
        active[r] = true
    end
    return find(active)
end

对于具有0.1密度的10_000x10_000矩阵,cols和row分别需要0.2ms和2.9ms。它也应该比有问题的方法更快(除了正确性问题)。

关于稀疏矩阵的视图,快速解决方案是将视图转换为稀疏矩阵(例如,使用b = sparse(view(a,100:199,100:199)))并使用上述函数。在代码中:

nzcols(b::SubArray{T,2,P}) where {T,P<:AbstractSparseArray} = nzcols(sparse(b))
nzrows(b::SubArray{T,2,P}) where {T,P<:AbstractSparseArray} = nzrows(sparse(b))

更好的解决方案是根据视图自定义功能。例如,当视图对行和列使用UnitRanges时:

# utility predicate returning true if element of sorted v in range r
inrange(v,r) = searchsortedlast(v,last(r))>=searchsortedfirst(v,first(r))

function nzcols(b::SubArray{T,2,P,Tuple{UnitRange{Int64},UnitRange{Int64}}}
  ) where {T,P<:SparseMatrixCSC}
    return collect(i+1-start(b.indexes[2]) 
      for i in b.indexes[2]
      if b.parent.colptr[i]<b.parent.colptr[i+1] && 
        inrange(b.parent.rowval[nzrange(b.parent,i)],b.indexes[1]))
end

function nzrows(b::SubArray{T,2,P,Tuple{UnitRange{Int64},UnitRange{Int64}}}
  ) where {T,P<:SparseMatrixCSC}
    active = falses(length(b.indexes[1]))
    for c in b.indexes[2]
        for r in nzrange(b.parent,c)
            if b.parent.rowval[r] in b.indexes[1]
                active[b.parent.rowval[r]+1-start(b.indexes[1])] = true
            end
        end
    end
    return find(active)
end

比完整矩阵的版本工作得更快(对于100x100以上的10,000x10,000矩阵列和行的子矩阵,在我的机器上分别需要16μs和12μs,但这些都是不稳定的结果)。

适当的基准测试将使用固定矩阵(或至少修复随机种子)。如果我这样做的话,我会用这样的基准编辑这行。

答案 1 :(得分:2)

如果索引不是范围,则转换为稀疏矩阵的回退有效,但这里是索引的版本,即Vector。如果索引是混合的,则需要另一组版本。相当重复,但这是Julia的优势,当版本完成后,代码将使用调用者中的类型正确选择优化方法而不需要太多努力。

function sortedintersecting(v1, v2)
    i,j = start(v1), start(v2)
    while i <= length(v1) && j <= length(v2)
        if v1[i] == v2[j] return true
        elseif v1[i] > v2[j] j += 1
        else i += 1
        end
    end
    return false
end

function nzcols(b::SubArray{T,2,P,Tuple{Vector{Int64},Vector{Int64}}}
  ) where {T,P<:SparseMatrixCSC}
    brows = sort(unique(b.indexes[1]))
    return [k 
      for (k,i) in enumerate(b.indexes[2])
      if b.parent.colptr[i]<b.parent.colptr[i+1] && 
        sortedintersecting(brows,b.parent.rowval[nzrange(b.parent,i)])]
end

function nzrows(b::SubArray{T,2,P,Tuple{Vector{Int64},Vector{Int64}}}
  ) where {T,P<:SparseMatrixCSC}
    active = falses(length(b.indexes[1]))
    for c in b.indexes[2]
      active[findin(b.indexes[1],b.parent.rowval[nzrange(b.parent,c)])] = true
    end
    return find(active)
end

- ADDENDUM -

由于注意到nzrows对于Vector {Int}索引有点慢,这是尝试通过将findin替换为利用排序的版本来提高其速度:

function findin2(inds,v,w)
    i,j = start(v),start(w)
    res = Vector{Int}()
    while i<=length(v) && j<=length(w)
        if v[i]==w[j]
            push!(res,inds[i])
            i += 1
        elseif (v[i]<w[j]) i += 1
        else j += 1
        end
    end
    return res
end

function nzrows(b::SubArray{T,2,P,Tuple{Vector{Int64},Vector{Int64}}}
  ) where {T,P<:SparseMatrixCSC}
    active = falses(length(b.indexes[1]))
    inds = sortperm(b.indexes[1])
    brows = (b.indexes[1])[inds] 
    for c in b.indexes[2]
      active[findin2(inds,brows,b.parent.rowval[nzrange(b.parent,c)])] = true
    end
    return find(active)
end
相关问题