0
0
PyTorchml~5 mins

Forward pass computation in PyTorch

Choose your learning style9 modes available
Introduction

The forward pass is how a model makes predictions from input data. It moves data through the model step-by-step to get results.

When you want to see what your model predicts for new data.
When training a model to calculate loss and update weights.
When testing a model's accuracy on unseen data.
When debugging to check if data flows correctly through the model.
When visualizing intermediate outputs inside the model.
Syntax
PyTorch
output = model(input_data)

This calls the model's forward method automatically.

input_data must be a tensor with the right shape.

Examples
Passes a random image tensor through the model to get predictions.
PyTorch
output = model(torch.randn(1, 3, 28, 28))
Runs the forward pass on your input tensor to get output.
PyTorch
output = model(input_tensor)
Sample Model

This code defines a simple model with one linear layer. It creates an input tensor with 4 features and runs the forward pass to get the output predictions.

PyTorch
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(4, 2)  # input 4 features, output 2

    def forward(self, x):
        return self.linear(x)

# Create model instance
model = SimpleModel()

# Create example input tensor (batch size 1, 4 features)
input_data = torch.tensor([[1.0, 2.0, 3.0, 4.0]])

# Run forward pass
output = model(input_data)

print("Input:", input_data)
print("Output:", output)
OutputSuccess
Important Notes

The forward pass does not change model weights; it only computes outputs.

Use model.eval() mode when testing to disable training behaviors like dropout.

Input shapes must match what the model expects, or you will get errors.

Summary

The forward pass moves input data through the model to get predictions.

Call the model with input tensors to run the forward pass.

It is used during training, testing, and debugging.