How to Use collate_fn in PyTorch DataLoader
Use
collate_fn in PyTorch's DataLoader to customize how individual data samples are combined into a batch. Pass a function to collate_fn that takes a list of samples and returns a batch in the desired format.Syntax
The collate_fn parameter in DataLoader expects a function that takes a list of data samples and returns a batch. This function controls how samples are merged into a batch.
- collate_fn: A callable that processes a list of samples into a batch.
- DataLoader(..., collate_fn=your_function): Pass your custom function here.
python
def collate_fn(batch): # batch is a list of samples # process and return a batch processed_batch = batch # example placeholder return processed_batch loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
Example
This example shows how to use collate_fn to pad sequences of different lengths in a batch so they can be processed together.
python
import torch from torch.utils.data import DataLoader # Sample dataset with variable-length sequences dataset = [torch.tensor([1, 2, 3]), torch.tensor([4, 5]), torch.tensor([6]), torch.tensor([7, 8, 9, 10])] # Custom collate_fn to pad sequences to the max length in batch def pad_collate(batch): max_len = max(x.size(0) for x in batch) padded_batch = torch.zeros(len(batch), max_len, dtype=torch.long) for i, seq in enumerate(batch): padded_batch[i, :seq.size(0)] = seq return padded_batch loader = DataLoader(dataset, batch_size=2, collate_fn=pad_collate) for batch in loader: print(batch)
Output
tensor([[1, 2, 3, 0],
[4, 5, 0, 0]])
tensor([[ 6, 0, 0, 0],
[ 7, 8, 9, 10]])
Common Pitfalls
- Not returning the correct batch format causes errors during training.
- Forgetting to handle variable-length data properly can break batch processing.
- Using the default
collate_fnwhen custom data types or structures are present leads to crashes.
Always ensure your collate_fn returns a batch compatible with your model input.
python
import torch from torch.utils.data import DataLoader dataset = [(1, 2), (3, 4)] # Wrong collate_fn: returns a list instead of tensor batch def wrong_collate(batch): return batch # Correct collate_fn: converts list of tuples to tensor batch def right_collate(batch): return torch.tensor(batch) loader_wrong = DataLoader(dataset, batch_size=2, collate_fn=wrong_collate) loader_right = DataLoader(dataset, batch_size=2, collate_fn=right_collate) print('Wrong batch:', next(iter(loader_wrong))) print('Right batch:', next(iter(loader_right)))
Output
Wrong batch: [(1, 2), (3, 4)]
Right batch: tensor([[1, 2],
[3, 4]])
Quick Reference
- Purpose: Customize how samples are combined into batches.
- Input: List of samples from dataset.
- Output: Batch in desired format (tensor, dict, padded sequences, etc.).
- Use cases: Padding sequences, combining dicts, handling complex data.
- Pass to DataLoader:
DataLoader(dataset, collate_fn=your_function).
Key Takeaways
Use collate_fn to control how DataLoader batches samples together.
collate_fn receives a list of samples and returns a processed batch.
Custom collate_fn is essential for variable-length or complex data.
Always return batches compatible with your model input format.
Test collate_fn separately to avoid batch processing errors.