#Function definition for Display Image
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 10.0)
def Display_images(title,img, labels, nrows, ncols):
fig, axes = plt.subplots(nrows, ncols)
for i, ax in enumerate(axes.flat):
if img[i].shape == (32, 32, 3):
ax.imshow(img[i])
else:
ax.imshow(img[i,:,:,0])
ax.set_xticks([]); ax.set_yticks([])
ax.set_title(title+str(labels[i]))
import numpy as np
def rgb2gray(images):
return np.expand_dims(np.dot(images, [0.2990, 0.5870, 0.1140]), axis=3)
x_train_grayscale = rgb2gray(x_train).astype(np.float32)
x_test_grayscale = rgb2gray(x_test).astype(np.float32)
print("Training Set", x_train_grayscale.shape)
print("Test Set", x_test_grayscale.shape)
print(x_train_grayscale[0].shape)
Display_images('Train',x_train_grayscale, y_train, 1, 10)
Display_images('Test',x_train_grayscale, y_train, 1, 10)
**
我正在尝试显示10张图像。 在上面的代码中,当我尝试显示灰色图像时,我只会得到彩色图像? 如何解决这个问题?**