每N个周期结束时保存模型权重

时间:2018-07-05 08:12:22

标签: python tensorflow callback keras

我正在训练神经网络,并希望每N个epoch节省模型权重以进行预测。我提出了此代码草案,它受@grovina的回复here的启发。你能提出建议吗? 预先感谢。

from keras.callbacks import Callback

class WeightsSaver(Callback):
    def __init__(self, model, N):
        self.model = model
        self.N = N
        self.epoch = 0

    def on_batch_end(self, epoch, logs={}):
        if self.epoch % self.N == 0:
            name = 'weights%08d.h5' % self.epoch
            self.model.save_weights(name)
        self.epoch += 1

然后将其添加到fit调用中:每5个时代保存一次权重:

model.fit(X_train, Y_train, callbacks=[WeightsSaver(model, 5)])

2 个答案:

答案 0 :(得分:7)

您不需要为回调传递模型。它已经可以通过它的超级访问模型了。因此,删除__init__(..., model, ...)自变量和self.model = model。无论如何,您都应该能够通过self.model访问当前模型。您还将它保存在每个批次的末尾,这不是您想要的,您可能希望它为on_epoch_end

但是无论如何,您可以通过幼稚的modelcheckpoint callback完成操作。您无需编写自定义代码。您可以按以下方式使用它;

mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5', 
                                     save_weights_only=True, period=5)
model.fit(X_train, Y_train, callbacks=[mc])

答案 1 :(得分:1)

您应该在on_epoch_end上实现,而不是在on_batch_end上实现。并且将模型作为__init__的参数传递也是多余的。

from keras.callbacks import Callback
class WeightsSaver(Callback):
  def __init__(self, N):
    self.N = N
    self.epoch = 0

  def on_epoch_end(self, epoch, logs={}):
    if self.epoch % self.N == 0:
      name = 'weights%08d.h5' % self.epoch
      self.model.save_weights(name)
    self.epoch += 1