How to Use nn.ModuleList in PyTorch: Syntax and Example
Use
nn.ModuleList in PyTorch to store a list of layers or modules that you want to treat as part of your model. It allows dynamic layer management and ensures all modules are registered properly for training and saving. You can iterate over ModuleList just like a Python list during the forward pass.Syntax
nn.ModuleList is initialized with a list of PyTorch modules (layers). You can add modules dynamically or pass them all at once. It behaves like a Python list but registers the modules so PyTorch tracks their parameters.
nn.ModuleList([module1, module2, ...]): Create with initial modules.append(module): Add a module later.- Use indexing like a list to access modules.
python
import torch.nn as nn layers = nn.ModuleList([ nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5) ]) layers.append(nn.Sigmoid()) print(layers[0]) # Access first layer
Output
Linear(in_features=10, out_features=20, bias=True)
Example
This example shows how to use nn.ModuleList inside a custom model to create a sequence of linear layers with ReLU activations. It demonstrates dynamic layer storage and iteration in the forward pass.
python
import torch import torch.nn as nn class SimpleModel(nn.Module): def __init__(self): super().__init__() self.layers = nn.ModuleList([ nn.Linear(10, 15), nn.ReLU(), nn.Linear(15, 5) ]) def forward(self, x): for layer in self.layers: x = layer(x) return x model = SimpleModel() input_tensor = torch.randn(2, 10) output = model(input_tensor) print(output)
Output
tensor([[ 0.0313, 0.0346, -0.0108, 0.0077, 0.0077],
[-0.0318, 0.0273, -0.0093, 0.0013, 0.0039]], grad_fn=<AddmmBackward0>)
Common Pitfalls
One common mistake is using a regular Python list to store layers instead of nn.ModuleList. This causes PyTorch to not register the layers, so their parameters won't be updated during training. Another pitfall is forgetting to iterate over the ModuleList in the forward method.
python
import torch.nn as nn # Wrong: Using a Python list (layers won't be registered) layers_wrong = [nn.Linear(10, 5), nn.ReLU()] class ModelWrong(nn.Module): def __init__(self): super().__init__() self.layers = layers_wrong # Not a ModuleList def forward(self, x): for layer in self.layers: x = layer(x) return x # Right: Using nn.ModuleList layers_right = nn.ModuleList([nn.Linear(10, 5), nn.ReLU()]) class ModelRight(nn.Module): def __init__(self): super().__init__() self.layers = layers_right def forward(self, x): for layer in self.layers: x = layer(x) return x
Quick Reference
- Initialization:
nn.ModuleList([modules]) - Add module:
module_list.append(module) - Access module:
module_list[index] - Use in forward: Iterate over
ModuleListto apply layers - Difference from list: Registers modules for training and saving
Key Takeaways
Use nn.ModuleList to store layers so PyTorch tracks their parameters automatically.
You can add layers dynamically and iterate over them in the forward method.
Avoid using plain Python lists for layers as they won't be registered for training.
ModuleList behaves like a list but integrates with PyTorch's model system.
Always iterate over ModuleList in forward to apply each layer to input data.