Tensorflow:批次中2x2矩阵的最大值指数

时间:2018-03-14 11:19:29

标签: python matrix tensorflow max argmax

如果我有一批矩阵使得我的矩阵是形状(?,600,600),我将如何检索批次中每个矩阵中最大值的行和col索引?这样我的行和列返回矩阵都是形状(?)(行返回矩阵具有批处理中每个示例的最大行的索引,并且类似于col返回矩阵)。

谢谢!

1 个答案:

答案 0 :(得分:1)

你可以重塑+ argmax。类似的东西:

x = tf.reshape(matrix, [tf.shape(matrix, 0), -1])
indices = tf.argmax(x, axis=1)  # this gives you indices from 0 to 600^2
col_indices = indices / 600
row_indices = indices % 600
final_indices = tf.transpose(tf.stack(col_indices, row_indices))