Implementing Luong Attention in PyTorch

时间:2018-05-28 18:41:20

标签: pytorch attention-model seq2seq

I am trying to implement the attention described in Luong et al. 2015 in PyTorch myself, but I couldn't get it work. Below is my code, I am only interested in the "general" attention case for now. I wonder if I am missing any obvious error. It runs, but doesn't seem to learn.

class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p

        self.embedding = nn.Embedding(
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size, self.hidden_size)
        # hc: [hidden, context]
        self.Whc = nn.Linear(self.hidden_size * 2, self.hidden_size)
        # s: softmax
        self.Ws = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        gru_out, hidden = self.gru(embedded, hidden)

        # [0] remove the dimension of directions x layers for now
        attn_prod =[0], encoder_outputs.t())
        attn_weights = F.softmax(attn_prod, dim=1) # eq. 7/8
        context =, encoder_outputs)

        # hc: [hidden: context]
        out_hc = F.tanh(self.Whc([hidden[0], context], dim=1)) # eq.5
        output = F.log_softmax(self.Ws(out_hc), dim=1) eq. 6

        return output, hidden, attn_weights

I have studied the attention implemented in


  • The first one isn't the exact attention mechanism I am looking for. A major disadvantage is that its attention depends on the sequence length (self.attn = nn.Linear(self.hidden_size * 2, self.max_length)), which could be expensive for long sequences.
  • The second one is more similar to what's described in the paper, but still not the same as there is not tanh. Besides, it is really slow after updating it to latest version of pytorch (ref). Also I don't know why it takes the last context (ref).

1 个答案:

答案 0 :(得分:2)

此版本有效,并且遵循Luong Attention(一般)的定义。与问题中的主要区别在于embedding_sizehidden_size的分离,这对于实验后的训练似乎很重要。以前,我把它们都做了相同的大小(256),这给学习带来了麻烦,似乎网络只能学习一半的序列。

class EncoderRNN(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size,
                 num_layers=1, bidirectional=False, batch_size=1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.batch_size = batch_size

        self.embedding = nn.Embedding(input_size, embedding_size)

        self.gru = nn.GRU(embedding_size, hidden_size, num_layers,

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output, hidden = self.gru(embedded, hidden)
        return output, hidden

    def initHidden(self):
        directions = 2 if self.bidirectional else 1
        return torch.zeros(
            self.num_layers * directions,

class AttnDecoderRNN(nn.Module):
    def __init__(self, embedding_size, hidden_size, output_size, dropout_p=0):
        super(AttnDecoderRNN, self).__init__()
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p

        self.embedding = nn.Embedding(
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(embedding_size, hidden_size)
        self.attn = nn.Linear(hidden_size, hidden_size)
        # hc: [hidden, context]
        self.Whc = nn.Linear(hidden_size * 2, hidden_size)
        # s: softmax
        self.Ws = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        gru_out, hidden = self.gru(embedded, hidden)

        attn_prod =[0], encoder_outputs.t())
        attn_weights = F.softmax(attn_prod, dim=1)
        context =, encoder_outputs)

        # hc: [hidden: context]
        hc =[hidden[0], context], dim=1)
        out_hc = F.tanh(self.Whc(hc))
        output = F.log_softmax(self.Ws(out_hc), dim=1)

        return output, hidden, attn_weights