0
0
PytorchHow-ToBeginner · 3 min read

How to Use nn.ModuleList in PyTorch: Syntax and Example

Use nn.ModuleList in PyTorch to store a list of layers or modules that you want to treat as part of your model. It allows dynamic layer management and ensures all modules are registered properly for training and saving. You can iterate over ModuleList just like a Python list during the forward pass.
📐

Syntax

nn.ModuleList is initialized with a list of PyTorch modules (layers). You can add modules dynamically or pass them all at once. It behaves like a Python list but registers the modules so PyTorch tracks their parameters.

  • nn.ModuleList([module1, module2, ...]): Create with initial modules.
  • append(module): Add a module later.
  • Use indexing like a list to access modules.
python
import torch.nn as nn

layers = nn.ModuleList([
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 5)
])

layers.append(nn.Sigmoid())

print(layers[0])  # Access first layer
Output
Linear(in_features=10, out_features=20, bias=True)
💻

Example

This example shows how to use nn.ModuleList inside a custom model to create a sequence of linear layers with ReLU activations. It demonstrates dynamic layer storage and iteration in the forward pass.

python
import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(10, 15),
            nn.ReLU(),
            nn.Linear(15, 5)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

model = SimpleModel()
input_tensor = torch.randn(2, 10)
output = model(input_tensor)
print(output)
Output
tensor([[ 0.0313, 0.0346, -0.0108, 0.0077, 0.0077], [-0.0318, 0.0273, -0.0093, 0.0013, 0.0039]], grad_fn=<AddmmBackward0>)
⚠️

Common Pitfalls

One common mistake is using a regular Python list to store layers instead of nn.ModuleList. This causes PyTorch to not register the layers, so their parameters won't be updated during training. Another pitfall is forgetting to iterate over the ModuleList in the forward method.

python
import torch.nn as nn

# Wrong: Using a Python list (layers won't be registered)
layers_wrong = [nn.Linear(10, 5), nn.ReLU()]

class ModelWrong(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = layers_wrong  # Not a ModuleList

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Right: Using nn.ModuleList
layers_right = nn.ModuleList([nn.Linear(10, 5), nn.ReLU()])

class ModelRight(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = layers_right

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
📊

Quick Reference

  • Initialization: nn.ModuleList([modules])
  • Add module: module_list.append(module)
  • Access module: module_list[index]
  • Use in forward: Iterate over ModuleList to apply layers
  • Difference from list: Registers modules for training and saving

Key Takeaways

Use nn.ModuleList to store layers so PyTorch tracks their parameters automatically.
You can add layers dynamically and iterate over them in the forward method.
Avoid using plain Python lists for layers as they won't be registered for training.
ModuleList behaves like a list but integrates with PyTorch's model system.
Always iterate over ModuleList in forward to apply each layer to input data.