恢复变量时出错

时间:2018-07-03 16:01:21

标签: tensorflow loading

我偶然发现了一个我无法解决的错误。我正在尝试做以下事情:

我想训练一个(虚拟)模型,该模型在每次迭代时将a添加到b。完成后,我想将变量保存为检查点。我第一次运行它,它将从头开始构建模型。每次我重新运行模型时,都应从最后一个检查点开始,然后再次进行添加。因此,我从.meta文件加载了完整的图形。全局step变量用于跟踪我已训练的步骤总数。

import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

# List ALL tensors.
print_tensors_in_checkpoint_file(tf.train.latest_checkpoint('./'), all_tensors=True, tensor_name='')

tf.reset_default_graph()

global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, initializer=tf.constant_initializer(0), trainable=False)

def model(a, b):
    b = tf.assign_add(b, a)
    return b

with tf.Session() as sess:

    ckpt = tf.train.latest_checkpoint('./')
    if ckpt:
        saver = tf.train.import_meta_graph('./my_test_model-1.meta')
        saver.restore(sess, ckpt)

    else:
        a = tf.Variable(3.0, name='a')
        b = tf.Variable(5.0, name='b')

        b = model(a, b)

        ### before EDIT
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        ###

        ### after EDIT
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        ###

    for step in range(5):
        global_step.assign_add(1).eval()
        print(global_step.eval())
        print(b.eval())

        saver.save(sess, './my_test_model', global_step=global_step)

脚本首次运行正常,输出如下:

1        # step
8.0      # value of b
2
11.0
3
14.0
4
17.0
5
20.0

第二次运行程序时,得到以下输出:

tensor_name:  a
3.0
tensor_name:  b
20.0
tensor_name:  global_step
0
tensor_name:  global_step_1
5

INFO:tensorflow:Restoring parameters from ./my_test_model-5
  

回溯(最近一次通话最近):... FailedPreconditionError:   尝试使用未初始化的值global_step [[Node:   AssignAdd_2 = AssignAdd [T = DT_INT32,use_locking = false,   _device =“ / job:localhost / replica:0 / task:0 / device:CPU:0”](global_step,AssignAdd_2 / value)]] ...

第一次,很明显,当我为所有变量运行初始化程序时,它不会引发错误。但是我认为恢复模型算是某种初始化吗?我真的不能为这个概念而烦恼。我还尝试在定义global_stepa之后定义b,但这在首次加载时导致另一个错误:

  

ValueError:无法使用默认会话评估张量:   张量图与会话图不同。通过显式   与eval(session=sess)的会话。   错误是指增加global_stepglobal_step.assign_add(1).eval())的行。

我在做什么错?我应该在哪里定义变量?

感谢您对这个问题的任何帮助!感谢您阅读本文。

编辑: 感谢@Diana,前提条件错误消失了。不幸的是,发生了另一个错误。每当运行带有加载检查点的脚本时,都会引发名称错误:

  

NameError:未定义名称“ global_step”。

变量“ b”也会发生这种情况。恢复检查点时不应该加载该名称吗?当我检查检查点文件中的张量时,这些张量似乎具有正确的名称和值。

1 个答案:

答案 0 :(得分:0)

您应在初始化后声明保护程序。否则,您将无法保存任何价值。由于保护者不知道。