0
0
PytorchHow-ToBeginner · 3 min read

How to Concatenate Tensors in PyTorch: Syntax and Examples

In PyTorch, you concatenate tensors using the torch.cat function, which joins a sequence of tensors along a specified dimension. You must ensure the tensors have the same shape except in the concatenation dimension.
📐

Syntax

The basic syntax to concatenate tensors in PyTorch is:

  • torch.cat(tensors, dim=0, *, out=None)

Here, tensors is a sequence (like a list or tuple) of tensors to join.

dim specifies the dimension along which to concatenate.

All tensors must have the same shape except in the dim dimension.

python
torch.cat(tensors, dim=0, *, out=None)
💻

Example

This example shows how to concatenate two 2D tensors along rows (dim=0) and columns (dim=1).

python
import torch

# Create two 2x3 tensors
tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]])

# Concatenate along rows (dim=0)
concat_dim0 = torch.cat((tensor1, tensor2), dim=0)

# Concatenate along columns (dim=1)
concat_dim1 = torch.cat((tensor1, tensor2), dim=1)

print('Concatenate along dim=0 (rows):')
print(concat_dim0)
print('\nConcatenate along dim=1 (columns):')
print(concat_dim1)
Output
Concatenate along dim=0 (rows): tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]]) Concatenate along dim=1 (columns): tensor([[ 1, 2, 3, 7, 8, 9], [ 4, 5, 6, 10, 11, 12]])
⚠️

Common Pitfalls

Common mistakes when concatenating tensors include:

  • Trying to concatenate tensors with mismatched shapes in non-concatenation dimensions.
  • Using dim values outside the tensor's dimension range.
  • Passing a single tensor instead of a sequence of tensors.

Always check tensor shapes before concatenation.

python
import torch

# Incorrect: shapes mismatch on dim=1
try:
    t1 = torch.randn(2, 3)
    t2 = torch.randn(2, 4)
    torch.cat((t1, t2), dim=0)  # This works because dim=0 sizes differ
    torch.cat((t1, t2), dim=1)  # This will raise an error
except RuntimeError as e:
    print('Error:', e)

# Correct: concatenate along dim=0 where sizes match
result = torch.cat((t1, t2), dim=0)
print('Concatenation along dim=0 works with different dim=1 sizes.')
Output
Error: Sizes of tensors must match except in dimension 1. Got 3 and 4 in dimension 1 Concatenation along dim=0 works with different dim=1 sizes.
📊

Quick Reference

Tips for concatenating tensors in PyTorch:

  • Use torch.cat to join tensors along an existing dimension.
  • All tensors must have the same shape except in the concatenation dimension.
  • Check tensor shapes with .shape before concatenation.
  • Use dim to specify the axis (0 for rows, 1 for columns in 2D tensors).

Key Takeaways

Use torch.cat to concatenate tensors along a specified dimension.
Ensure all tensors have matching shapes except in the concatenation dimension.
Specify the dimension with the dim parameter, starting from 0.
Check tensor shapes before concatenation to avoid runtime errors.
Pass a sequence of tensors, not a single tensor, to torch.cat.