Keras关注层超过LSTM

时间:2016-04-23 14:55:26

标签: python keras lstm

我正在使用keras 1.0.1我尝试在LSTM上添加注意层。这是我到目前为止所做的,但它不起作用。

input_ = Input(shape=(input_length, input_dim))
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
att = TimeDistributed(Dense(1)(lstm))
att = Reshape((-1, input_length))(att)
att = Activation(activation="softmax")(att)
att = RepeatVector(self.HID_DIM)(att)
merge = Merge([att, lstm], "mul")
hid = Merge("sum")(merge)

last = Dense(self.HID_DIM, activation="relu")(hid)

网络应在输入序列上应用LSTM。然后,应将LSTM的每个隐藏状态输入到完全连接的层,在该层上应用Softmax。 softmax针对每个隐藏维度进行复制,并按元素方式乘以LSTM隐藏状态。然后应对得到的矢量求平均值。

编辑:这是编译,但我不确定它是否符合我的想法。

input_ = Input(shape=(input_length, input_dim))
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
att = TimeDistributed(Dense(1))(lstm)
att = Flatten()(att)
att = Activation(activation="softmax")(att)
att = RepeatVector(self.HID_DIM)(att)
att = Permute((2,1))(att)
mer = merge([att, lstm], "mul")
hid = AveragePooling1D(pool_length=input_length)(mer)
hid = Flatten()(hid)

2 个答案:

答案 0 :(得分:2)

Here是Keras的Attention LSTM的实现,以及instantiation的示例。不过,我自己也没试过。

答案 1 :(得分:0)

您共享的第一段代码不正确。第二段代码看起来很正确,除了一件事。不要使用TimeDistributed,因为权重将相同。使用具有非线性激活的常规密集层。


    input_ = Input(shape=(input_length, input_dim))
    lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
    att = Dense(1, activation='tanh')(lstm_out )
    att = Flatten()(att)
    att = Activation(activation="softmax")(att)
    att = RepeatVector(self.HID_DIM)(att)
    att = Permute((2,1))(att)
    mer = merge([att, lstm], "mul")

现在,您有了体重调整状态。如何使用它取决于您。我已经看到了大多数版本的Attention,只需将它们加到时间轴上,然后将输出用作上下文即可。