Keras:如何保存模型或权重?

时间:2019-07-22 19:59:38

标签: python tensorflow keras save

很抱歉,这个问题似乎很简单。但是请阅读Keras的保存和恢复帮助页面:

https://www.tensorflow.org/beta/tutorials/keras/save_and_restore_models

我不知道如何在训练期间使用“ ModelCheckpoint”进行保存。帮助文件提到应该提供3个文件,我只能看到一个文件,MODEL.ckpt。

这是我的代码:

checkpoint_dir = FolderName + "/tmp/model.ckpt"
cp_callback = k.callbacks.ModelCheckpoint(checkpoint_dir,verbose=1,save_weights_only=True)    
parallel_model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),loss=my_cost_MSE, metrics=['accuracy])
    parallel _model.fit(image, annotation, epochs=epoch, 
    batch_size=batch_size, steps_per_epoch=10,
                                 validation_data=(image_val,annotation_val),validation_steps=num_batch_val,callbacks=callbacks_list)

此外,当我想通过以下方法训练负重时:

model = k.models.load_model(file_checkpoint)

我得到了错误:

"raise ValueError('Unknown ' + printable_module_name + ':' + object_name) 
ValueError: Unknown loss function:my_cost_MSE"

my-cost_MSE是我在培训中使用的成本函数。

2 个答案:

答案 0 :(得分:1)

keras有一个save命令。它保存了重建模型所需的所有细节。

(来自keras docs

from keras.models import load_model
model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'
del model  # deletes the existing model

# returns am identical compiled model
model = load_model('my_model.h5')

答案 1 :(得分:0)

首先,看起来您正在使用tf.keras(来自tensorflow)实现,而不是keras(来自keras-team / keras存储库)。请注意,在这种情况下,如tf.keras guide中所述:

  

保存模型的权重时,tf.keras默认为检查点   格式。传递save_format ='h5'以使用HDF5。

另一方面,请注意,添加回调ModelCheckpoint通常大致相当于调用model.save(...),因此这就是为什么您希望保存三个文件的原因(根据{{3 }}。

之所以不这样做,是因为通过使用选项save_weights_only=True,您仅节省了权重。大致等效于在每个时期结束时将对model.save的调用model.save_weights替换为model.load_weights。因此,您唯一要保存的文件就是具有权重的文件。

请注意,如果您只想存储权重,则需要预先加载模型(例如结构),然后调用model = MyModel(...) # Your model definition as used in training model.load_weights(file_checkpoint)

my_cost_MSE

请注意,在这种情况下,自定义定义(cp_callback = k.callbacks.ModelCheckpoint(checkpoint_dir,verbose=1,save_weights_only=False) parallel_model.compile( optimizer=tf.keras.optimizers.Adam(lr=learning_rate), loss=my_cost_MSE, metrics=['accuracy] ) # Training code here )不会有问题,因为您只是在加载模型权重。

另一种进行方法是存储整个模型并相应地加载它:

model = k.models.load_model(file_checkpoint, custom_objects={"my_cost_MSE": my_cost_MSE})

然后您可以通过以下方式加载它:

custom_objects

请注意,在后一种情况下,您需要指定swagger-blocks,因为需要对其定义进行反序列化。

相关问题