BasicLSTMCell

时间:2017-06-07 18:44:33

标签: tensorflow lstm

我有以下代码:

def dense_layers(pool3):
    with tf.variable_scope('local1') as scope:
        # Move everything into depth so we can perform a single matrix multiply.
        shape_d = pool3.get_shape()
        shape = shape_d[1] * shape_d[2] * shape_d[3]
        # tf_shape = tf.stack(shape)
        tf_shape = 1024

        print("shape:", shape, shape_d[1], shape_d[2], shape_d[3])

        # So note that tf_shape = 1024, this means that we have 1024 features are fed into the network. And
        # the batch size = 1024. Therefore, the aim is to divide the batch_size into num_steps so that
        reshape = tf.reshape(pool3, [-1, tf_shape])
        # Now we need to reshape/divide the batch_size into num_steps so that we would be feeding a sequence
        # And note that most importantly is to have batch_partition_length followed by step_size in the parameter list.
        lstm_inputs = tf.reshape(reshape, [batch_partition_length, step_size, tf_shape])

        # print('RNN inputs shape: ', lstm_inputs.get_shape()) # -> (128, 8, 1024).

        # Note that the state_size is the number of neurons.
        lstm = tf.contrib.rnn.BasicLSTMCell(state_size)
        lstm_outputs, final_state = tf.nn.dynamic_rnn(cell=lstm, inputs=lstm_inputs, initial_state=init_state)
        tf.assign(init_state, final_state)

所以,我正在获取池层的输出并尝试将其提供给网络中的LSTM。

最初我宣布了以下内容:

state_size = 16
step_size = 8

batch_partition_length = int(batch_size / step_size)

init_state = tf.Variable(tf.zeros([batch_partition_length, state_size]))    # -> [128, 16].

因此,我收到错误:

lstm_outputs, final_state = tf.nn.dynamic_rnn(cell=lstm, inputs=lstm_inputs, initial_state=init_state)

如下:

Traceback (most recent call last):
  File "C:/Users/user/PycharmProjects/AffectiveComputing/Brady_with_LSTM.py", line 197, in <module>
    predictions = dense_layers(conv_nets_output)
  File "C:/Users/user/PycharmProjects/AffectiveComputing/Brady_with_LSTM.py", line 162, in dense_layers
    lstm_outputs, final_state = tf.nn.dynamic_rnn(cell=lstm, inputs=lstm_inputs, initial_state=init_state)
  File "C:\Users\user\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn.py", line 553, in dynamic_rnn
    dtype=dtype)
  File "C:\Users\user\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn.py", line 720, in _dynamic_rnn_loop
    swap_memory=swap_memory)
  File "C:\Users\user\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2623, in while_loop
    result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "C:\Users\user\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2456, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "C:\Users\user\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2406, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "C:\Users\user\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn.py", line 705, in _time_step
    (output, new_state) = call_cell()
  File "C:\Users\user\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn.py", line 691, in <lambda>
    call_cell = lambda: cell(input_t, state)
  File "C:\Users\user\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\contrib\rnn\python\ops\core_rnn_cell_impl.py", line 238, in __call__
    c, h = state
  File "C:\Users\user\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 504, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

非常感谢任何帮助!!

1 个答案:

答案 0 :(得分:1)

LSTM的状态实际上由两部分组成

  1. 小区的状态
  2. 以前的输出
  3. the docs中为BasicLSTMCell提到了这一点。 This paper对LSTM的工作方式有一个很好的解释,这将有助于您理解为什么需要在LSTM实现中保留两组状态。抛出错误的原因是因为您需要为初始状态提供一组张量。

    那说你有两个选择:

    1. 提供由两个张量组成的初始状态。
    2. 让RNN小区生成自己的初始状态。
    3. 如果您想覆盖默认行为,通常只会执行1.在这种情况下,您使用的是默认(零)初始状态,因此您可以执行2。

      lstm_outputs, final_state = tf.nn.dynamic_rnn(cell=lstm, inputs=lstm_inputs, dtype=tf.float32)
      
相关问题