使用NumPy缩短一个衬垫以获取索引

时间:2015-02-05 22:39:43

标签: numpy

我有以下代码行:

idxs = [i for i,x in enumerate(labels) if x==lbl]
  • 标签是一个numpy int of int
  • lbl是一个int

idxs = indices s.t.标签的相应元素具有值lbl

问题:是否有较短的单行?

谢谢!

2 个答案:

答案 0 :(得分:3)

您可以使用numpy.where的单参数形式:

idxs = np.where(labels == lbl)[0]

或等效地使用numpy.nonzero

idxs = np.nonzero(labels == lbl)[0]

或者,为了更好的可读性(谢谢,乔!),

idxs = np.flatnonzero(labels == lbl)

例如,

In [332]: np.random.seed(1)

In [333]: labels = np.random.randint(5, size=10)

In [334]: labels
Out[334]: array([3, 4, 0, 1, 3, 0, 0, 1, 4, 4])

In [335]: [i for i,x in enumerate(labels) if x==lbl]
Out[335]: [3, 7]

In [336]: np.where(labels == lbl)[0]
Out[336]: array([3, 7])

使用np.where比大型数组的列表理解要快得多:

In [339]: labels = np.tile(labels, 1000)

In [340]: labels.shape
Out[340]: (10000,)

In [341]: %timeit np.where(labels == lbl)[0]
10000 loops, best of 3: 45.9 µs per loop

In [342]: %timeit [i for i,x in enumerate(labels) if x==lbl]
100 loops, best of 3: 5.31 ms per loop

In [343]: 5310/45.9
Out[343]: 115.68627450980392

答案 1 :(得分:1)

我没有代表。评论答案......但是,请记住,当使用numpy.where时,“标签”必须是一个numpy数组。

Codelifting unutbu的回答:

idxs = np.where(np.array(labels) == lbl)[0]

只是说清楚:正确答案是由unutbu做出的。