Keras conv2d输入形状错误

时间:2018-03-07 05:32:14

标签: python-3.x keras convolution mnist

无法解决keras输入形状错误的问题。如何在conv2d层中实际指定尺寸尚不清楚。试过不同的方法,仍然无法让它发挥作用。下面提到的是我尝试使用keras为mnist数据集实现conv网络的代码。该阵列最初为1x784。但我把它重新塑造成28x28,它仍然没有用。任何人都可以告诉我,我做错了什么。谢谢!

batch_size = 32
epochs = 20
number_of_classes = 10

def build_brain():
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=(28,28,1)))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(10, activation='softmax'))
    model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])
    return model

def shared_dataset(data_xy):
    data_x, data_y = data_xy
    shared_x = np.asarray(data_x)
    shared_y = np.asarray(data_y)
    return shared_x, shared_y

f = gzip.open('mnist.pkl.gz', 'rb')
train_set, valid_set, test_set = pickle.load(f,encoding='latin1')
f.close()

test_set_x, test_set_y = shared_dataset(test_set)
valid_set_x, valid_set_y = shared_dataset(valid_set)
train_set_x, train_set_y = shared_dataset(train_set)


train_set_x = train_set_x.reshape(-1,28,28)
valid_set_x = valid_set_x.reshape(-1,28,28)
test_set_x = test_set_x.reshape(-1,28,28)


brain = build_brain()
asd = brain.fit(train_set_x, train_set_y, epochs=30, validation_data = (valid_set_x, valid_set_y), batch_size=32)
score = brain.evaluate(test_set_x, test_set_y, batch_size = 32)
print('Test score:', score[0])
print('Test accuracy:', score[1])

1 个答案:

答案 0 :(得分:0)

这是你的问题:

train_set_x = train_set_x.reshape(-1,28,28)
valid_set_x = valid_set_x.reshape(-1,28,28)
test_set_x = test_set_x.reshape(-1,28,28)

网络期望训练输入具有形状(样本,28,28,1),因为您使用input_shape =(28,28,1)。所以正确的重塑是:

train_set_x = train_set_x.reshape(-1, 28, 28, 1)
valid_set_x = valid_set_x.reshape(-1, 28, 28, 1)
test_set_x = test_set_x.reshape(-1, 28, 28, 1)