0
0
PytorchHow-ToBeginner · 3 min read

How to Unsqueeze Tensor in PyTorch: Syntax and Examples

In PyTorch, you can add a dimension to a tensor using tensor.unsqueeze(dim), where dim is the position to insert the new axis. This returns a new tensor with one more dimension, useful for adjusting tensor shapes for operations.
📐

Syntax

The unsqueeze method adds a dimension of size 1 at the specified dim index in the tensor shape.

  • tensor: The original PyTorch tensor.
  • dim: Integer index where the new dimension is inserted. It can be negative to count from the end.
  • Returns a new tensor with one more dimension.
python
new_tensor = tensor.unsqueeze(dim)
💻

Example

This example shows how to add a new dimension to a 1D tensor to make it 2D by unsqueezing at dimension 0 and dimension 1.

python
import torch

# Original 1D tensor
x = torch.tensor([10, 20, 30])
print('Original tensor:', x)
print('Shape:', x.shape)

# Unsqueeze at dim=0 (adds new outer dimension)
x_unsq0 = x.unsqueeze(0)
print('\nAfter unsqueeze(0):')
print(x_unsq0)
print('Shape:', x_unsq0.shape)

# Unsqueeze at dim=1 (adds new inner dimension)
x_unsq1 = x.unsqueeze(1)
print('\nAfter unsqueeze(1):')
print(x_unsq1)
print('Shape:', x_unsq1.shape)
Output
Original tensor: tensor([10, 20, 30]) Shape: torch.Size([3]) After unsqueeze(0): tensor([[10, 20, 30]]) Shape: torch.Size([1, 3]) After unsqueeze(1): tensor([[10], [20], [30]]) Shape: torch.Size([3, 1])
⚠️

Common Pitfalls

Common mistakes when using unsqueeze include:

  • Using an invalid dim index outside the range [-tensor.dim()-1, tensor.dim()+1).
  • Confusing unsqueeze with reshape or view, which change shape differently.
  • Forgetting that unsqueeze returns a new tensor and does not modify the original tensor in place.
python
import torch
x = torch.tensor([1, 2, 3])

# Wrong: dim too large
try:
    x.unsqueeze(3)
except IndexError as e:
    print('Error:', e)

# Correct usage
x_unsq = x.unsqueeze(1)
print('Correct unsqueeze:', x_unsq.shape)
Output
Error: Dimension out of range (expected to be in range of [-2, 2], but got 3) Correct unsqueeze: torch.Size([3, 1])
📊

Quick Reference

FunctionDescriptionExample
unsqueeze(dim)Adds a dimension of size 1 at index dimx.unsqueeze(0) adds outer dimension
squeeze(dim)Removes dimension of size 1 at index dimx.squeeze(1) removes dim if size 1
reshape(shape)Changes tensor shape to given shapex.reshape(1, 3) changes shape to (1,3)

Key Takeaways

Use tensor.unsqueeze(dim) to add a new dimension of size 1 at position dim.
The dim index can be negative to count from the end of tensor dimensions.
unsqueeze returns a new tensor and does not change the original tensor.
Invalid dim values cause an IndexError; ensure dim is in valid range.
unsqueeze is useful to prepare tensors for operations requiring specific shapes.