How to Concatenate Tensors in PyTorch: Syntax and Examples
In PyTorch, you concatenate tensors using the
torch.cat function, which joins a sequence of tensors along a specified dimension. You must ensure the tensors have the same shape except in the concatenation dimension.Syntax
The basic syntax to concatenate tensors in PyTorch is:
torch.cat(tensors, dim=0, *, out=None)
Here, tensors is a sequence (like a list or tuple) of tensors to join.
dim specifies the dimension along which to concatenate.
All tensors must have the same shape except in the dim dimension.
python
torch.cat(tensors, dim=0, *, out=None)
Example
This example shows how to concatenate two 2D tensors along rows (dim=0) and columns (dim=1).
python
import torch # Create two 2x3 tensors tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]]) tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]]) # Concatenate along rows (dim=0) concat_dim0 = torch.cat((tensor1, tensor2), dim=0) # Concatenate along columns (dim=1) concat_dim1 = torch.cat((tensor1, tensor2), dim=1) print('Concatenate along dim=0 (rows):') print(concat_dim0) print('\nConcatenate along dim=1 (columns):') print(concat_dim1)
Output
Concatenate along dim=0 (rows):
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
Concatenate along dim=1 (columns):
tensor([[ 1, 2, 3, 7, 8, 9],
[ 4, 5, 6, 10, 11, 12]])
Common Pitfalls
Common mistakes when concatenating tensors include:
- Trying to concatenate tensors with mismatched shapes in non-concatenation dimensions.
- Using
dimvalues outside the tensor's dimension range. - Passing a single tensor instead of a sequence of tensors.
Always check tensor shapes before concatenation.
python
import torch # Incorrect: shapes mismatch on dim=1 try: t1 = torch.randn(2, 3) t2 = torch.randn(2, 4) torch.cat((t1, t2), dim=0) # This works because dim=0 sizes differ torch.cat((t1, t2), dim=1) # This will raise an error except RuntimeError as e: print('Error:', e) # Correct: concatenate along dim=0 where sizes match result = torch.cat((t1, t2), dim=0) print('Concatenation along dim=0 works with different dim=1 sizes.')
Output
Error: Sizes of tensors must match except in dimension 1. Got 3 and 4 in dimension 1
Concatenation along dim=0 works with different dim=1 sizes.
Quick Reference
Tips for concatenating tensors in PyTorch:
- Use
torch.catto join tensors along an existing dimension. - All tensors must have the same shape except in the concatenation dimension.
- Check tensor shapes with
.shapebefore concatenation. - Use
dimto specify the axis (0 for rows, 1 for columns in 2D tensors).
Key Takeaways
Use torch.cat to concatenate tensors along a specified dimension.
Ensure all tensors have matching shapes except in the concatenation dimension.
Specify the dimension with the dim parameter, starting from 0.
Check tensor shapes before concatenation to avoid runtime errors.
Pass a sequence of tensors, not a single tensor, to torch.cat.