0
0
PytorchHow-ToBeginner · 3 min read

How to Use nn.ModuleDict in PyTorch: Syntax and Example

Use nn.ModuleDict in PyTorch to store a dictionary of modules with string keys, allowing easy access and management of layers. Initialize it with a dictionary of modules, then access or iterate over them by keys like a normal Python dictionary.
📐

Syntax

The nn.ModuleDict is initialized with a dictionary where keys are strings and values are PyTorch modules. You can access modules by their keys and use them in your model's forward method.

Example parts:

  • nn.ModuleDict({'layer1': nn.Linear(10, 20), 'layer2': nn.ReLU()}): creates a ModuleDict with two layers.
  • Access with self.layers['layer1'] to use the module.
python
import torch.nn as nn

layers = nn.ModuleDict({
    'fc1': nn.Linear(10, 20),
    'relu': nn.ReLU(),
    'fc2': nn.Linear(20, 5)
})
💻

Example

This example shows how to define a simple neural network using nn.ModuleDict to store layers by name and use them in the forward method.

python
import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleDict({
            'fc1': nn.Linear(10, 20),
            'relu': nn.ReLU(),
            'fc2': nn.Linear(20, 5)
        })

    def forward(self, x):
        x = self.layers['fc1'](x)
        x = self.layers['relu'](x)
        x = self.layers['fc2'](x)
        return x

model = SimpleNet()
input_tensor = torch.randn(1, 10)
output = model(input_tensor)
print(output)
Output
tensor([[ 0.0497, 0.0807, -0.0223, 0.0106, -0.0400]], grad_fn=<AddmmBackward0>)
⚠️

Common Pitfalls

Common mistakes when using nn.ModuleDict include:

  • Not registering modules properly by using a normal Python dict instead of nn.ModuleDict, which causes layers not to be tracked by PyTorch.
  • Trying to use non-string keys, which is not supported.
  • Forgetting to call the modules inside forward and instead just accessing them without parentheses.
python
import torch.nn as nn

# Wrong: Using normal dict - modules won't register
layers_wrong = {
    'fc1': nn.Linear(10, 20),
    'relu': nn.ReLU()
}

# Right: Use nn.ModuleDict to register modules
layers_right = nn.ModuleDict({
    'fc1': nn.Linear(10, 20),
    'relu': nn.ReLU()
})
📊

Quick Reference

nn.ModuleDict Cheat Sheet:

  • Initialize with a dict of modules: nn.ModuleDict({'name': module})
  • Access modules by key: self.layers['name'](input)
  • Supports iteration: for name, module in self.layers.items():
  • Only string keys allowed
  • Modules are properly registered for training and saving

Key Takeaways

Use nn.ModuleDict to store layers with string keys so PyTorch tracks them properly.
Access modules by keys and call them like functions inside the forward method.
Do not use a normal Python dict for modules, or PyTorch won't register them.
Only string keys are allowed in nn.ModuleDict.
You can iterate over nn.ModuleDict like a normal dictionary.