How to Use Reshape in PyTorch: Syntax and Examples
In PyTorch, you use
tensor.reshape(new_shape) to change the shape of a tensor without changing its data. The new_shape can be a tuple or list specifying the desired dimensions, and one dimension can be set to -1 to infer its size automatically.Syntax
The basic syntax to reshape a tensor in PyTorch is:
tensor.reshape(new_shape): Returns a tensor with the same data but a different shape.new_shape: A tuple or list of integers specifying the desired dimensions.- Use
-1innew_shapeto let PyTorch automatically calculate that dimension size.
python
reshaped_tensor = tensor.reshape(new_shape)
Example
This example shows how to reshape a 1D tensor of 6 elements into a 2D tensor with 2 rows and 3 columns.
python
import torch tensor = torch.arange(6) # Creates tensor([0, 1, 2, 3, 4, 5]) reshaped = tensor.reshape((2, 3)) print("Original tensor:", tensor) print("Reshaped tensor (2x3):\n", reshaped)
Output
Original tensor: tensor([0, 1, 2, 3, 4, 5])
Reshaped tensor (2x3):
tensor([[0, 1, 2],
[3, 4, 5]])
Common Pitfalls
Common mistakes when using reshape include:
- Trying to reshape to a shape that does not match the total number of elements.
- Forgetting that
-1can only be used once in the new shape. - Confusing
reshapewithview;reshapecan handle non-contiguous tensors, whileviewrequires contiguous memory.
python
import torch # Wrong: total elements mismatch try: tensor = torch.arange(6) tensor.reshape((4, 2)) # 4*2=8 elements, but tensor has 6 except Exception as e: print("Error:", e) # Right: using -1 to infer dimension tensor = torch.arange(6) reshaped = tensor.reshape((2, -1)) # Automatically infers 3 columns print("Reshaped with -1:", reshaped)
Output
Error: shape '[4, 2]' is invalid for input of size 6
Reshaped with -1: tensor([[0, 1, 2],
[3, 4, 5]])
Quick Reference
| Usage | Description |
|---|---|
| tensor.reshape(new_shape) | Change tensor shape without changing data |
| -1 in new_shape | Automatically infer dimension size |
| new_shape must match total elements | Total elements before and after reshape must be equal |
| Use reshape over view for non-contiguous tensors | reshape works more generally than view |
Key Takeaways
Use tensor.reshape(new_shape) to change tensor shape without altering data.
One dimension in new_shape can be -1 to let PyTorch infer its size automatically.
The total number of elements must remain the same before and after reshape.
reshape works on non-contiguous tensors, unlike view which requires contiguous memory.
Common errors come from mismatched shapes or multiple -1 dimensions.