使用fit_generator的Keras传输学习Resnet50获得了较高的acc,但存在val_acc较低的问题

时间:2018-11-03 18:43:42

标签: python keras resnet transfer-learning

我正在使用Resnet50模型进行传递学习,总共使用了100,000个图像(共20个场景)(MIT Place365数据集)。我只训练了最后的160层(由于内存限制)。问题是我的准确性很高,但验证准确性却极低,我认为这可能是一个过拟合的问题,但我不知道如何解决。如果有人可以给我一些有关解决val_acc低问题的建议,我将非常感谢,非常感谢。 我的代码如下:

V1 = np.load("C:/Users/Desktop/numpydataKeras_20_val/imgonehot_val_500.npy")
V2 = np.load("C:/Users/Desktop/numpydataKeras_20_val/labelonehot_val_500.npy") 


net = keras.applications.resnet50.ResNet50(include_top=False, weights='imagenet', input_tensor=None, input_shape=(224, 224, 3))

x = net.output
x = Flatten()(x)
x = Dense(128)(x)
x = Activation('relu')(x)
x = Dropout(0.5)(x)
output_layer = Dense(20, activation='softmax', name='softmax')(x)
net_final = Model(inputs=net.input, outputs=output_layer)

for layer in net_final.layers[:-160]:
    layer.trainable = False
for layer in net_final.layers[-160:]:
    layer.trainable = True

net_final.compile(Adam(lr=.00002122), loss='categorical_crossentropy', metrics=['accuracy'])

def data_generator():
    n = 100000
    Num_batch = 100000/100
    arr = np.arange(1000)
    np.random.shuffle(arr)
    while (True):
        for i in arr:
            seed01 = random.randint(0,1000000)

            X_batch  = np.load( "C:/Users/Desktop/numpydataKeras/imgonehot_"+str((i+1)*100)+".npy" )
            np.random.seed(seed01)
            np.random.shuffle(X_batch)

            y_batch = np.load( "C:/Users/Desktop/numpydataKeras/labelonehot_"+str((i+1)*100)+".npy" )
            np.random.seed(seed01)
            np.random.shuffle(y_batch)

            yield X_batch, y_batch

weights_file = 'C:/Users/Desktop/Transfer_learning_resnet50_fit_generator_02s.h5'
early_stopping = EarlyStopping(monitor='val_acc', patience=5, mode='auto', verbose=2)
model_checkpoint = ModelCheckpoint(weights_file, monitor='val_acc', save_best_only=True, verbose=2)
callbacks = [early_stopping, model_checkpoint]

model_fit = net_final.fit_generator(
    data_generator(),
    steps_per_epoch=1000,
    epochs=5,
    validation_data=(V1, V2),
    callbacks=callbacks,
    verbose=1,
    pickle_safe=False)

以下是打印输出:

Epoch 1/5
1000/1000 [==============================] - 3481s 3s/step - loss: 1.7917 - acc: 0.4757 - val_loss: 3.5872 - val_acc: 0.0560

Epoch 00001: val_acc improved from -inf to 0.05600, saving model to C:/Users/Desktop/Transfer_learning_resnet50_fit_generator_02s.h5
Epoch 2/5
1000/1000 [==============================] - 4884s 5s/step - loss: 1.1287 - acc: 0.6595 - val_loss: 4.2113 - val_acc: 0.0520

Epoch 00002: val_acc did not improve from 0.05600
Epoch 3/5
1000/1000 [==============================] - 4964s 5s/step - loss: 0.8033 - acc: 0.7464 - val_loss: 4.9595 - val_acc: 0.0520

Epoch 00003: val_acc did not improve from 0.05600
Epoch 4/5
1000/1000 [==============================] - 4961s 5s/step - loss: 0.5677 - acc: 0.8143 - val_loss: 4.5484 - val_acc: 0.0520

Epoch 00004: val_acc did not improve from 0.05600
Epoch 5/5
1000/1000 [==============================] - 4928s 5s/step - loss: 0.3999 - acc: 0.8672 - val_loss: 4.6155 - val_acc: 0.0400

Epoch 00005: val_acc did not improve from 0.05600

1 个答案:

答案 0 :(得分:0)

https://github.com/keras-team/keras/issues/9214#issuecomment-397916155之后,批处理规范化应该是可以训练的。

以下代码可以替换您设置/取消可训练图层的循环:

<div class="btn-group">
        <button 
            type="button" 
            class="btn btn-default dropdown-toggle" 
            data-toggle="dropdown" 
            aria-haspopup="true" 
            aria-expanded="false">
                Select Invoice <span class="caret"></span>
        </button>
        <ul class="dropdown-menu">
            @foreach (var invoice in Model.Invoices)
            {
                <li><a href="@Url.Action(
                                 "WorkSummary",
                                 "Invoices",
                                 new {id = invoice.Id},
                                 null)">
                    @invoice.Name
                </a></li>
            }
        </ul>
    </div>

根据我自己的数据,我需要减小批量大小以避免OOM,现在我有了:

id

警告,它可能会影响准确性,因此您必须冻结模型以避免奇怪的推断。但这似乎是对我有用的唯一方法。

另一条评论https://github.com/keras-team/keras/issues/9214#issuecomment-422490253仅检查层名称以使其可批处理(如果它是批量归一化的话),但是它对我没有任何改变。也许可以为您的数据集提供帮助。