张量流图中恢复变量的输出错误

时间:2018-07-02 13:31:28

标签: tensorflow loading

我目前正在研究变量的保存和恢复。为此,我创建了两个脚本。其中一个保存一个简单的图,而另一个保存它。这里是保存图形的测试脚本:

import tensorflow as tf

a = tf.Variable(3.0, name='a')
b = tf.Variable(5.0, name='b')

b = tf.assign_add(b, a)

n_steps = 5

global_step = tf.Variable(0, name='global_step', trainable=False)

saver = tf.train.Saver()

with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())

    for step in range(n_steps):
        print(sess.run(b))

        global_step.assign_add(1).eval()
        print(global_step.eval())

        saver.save(sess, './my_test_model', global_step=global_step)

基本上,我想运行5次循环,每次这样做,我都将a添加到b中。我还想通过global_step跟踪步骤数。这按预期工作。输出为:

8.0     # value of b
1       # step
11.0
2
14.0
3
17.0
4
20.0
5

现在还原变量时,我尝试获取所有三个变量。脚本是:

import tensorflow as tf

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

# List ALL tensors.
print_tensors_in_checkpoint_file(tf.train.latest_checkpoint('./'), all_tensors=True, tensor_name='')

tf.reset_default_graph()

a = tf.get_variable('a', shape=[])
b = tf.get_variable('b', shape=[])
global_step = tf.get_variable('global_step', shape=[])

saver = tf.train.Saver()

with tf.Session() as sess:

    ckpt = tf.train.latest_checkpoint('./')
    if ckpt:
        print(ckpt)

        saver.restore(sess, ckpt)

    else:
        print('Nothing restored')

    print(a.eval())
    print(b.eval())
    print(global_step.eval())

它的输出是

tensor_name:  a
3.0
tensor_name:  b
20.0
tensor_name:  global_step
5
./my_test_model-5
INFO:tensorflow:Restoring parameters from ./my_test_model-5
3.0
20.0
7e-45

如何将global_step的值正确存储在检查点中,但是经评估,我得到了这么小的 7e-45 ?另外,还原后,我似乎无法定义任何其他变量,因为它声明无法在检查点中找到该变量。例如,如何定义变量并将其添加到已还原图的b

谢谢您的帮助!

1 个答案:

答案 0 :(得分:2)

TF文档似乎并没有对此做充分记录,但是您应该为global_step变量指定dtype。

不正确

global_step = tf.get_variable('global_step', shape=[], dtype=tf.float32) 结果为global_step=7e-5。默认情况下,该类型假定为dtf.float32。

正确

global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32) 产生global_step=5