0
0
PytorchConceptBeginner · 3 min read

What is nn.Module in PyTorch: Explanation and Example

nn.Module in PyTorch is a base class that helps you build neural network models by organizing layers and operations. It manages parameters and provides useful methods for training and saving models.
⚙️

How It Works

Think of nn.Module as a blueprint for building blocks of a neural network. Each block can be a layer like a linear transformation or an activation function. When you create a new model, you make a class that inherits from nn.Module, which helps you keep track of all parts of your model automatically.

This class also handles the parameters (weights and biases) inside your layers. It knows how to collect them, update them during training, and save or load them when needed. This way, you don’t have to manage these details yourself.

đź’»

Example

This example shows a simple neural network with one hidden layer using nn.Module. It defines the layers and the forward pass, which is how data moves through the network.

python
import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.hidden = nn.Linear(10, 5)  # hidden layer
        self.output = nn.Linear(5, 1)   # output layer

    def forward(self, x):
        x = torch.relu(self.hidden(x))
        x = self.output(x)
        return x

model = SimpleNet()
input_tensor = torch.randn(1, 10)
output = model(input_tensor)
print(output)
Output
tensor([[0.1234]], grad_fn=<AddmmBackward0>)
🎯

When to Use

Use nn.Module whenever you want to build a neural network in PyTorch. It is essential for creating custom models, combining layers, and managing parameters easily. For example, if you want to build a classifier, a regression model, or any deep learning architecture, nn.Module is the foundation.

It also helps when you want to save your trained model to disk or load it later for predictions. Without nn.Module, managing these tasks would be much harder and error-prone.

âś…

Key Points

  • Base class: nn.Module is the base for all neural network models in PyTorch.
  • Parameter management: It automatically tracks weights and biases.
  • Forward method: You define how data flows through the model here.
  • Easy saving/loading: Provides methods to save and load model parameters.
âś…

Key Takeaways

nn.Module is the foundation for building neural networks in PyTorch.
It helps organize layers and automatically manages model parameters.
You define the data flow by implementing the forward method.
It simplifies saving and loading models for training and inference.