0
0
PyTorchml~5 mins

Saving entire model in PyTorch

Choose your learning style9 modes available
Introduction
Saving the entire model lets you keep the model's structure and learned knowledge so you can use it later without retraining.
You want to pause training and continue later from the same model.
You finished training and want to share the model with others.
You want to deploy the model to make predictions in a real app.
You want to keep a backup of your model after training.
You want to load the model later to test or improve it.
Syntax
PyTorch
torch.save(model, PATH)

# To load:
model = torch.load(PATH)
model.eval()
PATH is a string with the file name or path where the model is saved.
model.eval() sets the model to evaluation mode, important for layers like dropout or batchnorm.
Examples
Save the entire model to a file named 'model.pth'.
PyTorch
torch.save(model, 'model.pth')
Load the saved model from 'model.pth' and set it to evaluation mode.
PyTorch
model = torch.load('model.pth')
model.eval()
Sample Model
This code trains a simple model for one step, saves the entire model, loads it back, and makes a prediction.
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(2, 1)
    def forward(self, x):
        return self.fc(x)

# Create model instance
model = SimpleNet()

# Dummy input and target
inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
targets = torch.tensor([[1.0], [2.0]])

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Train for 1 step
model.train()
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

# Save entire model
PATH = 'entire_model.pth'
torch.save(model, PATH)

# Load model
loaded_model = torch.load(PATH)
loaded_model.eval()

# Predict with loaded model
with torch.no_grad():
    pred = loaded_model(torch.tensor([[5.0, 6.0]]))

print(f"Loss after 1 step: {loss.item():.4f}")
print(f"Prediction for input [5.0, 6.0]: {pred.item():.4f}")
OutputSuccess
Important Notes
Saving the entire model saves both the architecture and the weights, so you don't need to redefine the model class when loading.
This method can cause issues if you change the model code later, so saving only the state_dict is often safer for long-term projects.
Always call model.eval() before using the loaded model for prediction to get correct results.
Summary
Use torch.save(model, PATH) to save the whole model including structure and weights.
Load the model with torch.load(PATH) and set it to eval mode before prediction.
Saving the entire model is quick and easy but less flexible than saving only weights.