Tensorflow:如何创建混淆矩阵

时间:2017-11-14 07:45:01

标签: python python-2.7 tensorflow tensorboard

我是tensorflow的新手,我使用了本教程:

https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/

我在包含3个标签的新数据集上训练了相同的模型。我正在尝试创建混淆矩阵。

tf.confusion_matrix函数非常混乱。

有人可以帮助使用相同的代码示例。

1 个答案:

答案 0 :(得分:2)

你有3个标签(比如0,1,2)。假设您有一个大小为10的测试集,并且您获得以下张量: 真相:[0,0,0,0,1,1,2,2,2,2] 预测:[2,0,0,1,1,1,2,1,2,2] 然后就可以了,

>>> import tensorflow as tf
>>> truth = [0,0,0,0,1,1,2,2,2,2]
>>> prediction = [2,0,0,1,1,1,2,1,2,2]
>>> cm = tf.contrib.metrics.confusion_matrix(truth, prediction)
>>> with tf.Session() as sess:
...     sess.run(cm)
... 
array([[2, 1, 1],
       [0, 2, 0],
       [0, 1, 3]], dtype=int32)

请注意以下事项: 结果是3x3矩阵。第一行表示正确地预测了2次标签0,一旦被误认为是标签1,一旦被误认为是标签2。