将logit与标签进行比较的最佳方法是什么?

时间:2019-02-18 23:57:45

标签: python tensorflow

我正在将logit与循环中的标签进行比较:

  for r in range(logits.shape[0]):
    if labels[r] == np.argmax(logits[r]):
      guessed += 1.0

其中labels是一维整数标签数组,logits是2D数组,第二维是标签的概率。

以上解决方案是效率不高的Python循环。应该有一个常用的numpytensorflow快捷方式来做到这一点。你能建议一个吗?

1 个答案:

答案 0 :(得分:1)

您可以通过np.argmax(logits,axis=1)一次获得所有的最大值。以下可以替换for循环以获取猜测的总数:

guessed = np.sum(labels == np.argmax(logits,axis=1))