Tensorflow变量重用

时间:2017-07-23 09:43:26

标签: python variables tensorflow

我已经构建了LSTM模型。理想情况下,我希望稍后使用重用变量来定义测试LSTM模型。

with tf.variable_scope('lstm_model') as scope:
    # Define LSTM Model
    lstm_model = LSTM_Model(rnn_size, batch_size, learning_rate,
                     training_seq_len, vocab_size)
    scope.reuse_variables()
    test_lstm_model = LSTM_Model(rnn_size, batch_size, learning_rate,
                     training_seq_len, vocab_size, infer=True)

上面的代码给我一个错误

Variable lstm_model/lstm_vars/W already exists, disallowed. Did you mean to set reuse=True in VarScope? 

如果我将reuse = True设置为如下面的代码块所示

with tf.variable_scope('lstm_model', reuse=True) as scope:

我收到了不同的错误

Variable lstm_model/lstm_model/lstm_vars/W/Adam/ does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?

作为参考,我在下面附上了相关的型号代码。 LSTM模型中的相应部分,我有权重

with tf.variable_scope('lstm_vars'):
    # Softmax Output Weights
    W = tf.get_variable('W', [self.rnn_size, self.vocab_size], tf.float32, tf.random_normal_initializer())

我有Adam优化器的相应部分:

optimizer = tf.train.AdamOptimizer(self.learning_rate)

2 个答案:

答案 0 :(得分:7)

似乎不是:

with tf.variable_scope('lstm_model') as scope:
    # Define LSTM Model
    lstm_model = LSTM_Model(rnn_size, batch_size, learning_rate,
                     training_seq_len, vocab_size)
    scope.reuse_variables()    
    test_lstm_model = LSTM_Model(rnn_size, batch_size, learning_rate,
                     training_seq_len, vocab_size, infer_sample=True)

这解决了问题

# Define LSTM Model
lstm_model = LSTM_Model(rnn_size, batch_size, learning_rate,
                        training_seq_len, vocab_size)

# Tell TensorFlow we are reusing the scope for the testing
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
    test_lstm_model = LSTM_Model(rnn_size, batch_size, learning_rate,
                                 training_seq_len, vocab_size, infer_sample=True)

答案 1 :(得分:4)

如果您使用一个变量两次(或更多次),则应首次使用with tf.variable_scope('scope_name', reuse=False):,然后使用with tf.variable_scope('scope_name', reuse=True):

或者您可以使用方法tf.variable_scope.reuse_variables()

with tf.variable_scope("foo") as scope:
    v = tf.get_variable("v", [1])
    scope.reuse_variables()
    v1 = tf.get_variable("v", [1])
上面代码中的

vv1是相同的变量。