0
0
NLPml~5 mins

Encoder-decoder with attention in NLP

Choose your learning style9 modes available
Introduction

Encoder-decoder with attention helps a model focus on important parts of input when making predictions. It improves tasks like translating languages by looking at relevant words.

Translating a sentence from one language to another.
Summarizing a long paragraph into a short summary.
Answering questions based on a given text.
Generating captions for images by focusing on image parts.
Speech recognition where attention helps focus on sounds.
Syntax
NLP
class Encoder(nn.Module):
    def __init__(self, ...):
        ...
    def forward(self, x):
        ...

class Attention(nn.Module):
    def __init__(self, ...):
        ...
    def forward(self, encoder_outputs, decoder_hidden):
        ...

class Decoder(nn.Module):
    def __init__(self, ...):
        ...
    def forward(self, input, hidden, encoder_outputs):
        attention_weights = self.attention(encoder_outputs, hidden)
        context = attention_weights @ encoder_outputs
        ...
        return output, hidden, attention_weights

The encoder processes the input sequence into a set of outputs.

The attention layer calculates weights to focus on parts of encoder outputs.

Examples
This computes attention scores by comparing decoder hidden state with encoder outputs.
NLP
attention_weights = torch.softmax(torch.bmm(decoder_hidden.unsqueeze(1), encoder_outputs.transpose(1,2)), dim=-1)
Context vector is a weighted sum of encoder outputs using attention weights.
NLP
context = torch.bmm(attention_weights, encoder_outputs)
Sample Model

This code builds a simple encoder-decoder model with attention for sequence tasks. It runs one training step on toy data and prints the total loss.

NLP
import torch
import torch.nn as nn
import torch.optim as optim

# Simple Encoder
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim, batch_first=True)
    def forward(self, src):
        embedded = self.embedding(src)
        outputs, hidden = self.rnn(embedded)
        return outputs, hidden

# Attention Layer
class Attention(nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        self.attn = nn.Linear(hid_dim * 2, hid_dim)
        self.v = nn.Linear(hid_dim, 1, bias=False)
    def forward(self, hidden, encoder_outputs):
        src_len = encoder_outputs.shape[1]
        hidden = hidden.permute(1, 0, 2)  # (batch, 1, hid_dim)
        hidden = hidden.repeat(1, src_len, 1)  # (batch, src_len, hid_dim)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)
        return torch.softmax(attention, dim=1)

# Decoder with Attention
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, attention):
        super().__init__()
        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU(hid_dim + emb_dim, hid_dim, batch_first=True)
        self.fc_out = nn.Linear(hid_dim * 2 + emb_dim, output_dim)
        self.attention = attention
    def forward(self, input, hidden, encoder_outputs):
        input = input.unsqueeze(1)  # (batch, 1)
        embedded = self.embedding(input)  # (batch, 1, emb_dim)
        attn_weights = self.attention(hidden, encoder_outputs)  # (batch, src_len)
        attn_weights = attn_weights.unsqueeze(1)  # (batch, 1, src_len)
        context = torch.bmm(attn_weights, encoder_outputs)  # (batch, 1, hid_dim)
        rnn_input = torch.cat((embedded, context), dim=2)  # (batch, 1, emb_dim + hid_dim)
        output, hidden = self.rnn(rnn_input, hidden)  # output: (batch,1,hid_dim)
        output = output.squeeze(1)  # (batch, hid_dim)
        context = context.squeeze(1)  # (batch, hid_dim)
        embedded = embedded.squeeze(1)  # (batch, emb_dim)
        pred_input = torch.cat((output, context, embedded), dim=1)  # (batch, hid_dim*2 + emb_dim)
        prediction = self.fc_out(pred_input)  # (batch, output_dim)
        return prediction, hidden, attn_weights.squeeze(1)

# Toy data and training loop
INPUT_DIM = 10
OUTPUT_DIM = 10
EMB_DIM = 8
HID_DIM = 16

encoder = Encoder(INPUT_DIM, EMB_DIM, HID_DIM)
attention = Attention(HID_DIM)
decoder = Decoder(OUTPUT_DIM, EMB_DIM, HID_DIM, attention)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()))

# Example input: batch size 2, sequence length 5
src = torch.tensor([[1,2,3,4,5],[5,4,3,2,1]])
tgt = torch.tensor([[1,2,3,4,5],[5,4,3,2,1]])

encoder_outputs, hidden = encoder(src)
input_decoder = tgt[:,0]  # first token
loss_total = 0

for t in range(1, tgt.shape[1]):
    output, hidden, attn_weights = decoder(input_decoder, hidden, encoder_outputs)
    loss = criterion(output, tgt[:,t])
    loss_total += loss.item()
    input_decoder = tgt[:,t]  # teacher forcing

print(f"Total loss: {loss_total:.4f}")
OutputSuccess
Important Notes

Attention helps the decoder look at different parts of the input for each output word.

Teacher forcing means using the true previous word as input during training.

Batch size and sequence length must be consistent in inputs.

Summary

Encoder-decoder with attention improves sequence tasks by focusing on important input parts.

Attention weights show where the model looks when predicting each output.

This method is widely used in translation, summarization, and more.