0
0
PyTorchml~15 mins

Loading model state_dict in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Loading model state_dict
Problem:You have trained a PyTorch model and saved its state_dict. Now you want to load this saved state_dict into a new model instance to continue training or make predictions.
Current Metrics:N/A (focus is on correct loading of model weights)
Issue:If the state_dict is not loaded correctly, the model will have random weights, leading to poor predictions or errors.
Your Task
Load the saved state_dict into a new model instance correctly and verify that the model produces the same output as the original model for the same input.
Do not retrain the model.
Use the same model architecture for loading the state_dict.
Use PyTorch's recommended methods for saving and loading.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(5, 2)
    def forward(self, x):
        return self.linear(x)

# Create and train a model (simulate training by setting weights manually for reproducibility)
original_model = SimpleModel()
with torch.no_grad():
    original_model.linear.weight.fill_(1.0)
    original_model.linear.bias.fill_(0.5)

# Save the state_dict
torch.save(original_model.state_dict(), 'model_state.pth')

# Create a new model instance
loaded_model = SimpleModel()

# Load the saved state_dict
state_dict = torch.load('model_state.pth')
loaded_model.load_state_dict(state_dict)

# Set both models to eval mode
original_model.eval()
loaded_model.eval()

# Test input
input_tensor = torch.ones(1, 5)

# Get outputs
original_output = original_model(input_tensor)
loaded_output = loaded_model(input_tensor)

# Check if outputs are the same
outputs_match = torch.allclose(original_output, loaded_output)

print(f"Outputs match: {outputs_match}")
print(f"Original output: {original_output}")
print(f"Loaded output: {loaded_output}")
Defined a simple model architecture.
Simulated training by manually setting weights and bias for reproducibility.
Saved the model's state_dict to a file.
Created a new model instance with the same architecture.
Loaded the saved state_dict into the new model.
Verified that outputs from original and loaded models match for the same input.
Results Interpretation

Before loading the state_dict, the new model has random weights and produces different outputs from the original model.

After loading the state_dict, the new model produces the same outputs as the original model for the same input.

Loading a saved state_dict correctly restores the model's learned weights, allowing the model to produce consistent outputs without retraining.
Bonus Experiment
Try saving and loading the entire model using torch.save(model) and torch.load(), then compare outputs with the state_dict method.
💡 Hint
Saving the entire model saves the architecture and weights but is less flexible and less recommended than saving state_dict.