将多个MetaGraphDefs导入单个图并还原变量

时间:2019-01-07 20:51:55

标签: python tensorflow

TensorFlow: Restoring Multiple Graphs类似的问题,但使用的是tf.train.import_meta_graph()接口。

我的代码:

    with self.graph.as_default(), tf.device(device):
        with tf.Session(graph=self.graph, config=self.tf_config) as sess:

            # Add inherited graphs to CenterNet's graph.
            self.mm_saver = tf.train.import_meta_graph(self.maskmaker.model_ckpt + ".meta")
            self.dv_saver = tf.train.import_meta_graph(self.deepvar.model_ckpt + ".meta")

            # First saver can restore
            self.mm_saver.restore(sess, self.maskmaker.model_ckpt)
            # Second saver raises an exception
            self.dv_saver.restore(sess, self.deepvar.model_ckpt)

异常(没有回溯,这很长)。

NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
Key classifier/bias not found in checkpoint
     [[node save/RestoreV2 (defined at /home/markemus/dev/IHC/ihc/neuralnets.py:936)  = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

看起来dv_saver正在尝试还原图形上的所有变量,而不仅仅是其自身。失败的关键字“分类器/偏见”最初是mm图表的一部分。

如何将其限制为恢复自己的密钥?

1 个答案:

答案 0 :(得分:2)

解决了!保护程序将操作添加到图形中,并且由于两个保护程序都在同一个name_scope中,因此它们会相互干扰。您需要将对import_meta_graph的每个调用包装在其自己的name_scope中:

with self.graph.as_default(), tf.device(device):
    with tf.Session(graph=self.graph, config=self.tf_config) as sess:

        # Add inherited graphs to CenterNet's graph.
        with tf.name_scope(self.maskmaker.name):
            self.mm_saver = tf.train.import_meta_graph(self.maskmaker.model_ckpt + ".meta")
        with tf.name_scope(self.deepvar.name):
            self.dv_saver = tf.train.import_meta_graph(self.deepvar.model_ckpt + ".meta")

        # First saver can restore
        self.mm_saver.restore(sess, self.maskmaker.model_ckpt)
        # Second saver can also restore
        self.dv_saver.restore(sess, self.deepvar.model_ckpt)