0
0
PytorchHow-ToBeginner · 4 min read

How to Use collate_fn in PyTorch DataLoader

Use collate_fn in PyTorch's DataLoader to customize how individual data samples are combined into a batch. Pass a function to collate_fn that takes a list of samples and returns a batch in the desired format.
📐

Syntax

The collate_fn parameter in DataLoader expects a function that takes a list of data samples and returns a batch. This function controls how samples are merged into a batch.

  • collate_fn: A callable that processes a list of samples into a batch.
  • DataLoader(..., collate_fn=your_function): Pass your custom function here.
python
def collate_fn(batch):
    # batch is a list of samples
    # process and return a batch
    processed_batch = batch  # example placeholder
    return processed_batch

loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
💻

Example

This example shows how to use collate_fn to pad sequences of different lengths in a batch so they can be processed together.

python
import torch
from torch.utils.data import DataLoader

# Sample dataset with variable-length sequences
dataset = [torch.tensor([1, 2, 3]), torch.tensor([4, 5]), torch.tensor([6]), torch.tensor([7, 8, 9, 10])]

# Custom collate_fn to pad sequences to the max length in batch
def pad_collate(batch):
    max_len = max(x.size(0) for x in batch)
    padded_batch = torch.zeros(len(batch), max_len, dtype=torch.long)
    for i, seq in enumerate(batch):
        padded_batch[i, :seq.size(0)] = seq
    return padded_batch

loader = DataLoader(dataset, batch_size=2, collate_fn=pad_collate)

for batch in loader:
    print(batch)
Output
tensor([[1, 2, 3, 0], [4, 5, 0, 0]]) tensor([[ 6, 0, 0, 0], [ 7, 8, 9, 10]])
⚠️

Common Pitfalls

  • Not returning the correct batch format causes errors during training.
  • Forgetting to handle variable-length data properly can break batch processing.
  • Using the default collate_fn when custom data types or structures are present leads to crashes.

Always ensure your collate_fn returns a batch compatible with your model input.

python
import torch
from torch.utils.data import DataLoader

dataset = [(1, 2), (3, 4)]

# Wrong collate_fn: returns a list instead of tensor batch
def wrong_collate(batch):
    return batch

# Correct collate_fn: converts list of tuples to tensor batch
def right_collate(batch):
    return torch.tensor(batch)

loader_wrong = DataLoader(dataset, batch_size=2, collate_fn=wrong_collate)
loader_right = DataLoader(dataset, batch_size=2, collate_fn=right_collate)

print('Wrong batch:', next(iter(loader_wrong)))
print('Right batch:', next(iter(loader_right)))
Output
Wrong batch: [(1, 2), (3, 4)] Right batch: tensor([[1, 2], [3, 4]])
📊

Quick Reference

  • Purpose: Customize how samples are combined into batches.
  • Input: List of samples from dataset.
  • Output: Batch in desired format (tensor, dict, padded sequences, etc.).
  • Use cases: Padding sequences, combining dicts, handling complex data.
  • Pass to DataLoader: DataLoader(dataset, collate_fn=your_function).

Key Takeaways

Use collate_fn to control how DataLoader batches samples together.
collate_fn receives a list of samples and returns a processed batch.
Custom collate_fn is essential for variable-length or complex data.
Always return batches compatible with your model input format.
Test collate_fn separately to avoid batch processing errors.