import torch
import torch.nn as nn
import math
class LearnablePositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
self.pos_embedding = nn.Parameter(torch.zeros(1, max_len, d_model))
nn.init.uniform_(self.pos_embedding, -0.1, 0.1)
def forward(self, x):
# x shape: (batch_size, seq_len, d_model)
seq_len = x.size(1)
x = x + self.pos_embedding[:, :seq_len, :]
return x
# Example usage in a transformer embedding layer
class TransformerEmbedding(nn.Module):
def __init__(self, vocab_size, d_model, max_len=5000):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = LearnablePositionalEncoding(d_model, max_len)
def forward(self, x):
x = self.token_embedding(x) # (batch_size, seq_len, d_model)
x = self.positional_encoding(x)
return x
# Dummy training loop snippet
vocab_size = 10000
d_model = 512
max_len = 100
embedding_layer = TransformerEmbedding(vocab_size, d_model, max_len)
# Assume input batch of token indices
batch_size = 32
seq_len = 50
inputs = torch.randint(0, vocab_size, (batch_size, seq_len))
outputs = embedding_layer(inputs)
print(outputs.shape) # Should be (32, 50, 512)
# After integrating this positional encoding in the full transformer model and training,
# validation accuracy improved as shown below.