恢复Tensorflow模型并查看变量值

时间:2017-03-25 08:36:10

标签: tensorflow

我声明了一些表示权重和偏差的Tensorflow变量,并在保存它们之前更新了它们的值,如图所示:

#                # 5 x 5 x 5 patches, 1 channel, 32 features to compute.
weights = {'W_conv1':tf.Variable(tf.random_normal([3,3,3,1,32]), name='w_conv1'),
           #       5 x 5 x 5 patches, 32 channels, 64 features to compute.
           'W_conv2':tf.Variable(tf.random_normal([3,3,3,32,64]), name='w_conv2'),
           #                                  64 features
           'W_fc':tf.Variable(tf.random_normal([32448,1024]), name='w_fc'), #54080 = ceil(50/2/2) * ceil(50/2/2) * ceil(10/2/2) * 64
           #'W_fc':tf.Variable(tf.random_normal([54080,1024]), name='W_fc'), #54080 = ceil(50/2/2) * ceil(50/2/2) * ceil(20/2/2) * 64
           'out':tf.Variable(tf.random_normal([1024, n_classes]), name='w_out')}

biases = {'b_conv1':tf.Variable(tf.random_normal([32]), name='b_conv1'),
           'b_conv2':tf.Variable(tf.random_normal([64]), name='b_conv2'),
           'b_fc':tf.Variable(tf.random_normal([1024]), name='b_fc'),
           'out':tf.Variable(tf.random_normal([n_classes]), name='b_out')}

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    #some training code

    saver = tf.train.Saver()
    saver.save(sess, 'my-save-dir/my-model-10')

然后,我尝试恢复模型并访问变量,如下所示:

weights = {'W_conv1':tf.Variable(-1.0, validate_shape=False, name='w_conv1'),
           #       5 x 5 x 5 patches, 32 channels, 64 features to compute.
           'W_conv2':tf.Variable(-1.0, validate_shape=False, name='w_conv2'),
           #                                  64 features
           'W_fc':tf.Variable(-1.0, validate_shape=False, name='w_fc'), #54080 = ceil(50/2/2) * ceil(50/2/2) * ceil(10/2/2) * 64
           #'W_fc':tf.Variable(tf.random_normal([54080,1024]), name='W_fc'), #54080 = ceil(50/2/2) * ceil(50/2/2) * ceil(20/2/2) * 64
           'out':tf.Variable(-1.0, validate_shape=False, name='w_out')}

biases = {'b_conv1':tf.Variable(-1.0, validate_shape=False, name='b_conv1'),
           'b_conv2':tf.Variable(-1.0, validate_shape=False, name='b_conv2'),
           'b_fc':tf.Variable(-1.0, validate_shape=False, name='b_fc'),
           'out':tf.Variable(-1.0, validate_shape=False, name='b_out')}

with tf.Session() as sess:
    model_saver = tf.train.import_meta_graph('my-save-dir/my-model-10.meta')
    model_saver.restore(sess, "my-save-dir/my-model-10")
    print("Model restored.") 
    print('Initialized')
    print(sess.run(weights['W_conv1']))

然而,我得到了一个" FailedPreconditionError:尝试使用未初始化的值w_conv1"。请协助。

1 个答案:

答案 0 :(得分:1)

以下是您的第二个代码段中发生的情况:首先创建所有变量w_conv1b_out,因此默认图表会填充相应的节点。然后,您再次调用import_meta_graph(..),其中默认图表将填充您存储在第一个代码段中的模型中的所有节点。但是,对于它尝试加载的每个节点,已经存在另一个具有相同名称的节点(因为您创建了它"手工"就在之前)。我不知道在这种情况下内部会发生什么,但是在调用tf.global_variables()之后查看import_meta_graph(..)的输出显示现在每个节点都存在两次,名称完全相同。因此,恢复可能是未定义的,它可能只恢复一半变量,这就是您看到此错误的原因。

所以,你有两个可能来解决这个问题:

1)不要使用import_from_metagraph

weights = {'W_conv1':tf.Variable(tf.random_normal([3,3,3,1,32]), name='w_conv1'),
           #       5 x 5 x 5 patches, 32 channels, 64 features to compute.
           'W_conv2':tf.Variable(tf.random_normal([3,3,3,32,64]), name='w_conv2'),
           #                                  64 features
           'W_fc':tf.Variable(tf.random_normal([32448,1024]), name='w_fc'), #54080 = ceil(50/2/2) * ceil(50/2/2) * ceil(10/2/2) * 64
           #'W_fc':tf.Variable(tf.random_normal([54080,1024]), name='W_fc'), #54080 = ceil(50/2/2) * ceil(50/2/2) * ceil(20/2/2) * 64
           'out':tf.Variable(tf.random_normal([1024, n_classes]), name='w_out')}

biases = {'b_conv1':tf.Variable(tf.random_normal([32]), name='b_conv1'),
           'b_conv2':tf.Variable(tf.random_normal([64]), name='b_conv2'),
           'b_fc':tf.Variable(tf.random_normal([1024]), name='b_fc'),
           'out':tf.Variable(tf.random_normal([n_classes]), name='b_out')}

with tf.Session() as sess:
    model_saver = tf.train.Saver()
    model_saver.restore(sess, "my-save-dir/my-model-10")
    print("Model restored.")
    print('Initialized')
    print(sess.run(weights['W_conv1']))

2)使用import_from_metagraph但不要手动重新创建图表

所以,就这样:

with tf.Session() as sess:
    model_saver = tf.train.import_meta_graph('my-save-dir/my-model-10.meta')
    model_saver.restore(sess, "my-save-dir/my-model-10")
    print("Model restored.") 
    print('Initialized')
    print(sess.run(tf.get_default_graph().get_tensor_by_name('w_conv1:0')))

请注意,在这种情况下,您需要更改在' w_conv1'中检索值的方式。 (最后一行)。您可以使用get_tensor_by_name()而不是调用tf.get_variable(),但为了实现此目的,您必须使用tf.get_variable()创建变量。查看此帖子了解详情:TensorFlow: getting variable by name