Introduction
Saving the entire model lets you keep the model's structure and learned knowledge so you can use it later without retraining.
Jump into concepts and practice - no test required
torch.save(model, PATH)
# To load:
model = torch.load(PATH)
model.eval()torch.save(model, 'model.pth')model = torch.load('model.pth')
model.eval()import torch import torch.nn as nn import torch.optim as optim # Define a simple model class SimpleNet(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(2, 1) def forward(self, x): return self.fc(x) # Create model instance model = SimpleNet() # Dummy input and target inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) targets = torch.tensor([[1.0], [2.0]]) # Loss and optimizer criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=0.01) # Train for 1 step model.train() optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() # Save entire model PATH = 'entire_model.pth' torch.save(model, PATH) # Load model loaded_model = torch.load(PATH) loaded_model.eval() # Predict with loaded model with torch.no_grad(): pred = loaded_model(torch.tensor([[5.0, 6.0]])) print(f"Loss after 1 step: {loss.item():.4f}") print(f"Prediction for input [5.0, 6.0]: {pred.item():.4f}")
torch.save(model, PATH) do in PyTorch?torch.save(model, PATH) saves the whole model object, which includes both architecture and weights.model.state_dict(), but here the entire model is saved.model.pth?torch.save(model, 'model.pth').model.state_dict() saves only weights, so torch.save(model.state_dict(), 'model.pth') is incorrect for entire model.import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(2, 1)
def forward(self, x):
return self.fc(x)
model = SimpleNet()
torch.save(model, 'model.pth')
loaded_model = torch.load('model.pth')
loaded_model.eval()
input_tensor = torch.tensor([[1.0, 2.0]])
output = loaded_model(input_tensor).item()
print(round(output, 2))torch.save and torch.load. Calling eval() sets model to evaluation mode.torch.save(model, 'model.pth'). When loading with loaded_model = torch.load('model.pth'), you get an error: AttributeError: Can't get attribute 'SimpleNet'. What is the likely cause?