我有一个numpy
数组的索引列表,但是使用它们时并没有达到想要的结果。
n = 3
a = np.array([[8, 1, 6],
[3, 5, 7],
[4, 9, 2]])
np.random.seed(7)
idx = np.random.choice(np.arange(n), size=(n, n-1))
# array([[0, 1],
# [2, 0],
# [1, 2]])
在这种情况下,我想要:
我的列表包含n sublists
,并且所有这些列表都具有相同的长度。
我希望每个子列表仅使用一次,而不是用于所有轴。
# Wanted result
# b = array[[8, 1],
# [7, 3],
# [9, 2]])
我可以实现这一目标,但是在进行大量重复和整形时似乎很麻烦。
# Possibility 1
b = a[:, idx]
# array([[[8, 1], | [[3, 5], | [[4, 9],
# [6, 8], | [7, 3], | [2, 4],
# [1, 6]], | [5, 7]], | [9, 2]])
b = b[np.arange(n), np.arange(n), :]
# Possibility 2
b = a[np.repeat(range(n), n-1), idx.ravel()]
# array([8, 1, 7, 3, 9, 2])
b = b.reshape(n, n-1)
有更简单的方法吗?
答案 0 :(得分:2)
您可以在此处使用np.take_along_axis
np.take_along_axis(a, idx, 1)
array([[8, 1],
[7, 3],
[9, 2]])
或使用broadcasting
:
a[np.arange(a.shape[0])[:,None], idx]
array([[8, 1],
[7, 3],
[9, 2]])
请注意,您在此处使用integer array indexing时,需要指定要使用idx
进行索引的轴和行。