从行中元素满足条件的数组中获取索引

时间:2018-06-29 12:29:33

标签: python arrays list numpy conditional-statements

我想找到满足条件的数组的索引。

我有一个numpy.ndarray B: (m =行数= 8, 3列)

array([[ 0.,  0.,  0.],
   [ 0.,  0.,  0.],
   [ 0.,  0.,  1.],
   [ 0.,  1.,  1.],
   [ 0.,  1.,  0.],
   [ 1.,  1.,  0.],
   [ 1.,  1.,  1.],
   [ 1.,  0.,  1.],
   [ 1.,  0.,  0.]])

对于每一列,我想找到元素满足以下条件的行的索引: 对于列中的col: 对于所有行= 1,2,..,m-1,B(row,col)= 1和B(row + 1,col)= 1;对于行= 0和m,B(row,col)= 1。

所以期望的结果是:

Sets1=[[5, 6, 7, 8], [3, 4, 5], [2, 6]]

到目前为止,我已经尝试过:

Sets1=[]
for j in range(3):
    Sets1.append([i for i, x in enumerate(K[1:-1]) if B[x,j]==1 and B[x+1,j]==1])

但这只是条件的第一部分,并给出以下错误输出,因为它采用了新集合的索引。因此,它实际上应为加1。

Sets1= [[4, 5, 6], [2, 3, 4], [1, 5]]

条件的第二部分也适用于索引0和m。还没包括。.

编辑:我通过写i + 1固定了加号1部分,并通过添加以下if语句尝试了条件的第二部分:

Sets1=[]
for j in range(3):
    Sets1.append([i+1 for i, xp in enumerate(K[1:-1]) if B[xp,j]==1 and B[xp+1,j]==1])
    if B[0,j]==1: Sets1[j].append(0)
    if B[(x-1),j]==1: Sets1[j].append(x-1)

这确实起作用,因为它提供了以下输出:

Sets1= [[5, 6, 7, 8], [3, 4, 5], [2, 6]]

所以现在我只需要在条件的第一部分(在if语句之前)向列表的元素添加+1 ...

我非常感谢您的帮助!

2 个答案:

答案 0 :(得分:1)

有一种使用numpy的矢量化方法。首先,我们为a等于1的地方创建一个掩码:

mask=a.T==1.0

第二个遮罩将判断下一个元素是否也等于1。由于我们只希望满足两个条件的元素,所以我们将两个遮罩相乘:

mask_next=np.ones_like(mask).astype(bool)
mask_next[:,:-1]=mask[:,1:]
fin_mask=mask*mask_next

获取索引:

idx=np.where(fin_mask)

第一个索引将告诉我们将idx行拆分到哪里:

split=np.where(np.diff(idx[0]))[0]+1
out=np.split(idx[1],split)

out产生期望的结果。如果我理解正确,您希望元素的索引等于1,而下一个(列向)元素也是1?

答案 1 :(得分:1)

您可以使用布尔掩码和 np.where

来完成此操作

首先,面具:

c1 = (x==1)
c2 = (np.roll(x, -1, axis=0) != 0)
c3 = (x[-1] == 1)

c1 & (c2 | c3)

array([[False, False, False],
       [False, False, False],
       [False, False,  True],
       [False,  True, False],
       [False,  True, False],
       [ True,  True, False],
       [ True, False,  True],
       [ True, False, False],
       [ True, False, False]])

np.where 获取索引:

>>> np.where(c1 & (c2 | c3))

(array([2, 3, 4, 5, 5, 6, 6, 7, 8], dtype=int64),
 array([2, 1, 1, 0, 1, 0, 2, 0, 0], dtype=int64))

如果您确实希望将结果作为输出中的列表:

s = np.where(c1 & (c2 | c3))
[list(s[0][s[1]==i]) for i in range(x.shape[1])]

# [[5, 6, 7, 8], [3, 4, 5], [2, 6]]
相关问题