我正在尝试为图像分类建立模型,但无法弄清楚如何按照本指南中的预测类(和概率)绘制验证图像:
https://www.tensorflow.org/tutorials/images/hub_with_keras#check_the_predictions
当我使用ImageDataGenerator时,我可以获取有关预测类别和概率的信息,但我无法获取图像本身。
plt.figure(figsize=(10,9))
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
plt.title(labels_batch[n])
plt.axis('off')
_ = plt.suptitle("Model predictions")
我收到ValueError:无法将输入数组从形状(400,80,80,3)广播到形状(400)
我知道问题出在-> image_batch [n]
请告诉我如何将有关图像的信息传递给imshow()。
P.S。 我使用的图像生成器代码:
img_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
image_batch = img_generator.flow_from_directory(
path,
target_size=(size, size),
color_mode='rgb',
class_mode='categorical',
batch_size=batch_size
)