在Tensorflow

时间:2018-02-12 23:54:09

标签: python tensorflow

到目前为止,我在Tensorflow中使用了保存和加载检查点,仅用于加载最后一个检查点。通常我使用的代码是这样的:

ckpt = tf.train.get_checkpoint_state(load_dir)
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(session, ckpt.model_checkpoint_path)
else:
    tf.gfile.DeleteRecursively(load_dir)
    tf.gfile.MakeDirs(load_dir)

但是,在我最近的实验中,我每1000次迭代都会保存一个检查点,并且我想在所有检查点上运行评估脚本,例如:显示不同的验证指标如何进展。有没有简单的方法来获取Tensorflow中的所有检查点,或者我只需要使用os相应地循环所有名称?

1 个答案:

答案 0 :(得分:3)

代码段中的ckpt对象是CheckpointState协议缓冲区。您可以使用以下内容迭代所有这些路径,而不是访问最新的模型路径(ckpt.model_checkpoint_path):

for model_path in ckpt.all_model_checkpoint_paths:
    saver.restore(session, model_path)
    # Do the evaluation using the restored model
相关问题