使用LSTMCell转发时,Gluon MXNet错误

时间:2018-08-01 07:30:48

标签: deep-learning lstm mxnet

使用mxnet 1.1, 当我尝试在以下网络上运行net(data)时:

net = gluon.nn.HybridSequential()
    with net.name_scope():
        net.add(gluon.nn.Embedding(input_dim=MAX_EVENT_INDEX + 1, output_dim=EMBEDDING_VECTOR_LENGTH))
        net.add(gluon.nn.Conv1D(channels=conv1D_filters, kernel_size=conv1D_kernel_size, activation='relu'))
        net.add(gluon.nn.MaxPool1D(pool_size=max_pool_size, strides=2))
        net.add(gluon.rnn.LSTMCell(100))
        net.add(gluon.rnn.DropoutCell(dropout_rate))
        net.add(gluon.rnn.LSTMCell(100))
        net.add(gluon.rnn.DropoutCell(dropout_rate))
        net.add(gluon.rnn.LSTMCell(100))
        net.add(gluon.rnn.DropoutCell(dropout_rate))
        net.add(gluon.nn.Flatten())
        net.add(gluon.nn.Dense(1, activation="sigmoid"))
    net.hybridize()

错误:forward()缺少1个必需的位置参数:“状态”

当我将gluon.nn.Sequential()net.add(gluon.rnn.LSTM(100, dropout=dropout_rate))一起使用时,一切正常

谢谢

1 个答案:

答案 0 :(得分:0)

如果您研究this thread的实现,您会发现hybrid_forward需要明确的states参数。 LSTM类使用其LSTMCell不需要states参数(可以为None)。因此,从一个切换到另一个肯定会为您提供帮助。

LSTM类超过LSTMCell。它实际上在内部使用LSTMCell,但它还添加了额外的功能。例如,您可以指定LSTM为多层或双向,其中LSTMCell本质上只是一堆与LSTM相关的公式,用于计算门c和h。