How to Use view() in PyTorch for Tensor Reshaping
In PyTorch,
view() is used to reshape a tensor without changing its data. You provide the new shape as arguments, and one dimension can be set to -1 to infer its size automatically.Syntax
The view() function changes the shape of a tensor. You call it on a tensor and pass the new shape as arguments. Use -1 for one dimension to let PyTorch calculate it automatically.
- tensor.view(new_shape): Returns a tensor with the specified shape.
new_shape: A tuple or list of integers representing the desired shape.-1: Automatically infers the size of this dimension.
python
tensor.view(shape) # Example: tensor.view(2, 3) # reshape tensor to 2 rows and 3 columns
Example
This example shows how to reshape a 1D tensor of 6 elements into a 2D tensor with 2 rows and 3 columns using view(). It also demonstrates using -1 to automatically infer one dimension.
python
import torch # Create a 1D tensor with 6 elements x = torch.arange(6) print('Original tensor:', x) # Reshape to 2 rows and 3 columns x_reshaped = x.view(2, 3) print('Reshaped tensor (2,3):\n', x_reshaped) # Use -1 to infer the second dimension automatically x_reshaped_auto = x.view(2, -1) print('Reshaped tensor with -1 (2, -1):\n', x_reshaped_auto)
Output
Original tensor: tensor([0, 1, 2, 3, 4, 5])
Reshaped tensor (2,3):
tensor([[0, 1, 2],
[3, 4, 5]])
Reshaped tensor with -1 (2, -1):
tensor([[0, 1, 2],
[3, 4, 5]])
Common Pitfalls
Common mistakes when using view() include:
- Trying to reshape a tensor to a shape that does not match the total number of elements.
- Using more than one
-1in the shape, which is not allowed. - Not ensuring the tensor is contiguous in memory before calling
view(), which can cause errors.
Always check the total elements match and use tensor.contiguous() if needed.
python
import torch x = torch.arange(6) # Wrong: shape does not match total elements try: x.view(4, 2) except RuntimeError as e: print('Error:', e) # Wrong: more than one -1 try: x.view(-1, -1) except RuntimeError as e: print('Error:', e) # Correct: make tensor contiguous before view if needed x_t = x.t() # transpose makes non-contiguous tensor try: x_t.view(6) except RuntimeError as e: print('Error:', e) x_t_contig = x_t.contiguous() print('Contiguous tensor view:', x_t_contig.view(6))
Output
Error: shape '[4, 2]' is invalid for input of size 6
Error: only one dimension can be inferred
Error: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces)
Contiguous tensor view: tensor([0, 1, 2, 3, 4, 5])
Quick Reference
| Usage | Description |
|---|---|
| tensor.view(shape) | Reshape tensor to given shape |
| -1 in shape | Automatically infer dimension size |
| Only one -1 allowed | Multiple -1s cause error |
| tensor.contiguous() | Make tensor memory contiguous before view |
| Total elements must match | Shape must fit total number of elements |
Key Takeaways
Use tensor.view(new_shape) to reshape tensors without changing data.
Only one dimension can be set to -1 to infer its size automatically.
Ensure the total number of elements matches the new shape.
Call tensor.contiguous() if the tensor is not contiguous before using view().
Common errors come from incompatible shapes or multiple -1s.