重新加载张量流模型

时间:2016-04-23 18:26:23

标签: tensorflow

我有两个单独的张量流程,一个用于训练模型并用tensorflow.python.client.graph_util.convert_variables_to_constants写出graph_def,另一个用tensorflow.import_graph_def读取graph_def。我希望第二个进程在第一个进程更新时定期重新加载graph_def。不幸的是,即使我关闭当前会话并创建一个新会话,似乎每次我读取graph_def仍会使用旧的。我也尝试用import_graph_def call包裹sess.graph.as_default(),但无济于事。这是我当前的graph_def加载代码:

if self.sess is not None:
    self.sess.close()
self.sess = tf.Session()

graph_def = tf.GraphDef()
with open(self.graph_path, 'rb') as f:
    graph_def.ParseFromString(f.read())
with self.sess.graph.as_default():
    tf.import_graph_def(graph_def, name='')

1 个答案:

答案 0 :(得分:5)

这里的问题是,当您创建一个没有参数的tf.Session时,它会使用当前默认图。假设您没有在代码中的任何其他地方创建tf.Graph,您将获得在流程启动时创建的全局默认图,并在所有会话之间共享。因此,with self.sess.graph.as_default():无效。

很难从您在问题中显示的代码段推荐一个新结构 - 特别是,我不知道您是如何创建上一个图表,或者类结构是什么 - 但是一个可能性是用以下内容替换self.sess = tf.Session()

self.sess = tf.Session(graph=tf.Graph())  # Creates a new graph for the session.

现在with self.sess.graph.as_default():将使用为会话创建的图表,您的程序应具有预期的效果。

稍微优先(对我而言)替代方案是明确构建图形:

with tf.Graph().as_default() as imported_graph:
    tf.import_graph_def(graph_def, ...)

sess = tf.Session(graph=imported_graph)