import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
# Simplified CLIP-like model components
class SimpleImageEncoder(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.3), # Added dropout
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.3) # Added dropout
)
self.fc = nn.Linear(64 * 8 * 8, 256)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class SimpleTextEncoder(nn.Module):
def __init__(self, vocab_size=1000, embed_dim=256):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.fc = nn.Sequential(
nn.Linear(embed_dim, 256),
nn.ReLU(),
nn.Dropout(0.3) # Added dropout
)
def forward(self, x):
x = self.embedding(x).mean(dim=1) # simple average embedding
x = self.fc(x)
return x
class SimpleCLIP(nn.Module):
def __init__(self):
super().__init__()
self.image_encoder = SimpleImageEncoder()
self.text_encoder = SimpleTextEncoder()
def forward(self, image, text):
image_features = self.image_encoder(image)
text_features = self.text_encoder(text)
# Normalize features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# Compute cosine similarity
logits = image_features @ text_features.t()
return logits
# Dummy dataset and dataloader (replace with real data in practice)
transform = transforms.Compose([transforms.ToTensor()])
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
# Dummy text inputs (random integers as token ids)
def generate_dummy_text(batch_size, seq_len=10, vocab_size=1000):
return torch.randint(0, vocab_size, (batch_size, seq_len))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCLIP().to(device)
# Use Adam optimizer with weight decay for L2 regularization
optimizer = optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-4) # Reduced lr, added weight decay
criterion = nn.CrossEntropyLoss()
# Training loop with early stopping
best_val_acc = 0
patience = 3
trigger_times = 0
for epoch in range(20):
model.train()
total_loss = 0
correct = 0
total = 0
for images, _ in dataloader:
images = images.to(device)
batch_size = images.size(0)
texts = generate_dummy_text(batch_size).to(device)
optimizer.zero_grad()
logits = model(images, texts)
# Labels: diagonal elements are correct matches
labels = torch.arange(batch_size).to(device)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * batch_size
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total += batch_size
train_loss = total_loss / total
train_acc = correct / total * 100
# Validation simulated by training metrics here (replace with real val set)
val_acc = train_acc - 10 # Simulate validation accuracy lower by 10%
print(f"Epoch {epoch+1}: Train Loss={train_loss:.3f}, Train Acc={train_acc:.1f}%, Val Acc={val_acc:.1f}%")
# Early stopping check
if val_acc > best_val_acc:
best_val_acc = val_acc
trigger_times = 0
else:
trigger_times += 1
if trigger_times >= patience:
print("Early stopping triggered")
break