解释numpy.where结果

时间:2016-02-14 18:07:19

标签: python arrays numpy

我对numpy.where的结果意味着什么感到困惑,以及如何使用它来索引数组。

看看下面的代码示例:

import numpy as np
a = np.random.randn(10,10,2)
indices = np.where(a[:,:,0] > 0.5)

我希望indices数组为2-dim并包含条件为true的索引。我们可以通过

看到
indices = np.array(indices)
indices.shape  # (2,120)

所以看起来indices正在对某种扁平化阵列进行操作,但我无法弄清楚具体如何。更令人困惑的是,

a.shape  # (20,20,2)
a[indices].shape # (2,120,20,2)

问题:

如何使用np.where的输出索引我的数组实际上增长数组的大小?这里发生了什么?

1 个答案:

答案 0 :(得分:4)

您的索引基于错误的假设:np.where返回一些可以立即用于高级索引的内容(它是np.ndarrays的元组)。但是你把它转换成一个numpy数组(所以它现在是np.ndarray的{​​{1}}。

所以

np.ndarrays

为您提供import numpy as np a = np.random.randn(10,10,2) indices = np.where(a[:,:,0] > 0.5) a[:,:,0][indices] # If you do a[indices] the result would be different, I'm not sure what # you intended. 找到的元素。如果您将np.where转换为indices,则会触发另一种形式的索引(see this section of the numpy docs),并且docs中的警告消息非常重要。这就是它增加数组总大小的原因。

有关np.array含义的一些其他信息:您将获得包含np.where数组的元组。 n是输入数组的维数。因此,满足条件的第一个元素具有索引n而不是[0][0], [1][0], ... [n][0]。因此,在您的情况下,您有(2,120)意味着您有2个维度和120个找到的点。