How to Use nn.ModuleDict in PyTorch: Syntax and Example
Use
nn.ModuleDict in PyTorch to store a dictionary of modules with string keys, allowing easy access and management of layers. Initialize it with a dictionary of modules, then access or iterate over them by keys like a normal Python dictionary.Syntax
The nn.ModuleDict is initialized with a dictionary where keys are strings and values are PyTorch modules. You can access modules by their keys and use them in your model's forward method.
Example parts:
nn.ModuleDict({'layer1': nn.Linear(10, 20), 'layer2': nn.ReLU()}): creates a ModuleDict with two layers.- Access with
self.layers['layer1']to use the module.
python
import torch.nn as nn layers = nn.ModuleDict({ 'fc1': nn.Linear(10, 20), 'relu': nn.ReLU(), 'fc2': nn.Linear(20, 5) })
Example
This example shows how to define a simple neural network using nn.ModuleDict to store layers by name and use them in the forward method.
python
import torch import torch.nn as nn class SimpleNet(nn.Module): def __init__(self): super().__init__() self.layers = nn.ModuleDict({ 'fc1': nn.Linear(10, 20), 'relu': nn.ReLU(), 'fc2': nn.Linear(20, 5) }) def forward(self, x): x = self.layers['fc1'](x) x = self.layers['relu'](x) x = self.layers['fc2'](x) return x model = SimpleNet() input_tensor = torch.randn(1, 10) output = model(input_tensor) print(output)
Output
tensor([[ 0.0497, 0.0807, -0.0223, 0.0106, -0.0400]], grad_fn=<AddmmBackward0>)
Common Pitfalls
Common mistakes when using nn.ModuleDict include:
- Not registering modules properly by using a normal Python dict instead of
nn.ModuleDict, which causes layers not to be tracked by PyTorch. - Trying to use non-string keys, which is not supported.
- Forgetting to call the modules inside
forwardand instead just accessing them without parentheses.
python
import torch.nn as nn # Wrong: Using normal dict - modules won't register layers_wrong = { 'fc1': nn.Linear(10, 20), 'relu': nn.ReLU() } # Right: Use nn.ModuleDict to register modules layers_right = nn.ModuleDict({ 'fc1': nn.Linear(10, 20), 'relu': nn.ReLU() })
Quick Reference
nn.ModuleDict Cheat Sheet:
- Initialize with a dict of modules:
nn.ModuleDict({'name': module}) - Access modules by key:
self.layers['name'](input) - Supports iteration:
for name, module in self.layers.items(): - Only string keys allowed
- Modules are properly registered for training and saving
Key Takeaways
Use nn.ModuleDict to store layers with string keys so PyTorch tracks them properly.
Access modules by keys and call them like functions inside the forward method.
Do not use a normal Python dict for modules, or PyTorch won't register them.
Only string keys are allowed in nn.ModuleDict.
You can iterate over nn.ModuleDict like a normal dictionary.