使用转移学习进行物体检测的训练顺序模型期间,训练和验证准确性保持恒定

时间:2019-02-03 09:00:21

标签: python tensorflow keras

我试图训练一个用于二分类的模型,但是在3个时期之后,验证准确性仍然保持在0.5000。

两个类别的数据集均包含1512张图像,因此总共有3024张图像。我使用keras使用VGG16模型进行迁移学习。

from keras import models
from keras import layers
from keras import optimizers
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import TensorBoard, EarlyStopping, ModelCheckpoint

# Stop when we stop learning.
early_stopper = EarlyStopping(patience=10)
# tensorboard
tensorboard = TensorBoard(log_dir='./logs')
model_check_point = ModelCheckpoint('vgg16.h5', save_best_only=True)
train_dir = 'dataset\\training_set'
validation_dir = 'dataset\\validation_set'
image_size_x = 360
image_size_y = 180
#Load the VGG model
vgg_conv = VGG16(weights='imagenet', include_top=False, input_shape=(image_size_x, image_size_y, 3))
# Freeze the layers except the last 4 layers
for layer in vgg_conv.layers[:-4]:
    layer.trainable = False

# Create the model
model = models.Sequential()

# Add the vgg convolutional base model
model.add(vgg_conv)

# Add new layers
model.add(layers.Flatten())
model.add(layers.Dense(1024, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(1, activation='sigmoid'))

train_datagen = ImageDataGenerator(
      rescale=1./255,
      rotation_range=20,
      width_shift_range=0.2,
      height_shift_range=0.2,
      horizontal_flip=True,
      fill_mode='nearest')

validation_datagen = ImageDataGenerator(rescale=1./255)

# Change the batchsize according to your system RAM
train_batchsize = 4
val_batchsize = 4

train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=(image_size_x, image_size_y),
        batch_size=train_batchsize,
        class_mode='binary')

validation_generator = validation_datagen.flow_from_directory(
        validation_dir,
        target_size=(image_size_x, image_size_y),
        batch_size=val_batchsize,
        class_mode='binary')
# Compile the model
model.compile(loss='binary_crossentropy',
              optimizer=optimizers.Adam(lr=0.1),
              metrics=['acc'])
# Train the model
history = model.fit_generator(
      train_generator,
      steps_per_epoch=train_generator.samples/train_generator.batch_size ,
      epochs=7,
      validation_data=validation_generator,
      validation_steps=validation_generator.samples/validation_generator.batch_size,
      verbose=1,
      callbacks=[tensorboard, early_stopper, model_check_point])

结果是

Using TensorFlow backend.
2019-02-03 14:46:06.520723: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
Found 2416 images belonging to 2 classes.
Found 608 images belonging to 2 classes.
Epoch 1/7
604/604 [==============================] - 2046s 3s/step - loss: 8.0208 - acc: 0.5008 - val_loss: 8.0590 - val_acc: 0.5000
Epoch 2/7
604/604 [==============================] - 1798s 3s/step - loss: 8.0055 - acc: 0.5033 - val_loss: 8.0590 - val_acc: 0.5000
Epoch 3/7
604/604 [==============================] - 2500s 4s/step - loss: 8.0054 - acc: 0.5033 - val_loss: 8.0590 - val_acc: 0.5000

我尝试使用另一个优化器(RMSprop)将学习率从0.0001提高到0.01,但是验证准确性仍然保持在0.5000。 这是模型摘要

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
inception_v3 (Model)         (None, 9, 4, 2048)        21802784
=================================================================
Total params: 21,802,784
Trainable params: 0
Non-trainable params: 21,802,784
_________________________________________________________________
None

1 个答案:

答案 0 :(得分:0)

VGG期望的预处理与rescale = 1/255.不同。您可以从同一模块导入适当的预处理输入函数,并将其提供给ImageDataGenerator

from keras.applications.vgg16 import preprocess_input
train_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

请注意,您应省略rescale参数,因为该参数由预处理功能处理。另请参见this answer

如果那没有帮助,我建议您尝试冻结整个卷积基础,只训练顶级分类器。卷积基数包含预训练的权重,而顶部分类器中的权重是随机初始化的。这可能会导致对预训练权重进行较大的梯度更新,从而使它们失去作用。此外,您没有太多数据,因此冻结卷积堆栈是一个好主意,无论它是否在此处引起问题。有关this answer遇到VGG问题的人的示例,请参见。

除此之外,您的模型定义看起来正确。