import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_size, dropout=0.1):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.query = nn.Linear(embed_size, embed_size)
self.key = nn.Linear(embed_size, embed_size)
self.value = nn.Linear(embed_size, embed_size)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([embed_size]))
def forward(self, x, mask=None):
Q = self.query(x) # (batch_size, seq_len, embed_size)
K = self.key(x) # (batch_size, seq_len, embed_size)
V = self.value(x) # (batch_size, seq_len, embed_size)
scores = torch.bmm(Q, K.transpose(1, 2)) / self.scale.to(x.device) # scaled dot-product
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(scores, dim=-1) # attention weights
attention = self.dropout(attention) # apply dropout
out = torch.bmm(attention, V) # weighted sum
return out, attention
# Example usage in a simple model
class SimpleModel(nn.Module):
def __init__(self, embed_size, num_classes):
super(SimpleModel, self).__init__()
self.attention = SelfAttention(embed_size)
self.fc = nn.Linear(embed_size, num_classes)
def forward(self, x, mask=None):
attn_out, attn_weights = self.attention(x, mask)
# Pooling: mean over sequence length
pooled = attn_out.mean(dim=1)
out = self.fc(pooled)
return out, attn_weights
# Training loop snippet (simplified)
# Assume X_train, y_train, X_val, y_val are tensors
import torch.optim as optim
embed_size = 64
num_classes = 2
model = SimpleModel(embed_size, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
model.train()
optimizer.zero_grad()
outputs, _ = model(X_train)
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
val_outputs, _ = model(X_val)
val_loss = criterion(val_outputs, y_val)
val_preds = val_outputs.argmax(dim=1)
val_acc = (val_preds == y_val).float().mean()
print(f"Epoch {epoch+1}: Train Loss={loss.item():.3f}, Val Loss={val_loss.item():.3f}, Val Acc={val_acc.item()*100:.1f}%")