Tensorflow

时间:2017-01-10 20:03:12

标签: python tensorflow

def biLSTM(data, n_steps):


    n_hidden= 24
    data = tf.transpose(data, [1, 0, 2])
    # Reshape to (n_steps*batch_size, n_input)
    data = tf.reshape(data, [-1, 300])
    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
    data = tf.split(0, n_steps, data)    

    lstm_fw_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
    # Backward direction cell
    lstm_bw_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)

    outputs, _, _ = tf.nn.bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, data, dtype=tf.float32)


    return outputs, n_hidden

在我的代码中,我调用此函数两次以创建2个双向LSTM。然后我遇到了重用变量的问题。

  

ValueError:变量lstm / BiRNN_FW / BasicLSTMCell / Linear / Matrix   已经存在,不允许。你的意思是设置reuse = True in   VarScope?

为了解决这个问题,我在with tf.variable_scope('lstm', reuse=True) as scope:

中的函数中添加了LSTM定义

这导致了一个新问题

  

ValueError:变量lstm / BiRNN_FW / BasicLSTMCell / Linear / Matrix   不存在,不允许。你的意思是在VarScope中设置reuse = None吗?

请帮助解决此问题。

2 个答案:

答案 0 :(得分:7)

创建BasicLSTMCell()时,它会创建所有必需的权重和偏差,以实现LSTM单元格。所有这些变量都自动分配名称。如果您在同一范围内多次调用该函数,则会得到错误。由于您的问题似乎表明您要创建两个单独的LSTM单元格,因此您不希望重用这些变量,但您确实希望在不同的范围内创建它们。您可以通过两种不同的方式执行此操作(我实际上没有尝试运行此代码,但它应该可以工作)。您可以在一个独特的范围内调用您的函数

def biLSTM(data, n_steps):
    ... blah ...

with tf.variable_scope('LSTM1'):
    outputs, hidden = biLSTM(data, steps)

with tf.variable_scope('LSTM2'):
    outputs, hidden = biLSTM(data, steps)

或者您可以将唯一的范围名称传递给函数并使用

中的范围
def biLSTM(data, n_steps, layer_name):
    ... blah...
    with tf.variable_scope(layer_name) as scope:
        lstm_fw_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
        lstm_bw_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
        outputs, _, _ = tf.nn.bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, data, dtype=tf.float32)
    return outputs, n_hidden

l1 = biLSTM(data, steps, 'layer1')
l2 = biLSTM(data, steps, 'layer2')

由您的编码敏感度决定哪种方法可供选择,它们在功能上几乎相同。

答案 1 :(得分:0)

我也有类似的问题。但是我使用了预先训练的Resnet50模型的keras实现。

当我使用以下命令更新tensorflow版本时,它对我有用:

conda update -f -c conda-forge tensorflow

并使用

from keras import backend as K
K.clear_session
相关问题