0
0
PyTorchml~20 mins

Saving model state_dict in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Saving model state_dict
Problem:You have trained a PyTorch model but did not save its state_dict. If the program stops, you lose the trained weights.
Current Metrics:Training accuracy: 90%, Validation accuracy: 88%
Issue:No saved model state means you cannot reuse the trained model without retraining.
Your Task
Save the model's state_dict after training so you can load it later and reuse the trained weights.
Use PyTorch's recommended methods for saving and loading state_dict.
Do not change the model architecture or training code.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 2)
    def forward(self, x):
        return self.fc(x)

# Create model instance
model = SimpleNet()

# Dummy data
X = torch.randn(100, 10)
y = torch.randint(0, 2, (100,))

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Training loop
for epoch in range(5):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

# Save the model state_dict
torch.save(model.state_dict(), 'model_state.pth')

# To load the model later:
# model = SimpleNet()
# model.load_state_dict(torch.load('model_state.pth'))
# model.eval()
Added torch.save(model.state_dict(), 'model_state.pth') after training to save weights.
Included example code to load the saved state_dict back into the model.
Results Interpretation

Before: No saved model weights, so trained model lost after program ends.

After: Model weights saved in 'model_state.pth' file, allowing reuse without retraining.

Saving the model's state_dict lets you keep the trained weights and load them later, saving time and resources.
Bonus Experiment
Try saving and loading the entire model instead of just the state_dict. Compare file sizes and loading flexibility.
💡 Hint
Use torch.save(model, 'model_full.pth') to save and torch.load('model_full.pth') to load the full model.