行numpy非零行?

时间:2014-07-11 17:45:08

标签: numpy

我有一个2d布尔数组,我试图从中提取真值的索引。 Numpy的非零函数将我的2d数组分解为x和y的位置列表,这是有问题的。

是否可以在保留行顺序的同时找到true元素的列索引?

列中的每个真值都在同一行中相互关联,因此将它们分成(行索引,列索引)对是没有用的。这可能吗?

我在想,也许np.apply_along_axis可能有用。

2 个答案:

答案 0 :(得分:8)

我不太明白你想要什么(也许一个例子会有所帮助),但两个猜测:

如果你想查看一行中是否有任何特鲁斯,那么:

np.any(a, axis=1)

将为您提供每行的布尔值数组。

或者如果你想逐行获取True的索引,那么

testarray = np.array([
    [True, False, True],
    [True, True, False],
    [False, False, False],
    [False, True, False]])

collists = [ np.nonzero(t)[0] for t in testarray ]

这给出了:

>>> collists
[array([0, 2]), array([0, 1]), array([], dtype=int64), array([1])]

如果你想知道第3行True列的索引,那么:

>>> collists[3]
array([1])  

没有基于数组的纯粹方法来实现这一点,因为每行上的项目数量各不相同。这就是我们需要这些清单的原因。另一方面,性能很不错,我尝试使用10000 x 10000随机布尔数组,完成任务需要774毫秒。

答案 1 :(得分:0)

您可以使用熊猫来进行此操作。下面的示例使用矢量化运算为每一行提供非零元素的索引-输入数据中的每列数为一个。

import numpy as np
import pandas as pd

np.random.seed(0)

size = int(1e4), 5
d1 = pd.DataFrame(np.random.randint(5, size=size))

print(d1)

nz = pd.Series(np.count_nonzero(d1, axis=1))

max_nz = nz.max()

dfs = []
for _nz, nzdf in d1.groupby(nz, sort=False):

    nz = np.apply_along_axis(lambda r: np.nonzero(r)[0], 1, nzdf)

    mock_result = pd.DataFrame(np.ones(shape=(len(nzdf), max_nz)) - 2, index=nzdf.index)

    for i in range(nz.shape[1]):
        mock_result.iloc[:, i] = nz[:, i]

    dfs.append(mock_result)

result = pd.concat(dfs).sort_index()
print(result)

它将打印

      0  1  2  3  4
0     4  0  3  3  3
1     1  3  2  4  0
2     0  4  2  1  0
3     1  1  0  1  4
4     3  0  3  0  2
...  .. .. .. .. ..
9995  0  2  3  1  3
9996  3  3  2  3  1
9997  4  0  3  4  3
9998  4  2  4  0  0
9999  0  3  4  1  2

[10000 rows x 5 columns]
        0    1    2    3    4
0     0.0  2.0  3.0  4.0 -1.0
1     0.0  1.0  2.0  3.0 -1.0
2     1.0  2.0  3.0 -1.0 -1.0
3     0.0  1.0  3.0  4.0 -1.0
4     0.0  2.0  4.0 -1.0 -1.0
...   ...  ...  ...  ...  ...
9995  1.0  2.0  3.0  4.0 -1.0
9996  0.0  1.0  2.0  3.0  4.0
9997  0.0  2.0  3.0  4.0 -1.0
9998  0.0  1.0  2.0 -1.0 -1.0
9999  1.0  2.0  3.0  4.0 -1.0

[10000 rows x 5 columns]

使用这种技术,我可以大大减少基于行的scipy.stats.rankdata版本的运行时间。