model.fit_generator()形状错误

时间:2016-11-13 13:32:31

标签: python keras

import os
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Convolution2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense

img_width, img_height = 64, 64

train_data_dir = 'data/train'
validation_data_dir = 'data/validation'
nb_train_samples = sum([len(files) for files in os.walk(train_data_dir)])
nb_validation_samples = sum([len(files) for files in os.walk(validation_data_dir)])
nb_epoch = 10


model = Sequential()
model.add(Dense(4096, input_dim = 4096, init='normal', activation='relu'))
model.add(Dense(4,init='normal',activation='softmax'))
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])


train_datagen = ImageDataGenerator(
        rescale=1./255,
        )


test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        train_data_dir,
        color_mode="grayscale",
        target_size=(img_width, img_height),
        batch_size=1,
        class_mode=None)

validation_generator = test_datagen.flow_from_directory(
        validation_data_dir,
        color_mode="grayscale",
        target_size=(img_width, img_height),
        batch_size=1,
        class_mode=None)

model.fit_generator(
        train_generator,
        samples_per_epoch=nb_train_samples,
        nb_epoch=nb_epoch,
        validation_data=validation_generator,
        nb_val_samples=nb_validation_samples)

一切正常,直到上面编码中的model.fit_generator()。然后它会弹出如下所示的错误。

Traceback (most recent call last):
  File "C:/Users/Sam/PycharmProjects/MLP/Testing Code without CNN.py", line 55, in <module>
    nb_val_samples=nb_validation_samples)
  File "C:\Python27\lib\site-packages\keras\models.py", line 874, in fit_generator
    pickle_safe=pickle_safe)
  File "C:\Python27\lib\site-packages\keras\engine\training.py", line 1427, in fit_generator
    'or (x, y). Found: ' + str(generator_output))
Exception: output of generator should be a tuple (x, y, sample_weight) or (x, y). Found: [[[[ 0.19215688]

2 个答案:

答案 0 :(得分:1)

问题应该由数据维度不匹配引起。 ImageDataGenerator实际上加载了图像文件,并以(num_image_channel,image_height,image_width)的形状放入numpy数组中。但是,您的第一层是密集连接的层,它正在寻找1D阵列形状的输入数据,或具有多个样本的2D阵列。所以基本上你错过了输入层,它输入的形状正确。

更改以下代码行

model.add(Dense(4096, input_dim = 4096, init='normal', activation='relu'))

model.add(Reshape((img_width*img_height*img_channel), input_shape=(img_channel, img_height, img_width)))
model.add(Dense(4096, init='normal', activation='relu'))

您必须定义img_channel,这是图片中的频道数。上述代码还假定您使用dim_ordering th。如果您使用tf输入维度排序,则必须将输入重塑图层更改为

model.add(Reshape((img_width*img_height*img_channel), input_shape=(img_height, img_width, img_channel)))

---老答案 -

您可能已将培训数据和验证数据放入trainvalidation下的子文件夹中,而Keras不支持这些子文件夹。所有培训数据应位于一个文件夹中,验证数据应相同。

有关详细信息,请参阅this Keras tutorial

答案 1 :(得分:0)

我不是100%确定您要实现的目标,但如果您尝试对图片进行二元分类,请尝试将var resultReceived = false; async.until(function(){ console.log("Checking result : "+resultReceived); return resultReceived; }, function(callback){ try{ //do something resultReceived = true; }catch(err){ resultReceived = false; } }, function(result){ console.log("===================="); console.log(result); callback(result); }); 设置为class_mode。来自documentation

  

class_mode:“分类”,“二进制”,“稀疏”或“无”之一。默认:   “类别”。确定返回的标签数组的类型:   “分类”将是2D单热编码标签,“二进制”将是1D   二进制标签,“稀疏”将是1D整数标签。

错误消息有点令人困惑,但如果你看一下source code,它会变得更加清晰:

binary