在ndarray的每个切片中获取最小值的索引

时间:2018-10-19 08:27:06

标签: python numpy-ndarray

我试图做一些应该很简单并且可以在 for-loop 中完成的事情,但是我试图避免这种情况。

我想沿numpy.ndarray a 的某个轴获取每个切片中最小值的索引。我对索引比值本身更感兴趣。我使用索引从另一个2D数组中获取一个值,该2D数组的形状等于 a 的前两个维度。

这是使用 for循环的天真的实现:

a = np.random.randint(0, 10, 60).reshape(3, 4, 5)
print(a)
for i in range(a.shape[-1]):
    idx = a[..., i].argmin()
    print('Slice:', i, '| Index:', idx, '| min value:',
          a[..., i].flat[idx])

出局:

[[[1 9 4 0 7]
  [6 3 1 6 8]
  [7 8 2 0 2]
  [8 6 1 6 5]]

 [[8 7 0 6 9]
  [7 2 6 4 5]
  [3 4 9 2 9]
  [1 4 8 0 7]]

 [[1 4 6 6 2]
  [9 9 5 6 7]
  [6 2 8 9 9]
  [3 9 8 5 4]]]
Slice: 0 | Index: 0 | min value: 1
Slice: 1 | Index: 5 | min value: 2
Slice: 2 | Index: 4 | min value: 0
Slice: 3 | Index: 0 | min value: 0
Slice: 4 | Index: 2 | min value: 2

我意识到我可以将axis关键字参数传递给argmin,但这不会产生我想要的结果。

1 个答案:

答案 0 :(得分:2)

对于问题中给出的特定情况,您可以调整数组的形状,然后使用argmin

>>> import numpy as np
>>> a = np.array([[[1, 9, 4, 0, 7],
...   [6, 3, 1, 6, 8],
...   [7, 8, 2, 0, 2],
...   [8, 6, 1, 6, 5]],
...
...  [[8, 7, 0, 6, 9],
...   [7, 2, 6, 4, 5],
...   [3, 4, 9, 2, 9],
...   [1, 4, 8, 0, 7]],
...
...  [[1, 4, 6, 6, 2],
...   [9, 9, 5, 6, 7],
...   [6, 2, 8, 9, 9],
...   [3, 9, 8, 5, 4]]])
>>> a.reshape(-1, a.shape[2]).min(axis=0)
array([1, 2, 0, 0, 2])
>>> a.reshape(-1, a.shape[2]).argmin(axis=0)
array([0, 5, 4, 0, 2])
>>>

shape[2]来自以下事实:这是维度(在这种情况下是内部维度或行),您不想要在其中计算最小值:您正在计算前两个维度的最小值。

您还需要切片编号:基本上只是元素的第二个索引。这很容易,因为那是顺序的,只是:

slices = np.arange(a.shape[2])