How to Unsqueeze Tensor in PyTorch: Syntax and Examples
In PyTorch, you can add a dimension to a tensor using
tensor.unsqueeze(dim), where dim is the position to insert the new axis. This returns a new tensor with one more dimension, useful for adjusting tensor shapes for operations.Syntax
The unsqueeze method adds a dimension of size 1 at the specified dim index in the tensor shape.
tensor: The original PyTorch tensor.dim: Integer index where the new dimension is inserted. It can be negative to count from the end.- Returns a new tensor with one more dimension.
python
new_tensor = tensor.unsqueeze(dim)
Example
This example shows how to add a new dimension to a 1D tensor to make it 2D by unsqueezing at dimension 0 and dimension 1.
python
import torch # Original 1D tensor x = torch.tensor([10, 20, 30]) print('Original tensor:', x) print('Shape:', x.shape) # Unsqueeze at dim=0 (adds new outer dimension) x_unsq0 = x.unsqueeze(0) print('\nAfter unsqueeze(0):') print(x_unsq0) print('Shape:', x_unsq0.shape) # Unsqueeze at dim=1 (adds new inner dimension) x_unsq1 = x.unsqueeze(1) print('\nAfter unsqueeze(1):') print(x_unsq1) print('Shape:', x_unsq1.shape)
Output
Original tensor: tensor([10, 20, 30])
Shape: torch.Size([3])
After unsqueeze(0):
tensor([[10, 20, 30]])
Shape: torch.Size([1, 3])
After unsqueeze(1):
tensor([[10],
[20],
[30]])
Shape: torch.Size([3, 1])
Common Pitfalls
Common mistakes when using unsqueeze include:
- Using an invalid
dimindex outside the range[-tensor.dim()-1, tensor.dim()+1). - Confusing
unsqueezewithreshapeorview, which change shape differently. - Forgetting that
unsqueezereturns a new tensor and does not modify the original tensor in place.
python
import torch x = torch.tensor([1, 2, 3]) # Wrong: dim too large try: x.unsqueeze(3) except IndexError as e: print('Error:', e) # Correct usage x_unsq = x.unsqueeze(1) print('Correct unsqueeze:', x_unsq.shape)
Output
Error: Dimension out of range (expected to be in range of [-2, 2], but got 3)
Correct unsqueeze: torch.Size([3, 1])
Quick Reference
| Function | Description | Example |
|---|---|---|
| unsqueeze(dim) | Adds a dimension of size 1 at index dim | x.unsqueeze(0) adds outer dimension |
| squeeze(dim) | Removes dimension of size 1 at index dim | x.squeeze(1) removes dim if size 1 |
| reshape(shape) | Changes tensor shape to given shape | x.reshape(1, 3) changes shape to (1,3) |
Key Takeaways
Use tensor.unsqueeze(dim) to add a new dimension of size 1 at position dim.
The dim index can be negative to count from the end of tensor dimensions.
unsqueeze returns a new tensor and does not change the original tensor.
Invalid dim values cause an IndexError; ensure dim is in valid range.
unsqueeze is useful to prepare tensors for operations requiring specific shapes.