How to Reshape Tensor in PyTorch: Syntax and Examples
In PyTorch, you can reshape a tensor using
tensor.reshape(new_shape) or tensor.view(new_shape). Both change the tensor's shape without altering its data, but reshape is more flexible and safer for non-contiguous tensors.Syntax
The main ways to reshape a tensor in PyTorch are:
tensor.reshape(new_shape): Returns a tensor with the specified shape. It can handle non-contiguous tensors.tensor.view(new_shape): Returns a tensor with the specified shape but requires the tensor to be contiguous in memory.
Here, new_shape is a tuple or list of integers defining the desired dimensions. You can use -1 for one dimension to infer its size automatically.
python
reshaped_tensor = tensor.reshape(new_shape) reshaped_tensor = tensor.view(*new_shape)
Example
This example shows how to reshape a 2D tensor into a 1D tensor and back using reshape and view.
python
import torch # Create a 2D tensor of shape (2, 3) tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) print('Original tensor shape:', tensor.shape) # Reshape to 1D tensor with 6 elements reshaped = tensor.reshape(-1) print('Reshaped tensor:', reshaped) print('Reshaped tensor shape:', reshaped.shape) # Reshape back to (3, 2) using view reshaped_back = reshaped.view(3, 2) print('Reshaped back tensor:', reshaped_back) print('Reshaped back tensor shape:', reshaped_back.shape)
Output
Original tensor shape: torch.Size([2, 3])
Reshaped tensor: tensor([1, 2, 3, 4, 5, 6])
Reshaped tensor shape: torch.Size([6])
Reshaped back tensor: tensor([[1, 2],
[3, 4],
[5, 6]])
Reshaped back tensor shape: torch.Size([3, 2])
Common Pitfalls
Common mistakes when reshaping tensors include:
- Using
view()on a non-contiguous tensor causes an error. Usereshape()instead. - Specifying a new shape that does not match the total number of elements causes a runtime error.
- Forgetting that only one dimension can be
-1to infer size automatically.
python
import torch # Create a non-contiguous tensor by transposing tensor = torch.arange(6).reshape(2, 3).t() print('Tensor shape:', tensor.shape) # This will raise an error because tensor is not contiguous try: tensor.view(3, 2) except RuntimeError as e: print('Error with view:', e) # Correct way using reshape reshaped = tensor.reshape(3, 2) print('Reshaped tensor with reshape:', reshaped)
Output
Tensor shape: torch.Size([3, 2])
Error with view: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape() instead.
Reshaped tensor with reshape: tensor([[0, 3],
[1, 4],
[2, 5]])
Quick Reference
Summary tips for reshaping tensors in PyTorch:
- Use
reshape()for flexible reshaping, especially with non-contiguous tensors. - Use
view()only if you are sure the tensor is contiguous. - Use
-1in shape to let PyTorch infer the dimension size automatically. - Always ensure the total number of elements remains the same before and after reshaping.
Key Takeaways
Use tensor.reshape(new_shape) to safely reshape tensors, even if they are non-contiguous.
tensor.view(new_shape) requires the tensor to be contiguous and can cause errors otherwise.
Only one dimension in new_shape can be -1 to automatically infer its size.
The total number of elements must remain the same before and after reshaping.
reshape() is generally preferred over view() for its flexibility.