查找哪个数组为分类提供了最大值

时间:2010-11-08 23:36:59

标签: python numpy

给定2个2x2 Numpy数组,每个元素的值介于0和1之间,我想找到具有最大值的2中的一个数组,并按元素进行比较。例如,给定:

A = [[.6 .2] [.3 .4]]B = [[.4 .5] [.7 .1]],我想要回复一下:[[A B] [A B]]。理想情况下,输出将是一些数字[[1 2] [1 2]],其中1表示A,2表示B.这样,如果我比较10个数组,输出将具有1到10之间的整数作为每个元素,可以很容易地绘制在pcolor图中。

如果我只是将这些数组合并为一个2x2x2并执行np.amax(combined_array,axis = 0),我会得到最大值,但不知道它来自哪个数组。

所有这一切的目的是每个数组代表一个类别,并包含该类别发生的概率。我想知道每个元素位置[0] [0],[0] [1],[1] [0]和[1] [1],哪个类别是在该位置发生的最可能的类别。

1 个答案:

答案 0 :(得分:3)

如果你有一个10个2x2矩阵的组合数组,比如

随机生成的那个
a = numpy.random.randn(10, 2, 2)

你可以通过

获得所需的指数
a.argmax(axis=0)