会话恢复后,get_variable()不起作用

时间:2017-03-13 17:08:14

标签: tensorflow

我尝试恢复会话并调用get_variable()来获取类型的对象 tf.Variable(根据this answer)。 它无法找到变量。重现案例的最小例子是 如下。

首先,创建一个变量并保存会话。

import tensorflow as tf

var = tf.Variable(101)

with tf.Session() as sess:
    with tf.variable_scope(''):
        scoped_var = tf.get_variable('scoped_var', [])

    with tf.variable_scope('', reuse=True):
        new_scoped_var = tf.get_variable('scoped_var', [])

    assert scoped_var is new_scoped_var
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    print(sess.run(scoped_var))
    saver.save(sess, 'data/sess')

get_variables范围内的reuse=True工作正常。 然后,从文件中恢复会话并尝试获取变量。

import tensorflow as tf

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('data/sess.meta')
    saver.restore(sess, 'data/sess')

    for v in tf.get_collection('variables'):
        print(v.name)

    print(tf.get_collection(("__variable_store",)))
    # Oops, it's empty!

    with tf.variable_scope('', reuse=True):
        # the next line fails
        new_scoped_var = tf.get_variable('scoped_var', [])

    print("new_scoped_var: ", new_scoped_var)

输出:

Variable:0
scoped_var:0
[]
Traceback (most recent call last):
...
ValueError: Variable scoped_var does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?

我们可以看到,get_variable()无法找到变量。和 ("__variable_store",)内部使用的get_variable()集合, 是空的。

为什么get_variable会失败?

1 个答案:

答案 0 :(得分:1)

您可以试试这个,而不是处理元图(如果你想修改图表以及它是如何加载的话,这可能会有所帮助)。

import tensorflow as tf

with tf.Session() as sess:
  with tf.variable_scope(''):
    scoped_var = tf.get_variable('scoped_var', [])

  with tf.variable_scope('', reuse=True):
    new_scoped_var = tf.get_variable('scoped_var', [])

  assert scoped_var is new_scoped_var
  saver = tf.train.Saver()
  path = tf.train.get_checkpoint_state('data/sess')
  if path is not None:
    saver.restore(sess, path.model_checkpoint_path)
  else:
    sess.run(tf.global_variables_initializer())

  print(sess.run(scoped_var))
  saver.save(sess, 'data/sess')

  #now continue to use as you normally would with a restored model

主要区别在于您在调用saver.restore

之前设置了模型