加速指数“恢复”

时间:2018-05-17 10:45:34

标签: python arrays numpy indexing

我有一个形状为a的numpy数组(n, 3),其中包含从0m的整数。 mn都可能相当大。众所周知,从0m的每个整数有时只出现一次,但大多数在a的某个地方出现两次。连续没有加倍的索引。

我现在想构建“反向”索引,即形状b_row的两个数组b_col(m, 2),每行包含(一个或两个)行/ a中的列索引row_idx中显示a

这有效:

import numpy

a = numpy.array([
    [0, 1, 2],
    [0, 1, 3],
    [2, 3, 4],
    [4, 5, 6],
    # ...
    ])

print(a)

b_row = -numpy.ones((7, 2), dtype=int)
b_col = -numpy.ones((7, 2), dtype=int)
count = numpy.zeros(7, dtype=int)
for k, row in enumerate(a):
    i = count[row]
    b_row[row, i] = k
    b_col[row, i] = [0, 1, 2]
    count[row] += 1

print(b_row)
print(b_col)
[[0 1 2]
 [0 1 3]
 [2 3 4]
 [4 5 6]]

[[ 0  1]
 [ 0  1]
 [ 0  2]
 [ 1  2]
 [ 2  3]
 [ 3 -1]
 [ 3 -1]]

[[ 0  0]
 [ 1  1]
 [ 2  0]
 [ 2  1]
 [ 2  0]
 [ 1 -1]
 [ 2 -1]]

但由于a上的显式循环而缓慢。

有关如何提高速度的任何提示?

2 个答案:

答案 0 :(得分:2)

这是一个解决方案:

import numpy as np

m = 7
a = np.array([
    [0, 1, 2],
    [0, 1, 3],
    [2, 3, 4],
    [4, 5, 6],
    # ...
    ])

print('a:')
print(a)

a_flat = a.flatten()  # Or a.ravel() if can modify original array
v1, idx1 = np.unique(a_flat, return_index=True)
a_flat[idx1] = -1
v2, idx2 = np.unique(a_flat, return_index=True)
v2, idx2 = v2[1:], idx2[1:]
rows1, cols1 = np.unravel_index(idx1, a.shape)
rows2, cols2 = np.unravel_index(idx2, a.shape)
b_row = -np.ones((m, 2), dtype=int)
b_col = -np.ones((m, 2), dtype=int)
b_row[v1, 0] = rows1
b_col[v1, 0] = cols1
b_row[v2, 1] = rows2
b_col[v2, 1] = cols2

print('b_row:')
print(b_row)
print('b_col:')
print(b_col)

输出:

a:
[[0 1 2]
 [0 1 3]
 [2 3 4]
 [4 5 6]]
b_row:
[[ 0  1]
 [ 0  1]
 [ 0  2]
 [ 1  2]
 [ 2  3]
 [ 3 -1]
 [ 3 -1]]
b_col:
[[ 0  0]
 [ 1  1]
 [ 2  0]
 [ 2  1]
 [ 2  0]
 [ 1 -1]
 [ 2 -1]]

编辑:

IPython中的一个小基准用于比较。正如@eozd所示,由于{(1)}在O(n)中运行,算法复杂度原则上更高,但对于实际大小,矢量化解决方案似乎仍然快得多:

np.unique

答案 1 :(得分:1)

这是一个仅使用一个argsort和一系列轻量级索引操作的解决方案:

def grp_start_len(a):
    # https://stackoverflow.com/a/50394587/353337
    m = numpy.concatenate([[True], a[:-1] != a[1:], [True]])
    idx = numpy.flatnonzero(m)
    return idx[:-1], numpy.diff(idx)


a_flat = a.flatten()

idx_sort = numpy.argsort(a_flat)

idx_start, count = grp_start_len(a_flat[idx_sort])

res1 = idx_sort[idx_start[count==1]][:, numpy.newaxis]
res1 // 3
res1 % 3

idx = idx_start[count==2]
res2 = numpy.column_stack([idx_sort[idx], idx_sort[idx + 1]])
res2 // 3
res2 % 3

基本思想是,在a被展平和排序后,所有信息都可以从a_flat_sorted中的起始索引和整数块的长度中提取。