How to Define Forward Method in PyTorch: Simple Guide
In PyTorch, you define the
forward method inside a subclass of torch.nn.Module to specify how input data passes through the model layers. This method takes input tensors and returns output tensors, representing the model's prediction or transformation.Syntax
The forward method is defined inside a class that inherits from torch.nn.Module. It takes self and input tensors as arguments and returns the output tensor after applying model layers.
- self: Refers to the instance of the model class.
- input: The input tensor(s) to the model.
- return: The output tensor after computation.
python
class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() # define layers here def forward(self, x): # define forward pass here return x
Example
This example shows a simple neural network with one linear layer. The forward method applies this layer to the input tensor and returns the result.
python
import torch class SimpleNet(torch.nn.Module): def __init__(self): super(SimpleNet, self).__init__() self.linear = torch.nn.Linear(3, 1) # 3 inputs to 1 output def forward(self, x): return self.linear(x) # Create model instance model = SimpleNet() # Example input tensor with batch size 2 and 3 features input_tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) # Get model output output = model(input_tensor) print(output)
Output
tensor([[ 0.1234],
[-0.5678]], grad_fn=<AddmmBackward0>)
Common Pitfalls
Common mistakes when defining the forward method include:
- Not calling
super().__init__()in the constructor, which can cause errors. - Forgetting to return the output tensor at the end of
forward. - Using
forwarddirectly instead of calling the model instance (e.g.,model(x)), which bypasses hooks and other PyTorch features.
python
import torch # Wrong way: missing return class BadModel(torch.nn.Module): def __init__(self): super(BadModel, self).__init__() self.linear = torch.nn.Linear(2, 1) def forward(self, x): self.linear(x) # forgot to return # Right way class GoodModel(torch.nn.Module): def __init__(self): super(GoodModel, self).__init__() self.linear = torch.nn.Linear(2, 1) def forward(self, x): return self.linear(x)
Quick Reference
Tips for defining forward in PyTorch:
- Always inherit from
torch.nn.Module. - Call
super().__init__()in__init__. - Define layers as class attributes in
__init__. - Use
forward(self, x)to define the data flow. - Return the output tensor at the end of
forward. - Call the model instance like a function (
model(x)), notmodel.forward(x).
Key Takeaways
Define the forward method inside a class inheriting from torch.nn.Module to specify model computation.
Always return the output tensor from the forward method after applying layers.
Call super().__init__() in the constructor to properly initialize the model.
Use model(input) to run the forward pass, not model.forward(input) directly.
Define layers in __init__ and use them inside forward for clean, reusable code.