0
0
PytorchHow-ToBeginner · 4 min read

How to Reshape Tensor in PyTorch: Syntax and Examples

In PyTorch, you can reshape a tensor using tensor.reshape(new_shape) or tensor.view(new_shape). Both change the tensor's shape without altering its data, but reshape is more flexible and safer for non-contiguous tensors.
📐

Syntax

The main ways to reshape a tensor in PyTorch are:

  • tensor.reshape(new_shape): Returns a tensor with the specified shape. It can handle non-contiguous tensors.
  • tensor.view(new_shape): Returns a tensor with the specified shape but requires the tensor to be contiguous in memory.

Here, new_shape is a tuple or list of integers defining the desired dimensions. You can use -1 for one dimension to infer its size automatically.

python
reshaped_tensor = tensor.reshape(new_shape)
reshaped_tensor = tensor.view(*new_shape)
💻

Example

This example shows how to reshape a 2D tensor into a 1D tensor and back using reshape and view.

python
import torch

# Create a 2D tensor of shape (2, 3)
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
print('Original tensor shape:', tensor.shape)

# Reshape to 1D tensor with 6 elements
reshaped = tensor.reshape(-1)
print('Reshaped tensor:', reshaped)
print('Reshaped tensor shape:', reshaped.shape)

# Reshape back to (3, 2) using view
reshaped_back = reshaped.view(3, 2)
print('Reshaped back tensor:', reshaped_back)
print('Reshaped back tensor shape:', reshaped_back.shape)
Output
Original tensor shape: torch.Size([2, 3]) Reshaped tensor: tensor([1, 2, 3, 4, 5, 6]) Reshaped tensor shape: torch.Size([6]) Reshaped back tensor: tensor([[1, 2], [3, 4], [5, 6]]) Reshaped back tensor shape: torch.Size([3, 2])
⚠️

Common Pitfalls

Common mistakes when reshaping tensors include:

  • Using view() on a non-contiguous tensor causes an error. Use reshape() instead.
  • Specifying a new shape that does not match the total number of elements causes a runtime error.
  • Forgetting that only one dimension can be -1 to infer size automatically.
python
import torch

# Create a non-contiguous tensor by transposing
tensor = torch.arange(6).reshape(2, 3).t()
print('Tensor shape:', tensor.shape)

# This will raise an error because tensor is not contiguous
try:
    tensor.view(3, 2)
except RuntimeError as e:
    print('Error with view:', e)

# Correct way using reshape
reshaped = tensor.reshape(3, 2)
print('Reshaped tensor with reshape:', reshaped)
Output
Tensor shape: torch.Size([3, 2]) Error with view: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape() instead. Reshaped tensor with reshape: tensor([[0, 3], [1, 4], [2, 5]])
📊

Quick Reference

Summary tips for reshaping tensors in PyTorch:

  • Use reshape() for flexible reshaping, especially with non-contiguous tensors.
  • Use view() only if you are sure the tensor is contiguous.
  • Use -1 in shape to let PyTorch infer the dimension size automatically.
  • Always ensure the total number of elements remains the same before and after reshaping.

Key Takeaways

Use tensor.reshape(new_shape) to safely reshape tensors, even if they are non-contiguous.
tensor.view(new_shape) requires the tensor to be contiguous and can cause errors otherwise.
Only one dimension in new_shape can be -1 to automatically infer its size.
The total number of elements must remain the same before and after reshaping.
reshape() is generally preferred over view() for its flexibility.