除了功能和标签

时间:2019-07-16 19:56:23

标签: python-3.x tensorflow

我正在使用tf.contrib.data.make_csv_dataset将CSV数据转换为可以很好地为CSV中的指定列提供功能和标签的数据集。

如何从CSV指定额外的列,这些列可能希望在模型测试期间可用,但不能用于训练和模型计算?例如,在评估测试准确性时,我想知道CSV数据集中哪些特定行的预测有误。有没有一种方法可以提供一个额外的参数,我可以利用它来找出模型究竟出了什么问题?

现在代码看起来像这样(基于Tensorflow示例页面):

test_dataset = tf.contrib.data.make_csv_dataset(
    CSV_file,
    BATCH_TEST_SIZE,
    column_names=column_names,
    select_columns=column_select,
    label_name=label_name,
    num_epochs=1,
    shuffle=False)\
        .map(pack_features_vector)

然后在测试过程中,代码执行此操作:

for (x, y) in test_dataset:
    logits = model(x)
    prediction = tf.argmax(logits, axis=1, output_type=tf.int32)

    print('Act\t{}\nPred\t{}\n\n'.format(y, prediction))

由于生成器功能仅提供xy值,我该如何特别针对原始CSV文件中的哪一行进行预测?

我该怎么做

for (x, y, z) in test_dataset:
print(z[x])

z是该附加列,然后我可以对其进行检查?

1 个答案:

答案 0 :(得分:0)

感谢您澄清您的问题。我相信您要查找的答案是为了查看错误地预测了哪些行,方法是在keras中使用model.predict_classes()。以下代码应为您提供模型所猜测的数组:

predictionArr = model.predict_classes(testData).reshape(-1)

这将为您提供测试数据集长度的数组,您可以在它们之间进行比较。

希望这对您有所帮助,并回答您的问题!