Keras:从ImageDataGenerator或predict_generator获取True标签(y_test)

时间:2017-07-31 10:43:06

标签: keras

我正在使用ImageDataGenerator().flow_from_directory(...)从目录生成批量数据。

模型构建成功后,我想获得一个True和Predicted类标签的两列数组。使用model.predict_generator(validation_generator, steps=NUM_STEPS),我可以获得一系列预测类。是否可以让predict_generator输出相应的True类标签?

要添加:validation_generator.classes确实打印True标签,但按照从目录中检索它们的顺序,它不会通过扩充来考虑批处理或样本扩展。

2 个答案:

答案 0 :(得分:3)

您可以通过以下方式获取预测标签:

 y_pred = numpy.rint(predictions)

你可以通过以下方式获得真正的标签:

y_true = validation_generator.classes

在此之前,您应该在验证生成器中设置shuffle=False

最后,您可以通过

打印混淆矩阵

print confusion_matrix(y_true, y_pred)

答案 1 :(得分:0)

还有另一种稍微“hackier”的方式来检索真实标签。 请注意,这种方法可以在您的生成器中设置 shuffle=True 时处理(通常来说,混洗数据是个好主意 - 如果您在存储数据的位置手动执行此操作,或者通过生成器,这可能更容易)。不过,您将需要您的模型来使用这种方法。

# Create lists for storing the predictions and labels
predictions = []
labels = []

# Get the total number of labels in generator 
# (i.e. the length of the dataset where the generator generates batches from)
n = len(generator.labels)

# Loop over the generator
for data, label in generator:
    # Make predictions on data using the model. Store the results.
    predictions.extend(model.predict(data).flatten())

    # Store corresponding labels
    labels.extend(label)

    # We have to break out from the generator when we've processed 
    # the entire once (otherwise we would end up with duplicates). 
    if (len(label) < generator.batch_size) and (len(predictions) == n):
        break

您的预测和相应的标签现在应该分别存储在 predictionslabels 中。

最后,请记住,我们不应在验证和测试集/生成器上添加数据增强。

相关问题