How to Use model.train() in PyTorch for Training Mode
model.train() to set your model to training mode. This activates layers like dropout and batch normalization to behave correctly during training. Call it before your training loop to ensure proper model behavior.Syntax
The model.train() method switches the model to training mode. This affects certain layers like dropout and batch normalization, which behave differently during training and evaluation.
model: Your neural network instance (usually a subclass oftorch.nn.Module).train(): A method that sets the model to training mode.
model.train()
Example
This example shows a simple training loop where model.train() is called before training to enable training-specific behaviors like dropout.
import torch import torch.nn as nn import torch.optim as optim # Define a simple model with dropout class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 2) self.dropout = nn.Dropout(p=0.5) def forward(self, x): x = self.dropout(x) return self.linear(x) # Create model, loss, optimizer model = SimpleModel() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1) # Dummy input and target inputs = torch.randn(5, 10) targets = torch.tensor([0, 1, 0, 1, 0]) # Set model to training mode model.train() # Forward pass outputs = model(inputs) loss = criterion(outputs, targets) # Backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() print(f"Loss: {loss.item():.4f}")
Common Pitfalls
One common mistake is forgetting to call model.train() before training, which leaves the model in evaluation mode by default. This causes layers like dropout and batch normalization to behave incorrectly, leading to poor training results.
Another mistake is calling model.eval() during training, which disables dropout and uses running statistics in batch normalization, not suitable for training.
import torch.nn as nn model = nn.Dropout(p=0.5) # Wrong: model is in eval mode during training model.eval() print(f"Output in eval mode: {model(torch.ones(5))}") # No dropout applied # Right: model in train mode during training model.train() print(f"Output in train mode: {model(torch.ones(5))}") # Dropout applied (some zeros)
Quick Reference
Remember these tips when using model.train():
- Call
model.train()before your training loop. - Use
model.eval()when evaluating or testing your model. - Training mode enables dropout and batch normalization to update and behave correctly.
- Evaluation mode disables dropout and uses fixed batch normalization statistics.