如何将pytorch lstmcell转换为keras lstm或lstmcell

时间:2018-10-04 13:23:19

标签: keras

这是pytorch lstmcell的示例:

rnn = nn.LSTMCell(10, 20)
input = torch.randn(6, 3, 10)
hx = torch.randn(3, 20)
cx = torch.randn(3, 20)
output = []
hx, cx = rnn(input[0], (hx, cx))
output.append(hx)

不确定如何将其转换为keras lstm / lstmcell

1 个答案:

答案 0 :(得分:0)

原始Pytorch代码:
self.att_lstm = nn.LSTMCell(1536,512)
h_att,c_att = self.att_lstm(att_lstm_input,(state [0] [0],state [1] [0]))
状态[0] [0],状态[1] [0]是张量(10,512) 我在喀拉拉邦尝试过的东西: inputs = Input(shape=(10, 1536))
lstm, h_att, c_att = LSTM(units=512, input_shape=(10,1536), name='core.att_lstm', return_state=True)(inputs)
所以我不确定是否正确。