Fix State Dict Key Mismatch Error in PyTorch Models
A
state_dict key mismatch in PyTorch happens when the keys in the saved model weights do not match the keys expected by the model architecture. To fix this, ensure the model architecture matches the saved weights exactly or adjust the keys manually before loading. You can also load weights with strict=False to ignore missing or unexpected keys.Why This Happens
This error occurs because the keys in the saved state_dict (which holds model weights) do not match the keys expected by the model you are trying to load them into. This usually happens if the model architecture has changed, or if you saved weights from a model wrapped in a container like nn.DataParallel but try to load them into a plain model.
python
import torch import torch.nn as nn class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 5) model = SimpleModel() torch.save(model.state_dict(), 'model.pth') # Later, change the model architecture class ChangedModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 6) # output size changed changed_model = ChangedModel() changed_model.load_state_dict(torch.load('model.pth'))
Output
RuntimeError: Error(s) in loading state_dict for ChangedModel:\n\tSize mismatch for linear.weight: copying a param with shape torch.Size([5, 10]) from checkpoint, the shape in current model is torch.Size([6, 10]).\n\tSize mismatch for linear.bias: copying a param with shape torch.Size([5]) from checkpoint, the shape in current model is torch.Size([6]).
The Fix
To fix this, make sure the model architecture matches the saved weights exactly. If you saved weights from a model wrapped in nn.DataParallel, remove the module. prefix from keys before loading. Alternatively, load the state dict with strict=False to ignore mismatched keys, but be careful as this may leave some weights uninitialized.
python
import torch import torch.nn as nn class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 5) model = SimpleModel() torch.save(model.state_dict(), 'model.pth') # Correct model architecture model2 = SimpleModel() state_dict = torch.load('model.pth') model2.load_state_dict(state_dict) # strict=True by default print('Model loaded successfully')
Output
Model loaded successfully
Prevention
To avoid this error in the future:
- Always save and load weights from the same model architecture.
- If using
nn.DataParallel, save the underlying model's state dict withmodel.module.state_dict(). - Use consistent naming and avoid changing layer names after saving.
- Use
strict=Falseonly when you understand the implications.
Related Errors
Other common errors include:
- Missing keys: Some weights are missing in the loaded state dict. Fix by ensuring all layers are saved and loaded.
- Unexpected keys: Extra keys in the state dict not used by the model. Fix by cleaning the state dict or using
strict=False. - Shape mismatch: Weights have different shapes due to architecture changes. Fix by matching architectures exactly.
Key Takeaways
Ensure model architecture matches saved weights exactly to avoid key mismatches.
Remove 'module.' prefix from keys if loading weights saved from nn.DataParallel.
Use load_state_dict with strict=False to ignore missing or unexpected keys cautiously.
Save and load weights consistently from the same model class and structure.
Check error messages carefully to identify missing, unexpected, or shape mismatched keys.