如何过滤数组以仅保留其重复元素?

时间:2017-12-24 01:02:51

标签: python arrays numpy

我有一个NumPy数组。我想从中创建一个新的,只包含重复的元素。例如,在数组看起来像

之前
[[  3.   0.   1.   0.  12.   1.]
 [ 14.   0.   2.   2.   0.   3.]
 [  3.   0.   1.   2.   0.   3.]
 [ 12.   0.  14.   0.  12.   1.]
 [ 14.   0.   2.  12.   0.  14.]
 [ 15.   4.  13.  13.  14.  15.]
 [ 14.   2.  15.  13.  14.  15.]]

并且在操作之后我希望它看起来像

[[ 1.   0.  ]
 [ 0.   2.  ]
 [  3.  0.  ]
 [ 12.  0.  ]
 [ 14.  0.  ]
 [ 15.  13. ]
 [ 14.  15. ]]

现在,我会使用for循环来做,但也许你们中的某个人知道更顺畅,更快捷的方式。

1 个答案:

答案 0 :(得分:1)

您无法在一个简单的步骤中执行此操作,因为重复项的长度可能会在一行之间发生变化。

我建议你做下面的事。

定义一个函数来查找重复项:

def dups(a):
    uniques, counts = np.unique(a, return_counts=True)
    return uniques[np.where(counts > 1)]

然后将其应用于数组的每一行:

ans = [dups(row) for row in arr]

对于所有行具有相同重复次数的情况,您可以使用ans制作一个numpy数组:

ans = np.stack(ans)

对于您的示例案例,它会打印:

[[  0.   1.]
 [  0.   2.]
 [  0.   3.]
 [  0.  12.]
 [  0.  14.]
 [ 13.  15.]
 [ 14.  15.]]