Introduction
Saving the model's state_dict lets you keep the learned settings so you can use or improve the model later without starting over.
Jump into concepts and practice - no test required
torch.save(model.state_dict(), 'filename.pth')torch.save(model.state_dict(), 'model_weights.pth')PATH = 'checkpoint.pth'
torch.save(model.state_dict(), PATH)import torch import torch.nn as nn # Define a simple model class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear = nn.Linear(2, 1) def forward(self, x): return self.linear(x) # Create model instance model = SimpleModel() # Print initial weights print('Initial weights:', model.linear.weight) # Save the model's state_dict torch.save(model.state_dict(), 'simple_model.pth') # Load the state_dict into a new model to check new_model = SimpleModel() new_model.load_state_dict(torch.load('simple_model.pth')) # Print loaded weights print('Loaded weights:', new_model.linear.weight)
model.state_dict() in PyTorch contain?state_dict stores all the learned parameters like weights and biases of the model layers.torch.save() is used to save objects to a file.model.state_dict() to torch.save() along with the filename.import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
model = SimpleModel()
torch.save(model.state_dict(), 'weights.pth')
loaded_state = torch.load('weights.pth')
print(type(loaded_state))model.state_dict() stores an OrderedDict of parameter tensors.torch.save(model.state_dict(), 'model.pth'). Later, you try to load it with model.load_state_dict(torch.load('model.pth')) but get a runtime error about missing keys. What is the most likely cause?torch.save(model.state_dict(), 'file.pth') to save learned weights.model.load_state_dict(torch.load('file.pth')) to load parameters.