How to Create a Model in PyTorch: Simple Guide
To create a model in PyTorch, define a class that inherits from
torch.nn.Module and implement the __init__ and forward methods. The __init__ method sets up layers, and forward defines how data flows through the model.Syntax
Creating a model in PyTorch involves these parts:
- Subclass
torch.nn.Module: This is the base class for all models. __init__method: Define layers here.forwardmethod: Define how input data passes through layers.
python
import torch import torch.nn as nn class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() # Define layers here self.layer = nn.Linear(in_features=10, out_features=5) def forward(self, x): # Define forward pass return self.layer(x)
Example
This example shows a simple model with one linear layer. It takes input of size 10 and outputs size 5. We create a random input tensor and get the model's output.
python
import torch import torch.nn as nn class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear = nn.Linear(10, 5) def forward(self, x): return self.linear(x) # Create model instance model = SimpleModel() # Create random input tensor with batch size 2 and 10 features input_tensor = torch.randn(2, 10) # Get model output output = model(input_tensor) print(output)
Output
tensor([[ 0.1234, -0.5678, 0.9101, -0.2345, 0.6789],
[-0.3456, 0.7890, -0.1234, 0.4567, -0.8901]], grad_fn=<AddmmBackward0>)
Common Pitfalls
Common mistakes when creating PyTorch models include:
- Not calling
super().__init__()in the constructor, which breaks layer registration. - Defining layers inside the
forwardmethod instead of__init__, causing layers to be recreated every call. - Forgetting to return the output tensor from
forward.
python
import torch.nn as nn # Wrong way: defining layer inside forward class WrongModel(nn.Module): def __init__(self): super(WrongModel, self).__init__() def forward(self, x): linear = nn.Linear(10, 5) # This creates a new layer every call return linear(x) # Right way: define layer in __init__ class RightModel(nn.Module): def __init__(self): super(RightModel, self).__init__() self.linear = nn.Linear(10, 5) def forward(self, x): return self.linear(x)
Quick Reference
Remember these tips when creating PyTorch models:
- Always subclass
nn.Module. - Call
super().__init__()in your constructor. - Define layers in
__init__, not inforward. forwardmethod must return the output tensor.
Key Takeaways
Create models by subclassing torch.nn.Module and implementing __init__ and forward methods.
Define all layers inside __init__ and use them in forward to process input data.
Always call super().__init__() in your model's constructor to register layers properly.
Avoid creating layers inside forward to prevent errors and inefficiency.
Return the output tensor from the forward method to get model predictions.