Tensorflow RNN-LSTM - 重置隐藏状态

时间:2017-08-01 06:55:26

标签: tensorflow stateful rnn

我正在建立一个用于语言识别的有状态LSTM。 作为状态良好,我可以使用较小的文件训练网络,新的批次将像讨论中的下一句话。 但是,为了对网络进行适当的训练,我需要在一些批次之间重置LSTM的隐藏状态。

我使用变量存储LSTM的hidden_​​state以提高性能:

    with tf.variable_scope('Hidden_state'):
        hidden_state = tf.get_variable("hidden_state", [self.num_layers, 2, self.batch_size, self.hidden_size],
                                       tf.float32, initializer=tf.constant_initializer(0.0), trainable=False)
        # Arrange it to a tuple of LSTMStateTuple as needed
        l = tf.unstack(hidden_state, axis=0)
        rnn_tuple_state = tuple([tf.contrib.rnn.LSTMStateTuple(l[idx][0], l[idx][1])
                                for idx in range(self.num_layers)])

    # Build the RNN
    with tf.name_scope('LSTM'):
        rnn_output, _ = tf.nn.dynamic_rnn(cell, rnn_inputs, sequence_length=input_seq_lengths,
                                          initial_state=rnn_tuple_state, time_major=True)

现在我对如何重置隐藏状态感到困惑。我尝试了两种解决方案,但它不起作用:

第一个解决方案

重置" hidden_​​state"变量:

rnn_state_zero_op = hidden_state.assign(tf.zeros_like(hidden_state))

它确实有效,我认为这是因为拆散和元组结构没有重新播放"运行rnn_state_zero_op操作后进入图形。

第二个解决方案

关注LSTMStateTuple vs cell.zero_state() for RNN in Tensorflow我尝试使用以下方法重置单元格状态:

rnn_state_zero_op = cell.zero_state(self.batch_size, tf.float32)

它似乎也无效。

问题

我想到了另一个解决方案,但它充其量只是猜测:我没有保持tf.nn.dynamic_rnn返回的状态,我已经想到了它但我得到了一个元组,我无法找到一种方法来构建一个重置元组的操作。

此时我承认我并不完全理解张量流的内部工作,如果它甚至可以做我想做的事情。 有没有正确的方法呢?

谢谢!

2 个答案:

答案 0 :(得分:4)

感谢this answer to another question我能够找到一种方法来完全控制是否(以及何时)将RNN的内部状态重置为0。

首先,你需要定义一些变量来存储RNN的状态,这样你就可以控制它了:

with tf.variable_scope('Hidden_state'):
    state_variables = []
    for state_c, state_h in cell.zero_state(self.batch_size, tf.float32):
        state_variables.append(tf.nn.rnn_cell.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
    rnn_tuple_state = tuple(state_variables)

请注意,此版本直接定义了LSTM使用的变量,这比我的问题中的版本要好得多,因为您不必卸载并构建元组,这会在图形中添加一些操作无法明确运行。

其次构建RNN并检索最终状态:

# Build the RNN
with tf.name_scope('LSTM'):
    rnn_output, new_states = tf.nn.dynamic_rnn(cell, rnn_inputs,
                                               sequence_length=input_seq_lengths,
                                               initial_state=rnn_tuple_state,
                                               time_major=True)

所以现在你有了新的RNN内部状态。您可以定义两个操作来管理它。

第一个将更新下一批的变量。所以在下一批中" initial_state"将向RNN提供上一批的最终状态:

# Define an op to keep the hidden state between batches
update_ops = []
for state_variable, new_state in zip(rnn_tuple_state, new_states):
    # Assign the new state to the state variables on this layer
    update_ops.extend([state_variable[0].assign(new_state[0]),
                       state_variable[1].assign(new_state[1])])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
rnn_keep_state_op = tf.tuple(update_ops)

您应该在想要运行批处理并保持内部状态的任何时候将此操作添加到会话中。

小心:如果您使用此op操作批处理1,则批处理2将从批处理1最终状态开始,但如果您在运行批处理2时再次调用它,则批处理3也将从批处理1最终状态开始。我的建议是每次运行RNN时添加此操作。

第二个操作将用于将RNN的内部状态重置为零:

# Define an op to reset the hidden state to zeros
update_ops = []
for state_variable in rnn_tuple_state:
    # Assign the new state to the state variables on this layer
    update_ops.extend([state_variable[0].assign(tf.zeros_like(state_variable[0])),
                       state_variable[1].assign(tf.zeros_like(state_variable[1]))])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
rnn_state_zero_op = tf.tuple(update_ops)

只要您想重置内部状态,就可以调用此操作。

答案 1 :(得分:0)

一个LSTM图层的简化版AMairesse帖子:

zero_state = tf.zeros(shape=[1, units[-1]])
self.c_state = tf.Variable(zero_state, trainable=False)
self.h_state = tf.Variable(zero_state, trainable=False)
self.init_encoder = tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state)

self.output_encoder, self.state_encoder = tf.nn.dynamic_rnn(cell_encoder, layer, initial_state=self.init_encoder)

# save or reset states
self.update_ops += [self.c_state.assign(self.state_encoder.c, use_locking=True)]
self.update_ops += [self.h_state.assign(self.state_encoder.h, use_locking=True)]

或者您可以使用替换init_encoder来重置步骤== 0的状态(您需要将self.step_tf作为占位符传递给session.run()):

self.step_tf = tf.placeholder_with_default(tf.constant(-1, dtype=tf.int64), shape=[], name="step")

self.init_encoder = tf.cond(tf.equal(self.step_tf, 0),
  true_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(zero_state, zero_state),
  false_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state))