TF如何从同一变量恢复两个变量

时间:2020-02-12 03:33:04

标签: tensorflow restore

我保存了一个模型,现在我试图在两个分支中还原它,如下所示:

enter image description here

我写了这段代码,它引发了ValueError: The same saveable will be restored with two names。 如何从同一个变量恢复两个变量?

restore_variables = {}
for varr in tf.global_variables()
    if varr.op.name in checkpoint_variables:
        restore_variables[varr.op.name.split("_red")[0]] = varr           
        restore_variables[varr.op.name.split("_blue")[0]] = varr
init_saver = tf.train.Saver(restore_variables, max_to_keep=0)

2 个答案:

答案 0 :(得分:1)

在TF 1.15上测试

基本上,该错误是说它正在restore_variables字典中找到对同一变量的多个引用。解决方法很简单。使用tf.Variable(varr)为其中一个引用创建变量的副本,如下所示。

我认为可以安全地假设您不是要在此处查找对同一变量的多个引用,而是要查找两个单独的变量。 (我假设这样做是因为,如果您想多次使用同一变量,则可以多次使用单个变量。)

with tf.Session() as sess:
    saver.restore(sess, './vars/vars.ckpt-0')
    restore_variables = {}
    checkpoint_variables=['b']
    for varr in tf.global_variables():
        if varr.op.name in checkpoint_variables:
            restore_variables[varr.op.name.split("_red")[0]] = varr           
            restore_variables[varr.op.name.split("_blue")[0]] = tf.Variable(varr)
    print(restore_variables)
    init_saver = tf.train.Saver(restore_variables, max_to_keep=0)

在下面,您可以找到一个完整的代码,使用一个玩具示例来复制问题。本质上,我们有两个变量ab,除此之外,我们正在创建b_redb_blue变量。

# Saving the variables

import tensorflow as tf
import numpy as np
a = tf.placeholder(shape=[None, 3], dtype=tf.float64)
w1 = tf.Variable(np.random.normal(size=[3,2]), name='a')
out = tf.matmul(a, w1)
w2 = tf.Variable(np.random.normal(size=[2,3]), name='b')
out = tf.matmul(out, w2)

saver = tf.train.Saver([w1, w2])

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  saved_path = saver.save(sess, './vars/vars.ckpt', global_step=0)
# Restoring the variables

with tf.Session() as sess:
    saver.restore(sess, './vars/vars.ckpt-0')
    restore_variables = {}
    checkpoint_variables=['b']
    for varr in tf.global_variables():
        if varr.op.name in checkpoint_variables:
            restore_variables[varr.op.name+"_red"] = varr  
            # Fixing the issue: Instead of varr, do tf.Variable(varr)
            restore_variables[varr.op.name+"_blue"] = varr
    print(restore_variables)
    init_saver = tf.train.Saver(restore_variables, max_to_keep=0)

答案 1 :(得分:0)

我可能无法正确理解问题,但是您不能只创建两个保护程序对象吗?像这样:

import tensorflow as tf

# Make checkpoint
with tf.Graph().as_default(), tf.Session() as sess:
    a = tf.Variable([1., 2.], name='a')
    sess.run(a.initializer)
    b = tf.Variable([3., 4., 5.], name='b')
    sess.run(b.initializer)
    saver = tf.train.Saver([a, b])
    saver.save(sess, 'tmp/vars.ckpt')

# Restore checkpoint
with tf.Graph().as_default(), tf.Session() as sess:
    # Red
    a_red = tf.Variable([0., 0.], name='a_red')
    b_red = tf.Variable([0., 0., 0.], name='b_red')
    saver_red = tf.train.Saver({'a': a_red, 'b': b_red})
    saver_red.restore(sess, 'tmp1/vars.ckpt')
    print(a_red.eval())
    # [1. 2.]
    print(b_red.eval())
    # [3. 4. 5.]

    # Blue
    a_blue = tf.Variable([0., 0.], name='a_blue')
    b_blue = tf.Variable([0., 0., 0.], name='b_blue')
    saver_blue = tf.train.Saver({'a': a_blue, 'b': b_blue})
    saver_blue.restore(sess, 'tmp/vars.ckpt')
    print(a_blue.eval())
    # [1. 2.]
    print(b_blue.eval())
    # [3. 4. 5.]
相关问题