尽管加载了最佳权重,但EarlyStopping并没有停止模型

时间:2019-07-31 16:25:36

标签: python tensorflow keras callback neural-network

我正在使用tf.keras运行图像分类程序,并且试图确定精度和val_accuracy的误差曲线。但是,当我添加一个提前停止的回调时,该模型即使经过耐心阈值也不会停止训练。

我尝试更改早期停止监视值,但发现它正在监视正确的值,因为我在tensorflow.python.keras.callbacks.py文件中达到了这一点

    else:
      self.wait += 1
      if self.wait >= self.patience:
        self.stopped_epoch = epoch
        self.model.stop_training = True
        if self.restore_best_weights:
          if self.verbose > 0:
            print('Restoring model weights from the end of the best epoch.')
          self.model.set_weights(self.best_weights)

我的输出显示了打印行,因此我清楚地点击了self.model.stop_training = True行,但是我的模型仍在继续。这是尽管已达到早期停止点但仍在运行的示例。您可以在第9个阶段结束时看到,它“从最佳阶段结束时恢复模型权重”。但是,此后它继续运行第十个纪元。

Epoch 1/10
 9/10 [==========================>...] - ETA: 1s - loss: 1.1147 - categorical_accuracy: 0.6058
Epoch 00001: val_categorical_accuracy improved from -inf to 0.25000, saving model to /home/chale/ml_classify/data/best.weights.hdf5
10/10 [==============================] - 29s 3s/step - loss: 1.0876 - categorical_accuracy: 0.6013 - val_loss: 60.9186 - val_categorical_accuracy: 0.2500
Epoch 2/10
 9/10 [==========================>...] - ETA: 0s - loss: 1.2638 - categorical_accuracy: 0.5694
Epoch 00002: val_categorical_accuracy did not improve from 0.25000
10/10 [==============================] - 7s 747ms/step - loss: 1.2278 - categorical_accuracy: 0.5750 - val_loss: 147.1493 - val_categorical_accuracy: 0.2396
Epoch 3/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.5760 - categorical_accuracy: 0.8321
Epoch 00003: val_categorical_accuracy improved from 0.25000 to 0.26042, saving model to /home/chale/ml_classify/data/best.weights.hdf5
10/10 [==============================] - 10s 972ms/step - loss: 0.5569 - categorical_accuracy: 0.8288 - val_loss: 21.9862 - val_categorical_accuracy: 0.2604
Epoch 4/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.4401 - categorical_accuracy: 0.8681
Epoch 00004: val_categorical_accuracy improved from 0.26042 to 0.30208, saving model to /home/chale/ml_classify/data/best.weights.hdf5
10/10 [==============================] - 9s 897ms/step - loss: 0.4383 - categorical_accuracy: 0.8687 - val_loss: 146.7307 - val_categorical_accuracy: 0.3021
Epoch 5/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.4499 - categorical_accuracy: 0.8394
Epoch 00005: val_categorical_accuracy did not improve from 0.30208
10/10 [==============================] - 7s 714ms/step - loss: 0.4218 - categorical_accuracy: 0.8493 - val_loss: 71.2797 - val_categorical_accuracy: 0.1354
Epoch 6/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.5760 - categorical_accuracy: 0.8194
Epoch 00006: val_categorical_accuracy improved from 0.30208 to 0.38542, saving model to /home/chale/ml_classify/data/best.weights.hdf5
10/10 [==============================] - 10s 974ms/step - loss: 0.5342 - categorical_accuracy: 0.8313 - val_loss: 13.7430 - val_categorical_accuracy: 0.3854
Epoch 7/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.3852 - categorical_accuracy: 0.9000
Epoch 00007: val_categorical_accuracy did not improve from 0.38542
10/10 [==============================] - 6s 619ms/step - loss: 0.4190 - categorical_accuracy: 0.8973 - val_loss: 164.1882 - val_categorical_accuracy: 0.2708
Epoch 8/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.3401 - categorical_accuracy: 0.8905
Epoch 00008: val_categorical_accuracy did not improve from 0.38542
10/10 [==============================] - 7s 723ms/step - loss: 0.3745 - categorical_accuracy: 0.8889 - val_loss: 315.0913 - val_categorical_accuracy: 0.2708
Epoch 9/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.2713 - categorical_accuracy: 0.8958
Epoch 00009: val_categorical_accuracy did not improve from 0.38542
Restoring model weights from the end of the best epoch.
10/10 [==============================] - 9s 853ms/step - loss: 0.2550 - categorical_accuracy: 0.9062 - val_loss: 340.6383 - val_categorical_accuracy: 0.2708
Epoch 10/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.4282 - categorical_accuracy: 0.8759
Epoch 00010: val_categorical_accuracy did not improve from 0.38542
Restoring model weights from the end of the best epoch.
10/10 [==============================] - 8s 795ms/step - loss: 0.4260 - categorical_accuracy: 0.8758 - val_loss: 4.5791 - val_categorical_accuracy: 0.2500
Epoch 00010: early stopping

这是该问题的主要代码

        if loss == 'categorical_crossentropy':
            monitor = 'val_categorical_accuracy'
        else:
            monitor = 'val_binary_accuracy'

        early_stop = EarlyStopping(monitor=monitor, patience=3, verbose=1, restore_best_weights=True)

        checkpoint_path = '{}/best.weights.hdf5'.format(output_dir)
        best_model = ModelCheckpoint(checkpoint_path, monitor=monitor, verbose=1, save_best_only=True, mode='max')

        # reduce_lr = tensorflow.python.keras.callbacks.ReduceLROnPlateau()

        m = Metrics(labels=labels, val_data=validation_generator, batch_size=batch_size)
        history = model.fit_generator(train_generator,
                                           steps_per_epoch=steps_per_epoch,
                                           epochs=epochs,
                                           use_multiprocessing=True,
                                           validation_data=validation_generator,
                                           validation_steps=validation_steps,
                                           callbacks=[tensorboard, best_model, early_stop,
                                                      WandbCallback(data_type="image", 
                                                                    validation_data=validation_generator,
                                                                    labels=labels)])# , schedule])


        return history

整个代码位于https://github.com/AtlasHale/ml_classify

我希望,当尽早停止忍耐时,剩下的时期将不会持续。如果尽早停止,返回的模型将是权重中的最佳模型。但是,该模型不是最佳模型,它是最后一个模型。我想返回最佳模型,并在发生早期停止操作后跳过纪元。

编辑:复制并添加一些打印到EarlyStopping类之后,我发现了

Epoch 6/10
8/9 [=========================>....] - ETA: 0s - loss: 0.5594 - categorical_accuracy: 0.9062
Epoch 00006: val_categorical_accuracy did not improve from 0.27083
3 epochs since improvement to val_categorical_accuracy
Model stop_training state previously: False
Model stop_training state now: True
Restoring model weights from the end of the best epoch.
9/9 [==============================] - 8s 855ms/step - loss: 0.5511 - categorical_accuracy: 0.8889 - val_loss: 466.1678 - val_categorical_accuracy: 0.2292
Epoch 7/10
8/9 [=========================>....] - ETA: 0s - loss: 0.3544 - categorical_accuracy: 0.8992
Epoch 00007: val_categorical_accuracy did not improve from 0.27083
4 epochs since improvement to val_categorical_accuracy
Model stop_training state previously: False
Model stop_training state now: True
Restoring model weights from the end of the best epoch.

将self.model.stop_training设置为True时,它似乎不会持续到下一个时期结束。如此看来,回调中发生的事情并未应用于模型?我不确定。欢迎任何见识。

1 个答案:

答案 0 :(得分:0)

我有同样的问题。我设置了epochs=100patience=5,但是每次我接受100个培训时期。

我发现,这些人的EarlyStopping行为正确: https://lambdalabs.com/blog/tensorflow-2-0-tutorial-04-early-stopping/

主要提示:使用min_delta参数。在这种情况下,如果在由patience参数设置的纪元数之后,先前的最佳结果没有任何改善,培训将停止。