0
0
PytorchHow-ToBeginner · 3 min read

How to Stack Tensors in PyTorch: Syntax and Examples

In PyTorch, you can stack tensors using torch.stack(tensors, dim=0), which joins a sequence of tensors along a new dimension specified by dim. All tensors must have the same shape to stack them successfully.
📐

Syntax

The basic syntax to stack tensors in PyTorch is:

torch.stack(tensors, dim=0)
  • tensors: A sequence (like a list or tuple) of tensors to stack.
  • dim: The dimension along which to insert the new axis and stack the tensors. Default is 0.

All tensors must have the same shape for stacking to work.

python
torch.stack(tensors, dim=0)
💻

Example

This example shows how to stack three 1D tensors along a new dimension 0 and then along dimension 1.

python
import torch

# Create three 1D tensors
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])
tensor3 = torch.tensor([7, 8, 9])

# Stack along dimension 0 (new outer dimension)
stacked_dim0 = torch.stack([tensor1, tensor2, tensor3], dim=0)

# Stack along dimension 1 (new inner dimension)
stacked_dim1 = torch.stack([tensor1, tensor2, tensor3], dim=1)

print('Stacked along dim=0:')
print(stacked_dim0)
print('Shape:', stacked_dim0.shape)

print('\nStacked along dim=1:')
print(stacked_dim1)
print('Shape:', stacked_dim1.shape)
Output
Stacked along dim=0: tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) Shape: torch.Size([3, 3]) Stacked along dim=1: tensor([[1, 4, 7], [2, 5, 8], [3, 6, 9]]) Shape: torch.Size([3, 3])
⚠️

Common Pitfalls

1. Shape mismatch: All tensors must have the exact same shape. Trying to stack tensors of different shapes causes an error.

2. Confusing torch.stack with torch.cat: torch.stack adds a new dimension, while torch.cat concatenates along an existing dimension.

3. Wrong dimension index: Using a dim value outside the allowed range (negative or too large) will cause an error.

python
import torch

# Wrong: tensors have different shapes
try:
    t1 = torch.tensor([1, 2])
    t2 = torch.tensor([3, 4, 5])
    torch.stack([t1, t2])
except Exception as e:
    print('Error stacking tensors with different shapes:', e)

# Correct: tensors have same shape
t1 = torch.tensor([1, 2])
t2 = torch.tensor([3, 4])
stacked = torch.stack([t1, t2], dim=0)
print('Correct stacking result:', stacked)
Output
Error stacking tensors with different shapes: stack expects each tensor to be equal size, but got [2] at entry 0 and [3] at entry 1 Correct stacking result: tensor([[1, 2], [3, 4]])
📊

Quick Reference

Tips for stacking tensors in PyTorch:

  • Use torch.stack to add a new dimension and combine tensors.
  • All tensors must have the same shape.
  • The dim parameter controls where the new dimension is inserted.
  • Use torch.cat to join tensors along an existing dimension instead.

Key Takeaways

Use torch.stack(tensors, dim) to join tensors along a new dimension.
All tensors must have the same shape to stack successfully.
The dim argument sets the position of the new dimension in the output tensor.
torch.stack adds a new dimension; torch.cat joins along existing dimensions.
Check tensor shapes carefully to avoid stacking errors.