如何向Tensorflow的LSTM / RNN添加单元状态?

时间:2020-04-15 14:43:50

标签: tensorflow tensorflow2.0

我想向Tensorflow LSTM模型添加一个附加的(二进制)单元状态。

因此,我正在尝试自定义以下LSTM步骤功能:

 def step(cell_inputs, cell_states):

  h_tm1 = cell_states[0]  # previous memory state
  c_tm1 = cell_states[1]  # previous carry state

  z = K.dot(cell_inputs, kernel)
  z += K.dot(h_tm1, recurrent_kernel)
  z = K.bias_add(z, bias)

  z0, z1, z2, z3 = array_ops.split(z, 4, axis=1)

  i = recurrent_activation(z0)
  f = recurrent_activation(z1)
  c = f * c_tm1 + i * activation(z2)
  o = recurrent_activation(z3)

  h = o * activation(c)
  return h, [h, c]

此步骤函数将在Keras的RNN后端中使用:

last_output, outputs, new_states = K.rnn(
    step,
    inputs, [init_h, init_c],
    constants=None,
    unroll=False,
    time_major=time_major,
    mask=mask,
    go_backwards=go_backwards,
    input_length=sequence_lengths
    if sequence_lengths is not None else timesteps)

我的附加单元状态b在每一层和每个时间步都设置为0或1。我必须能够访问b(layer-1,timestep)b(layer,timestep-1)。 如何实现呢?我是否必须将单元格状态b添加到cell_inputs变量中才能读出b(layer-1,timestep)?而且我是否必须将b添加到cell_states变量中以便我可以读出b(layer,timestep-1)

非常感谢您的帮助!

0 个答案:

没有答案