使用np.where在2D数组中查找元素的索引给出ValueError

时间:2019-03-25 15:40:39

标签: python numpy

我正在尝试使用np.where在数组中查找元素的索引,特别是行号

我有一个大小为1000 x 6的数组,称为“表”。每行的第一个元素是2 x 2的字符串数组,其余元素为0。例如。 “表格”中元素的5 x 6示例:

    [['s',' ']   0 0 0 0 0
     [' ',' ']]
    [[' ',' ']   0 0 0 0 0
     [' ','a']]
    [[' ',' ']   0 0 0 0 0
     [' ',' ']]         
    [['p',' ']   0 0 0 0 0
     [' ',' ']]
    [[' ',' ']   0 0 0 0 0
     ['b',' ']]  

2x2数组都不同,我想获取大表中包含特定2x2的数组的索引,特别是行号。

例如。说我有

    grid = [['s',' ']   
            [' ',' ']]

我希望我的代码返回[0] [0]

我已经尝试过:

    i,j = np.where(table == grid)

还有

    i,j = np.where(np.all(table == grid))

我得到以下错误:

    ValueError: not enough values to unpack (expected 2, got 1)

例如使用单个值。

    index = np.where(table == grid) 

不会导致错误,但是print(index)将输出一个空数组:

    (array([], dtype=int64),)

从关于Stack Overflow的类似问题中,我似乎无法弄清楚此错误如何适用于我,而且我已经盯着它看了很久了

任何帮助将不胜感激

2 个答案:

答案 0 :(得分:2)

设置:

b = np.array([['s','t'],['q','r']])
c = np.array([['s',' '],[' ',' ']])
a = np.array([[c,0,0,0,0,0],
              [c,0,0,0,0,0],
              [c,0,0,0,0,0],
              [c,0,0,0,0,0],
              [b,0,0,0,0,0],
              [c,0,0,0,0,0],
              [c,0,0,0,0,0],
              [c,0,0,0,0,0],
              [c,0,0,0,0,0]])

假设您对的唯一兴趣为零;编写一个函数来测试一维数组中的每个项目。并将其应用于零列

def f(args):
    return [np.all(thing==b) for thing in args]

>>> np.apply_along_axis(f,0,a[:,0])
array([False, False, False, False,  True, False, False, False, False])
>>> 

在结果上使用np.where

>>> np.where(np.apply_along_axis(f,0,a[:,0]))
(array([4], dtype=int64),)

或遵循numpy.where文档中的注释:

>>> np.asarray(np.apply_along_axis(f,0,a[:,0])).nonzero()
(array([4], dtype=int64),)

@hpaulj指出np.apply_along_axis是不必要的。所以

>>> [np.all(thing == b) for thing in a[:,0]]
[False, False, False, False, True, False, False, False, False]

>>> np.asarray([np.all(thing == b) for thing in a[:,0]]).nonzero()
(array([4], dtype=int64),)

并且没有Python迭代:

>>> (np.stack(a[:,0])==b).all(axis=(1,2))
array([False, False, False, False,  True, False, False, False, False])

>>> (np.stack(a[:,0])==b).all(axis=(1,2)).nonzero()
(array([4], dtype=int64),)

答案 1 :(得分:1)

这是使用矢量化的解决方案

a = np.array( [np.array([['s',' '],[' ',' ']])  , 0, 0, 0, 0, 0 ])
grid = np.array([['s',' '],[' ',' ']]) 

vfunc = np.vectorize(lambda x: np.all(grid == x))
np.argwhere(vfunc(a))
相关问题