How to Use plt.subplots in Matplotlib: Simple Guide
Use
plt.subplots() to create a figure and one or more axes (plots) in matplotlib. It returns a tuple with the figure and axes objects, which you can use to customize and display your plots.Syntax
The basic syntax of plt.subplots() is:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(width, height))figis the figure object that holds everything.axis the axes object or array of axes where you draw your plots.nrowsandncolsspecify the grid layout of subplots.figsizesets the size of the figure in inches.
python
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(10, 6))
Example
This example creates a 2x2 grid of subplots and plots simple line charts on each.
python
import matplotlib.pyplot as plt import numpy as np x = np.linspace(0, 10, 100) fig, axs = plt.subplots(2, 2, figsize=(8, 6)) for i, ax in enumerate(axs.flat): ax.plot(x, np.sin(x + i)) ax.set_title(f'Subplot {i+1}') plt.tight_layout() plt.show()
Output
A window with 4 line plots arranged in 2 rows and 2 columns, each showing a sine wave shifted by subplot index.
Common Pitfalls
Common mistakes when using plt.subplots() include:
- Not unpacking the returned tuple correctly, especially when creating multiple subplots.
- For a single subplot,
axis not an array but a single object, so trying to iterate over it causes errors. - Not using
plt.tight_layout()to avoid overlapping labels and titles.
python
import matplotlib.pyplot as plt # Wrong: Trying to iterate over ax when only one subplot fig, ax = plt.subplots() for a in [ax]: a.plot([1, 2, 3], [4, 5, 6]) # This will not raise an error but is unnecessary # Right: Use ax directly without iteration fig, ax = plt.subplots() ax.plot([1, 2, 3], [4, 5, 6]) plt.show()
Output
Error: 'AxesSubplot' object is not iterable for the wrong code; a single plot window for the right code.
Quick Reference
| Parameter | Description | Default |
|---|---|---|
| nrows | Number of rows of subplots | 1 |
| ncols | Number of columns of subplots | 1 |
| figsize | Size of the figure in inches (width, height) | (6.4, 4.8) |
| sharex | Share x-axis among subplots (True/False) | False |
| sharey | Share y-axis among subplots (True/False) | False |
| squeeze | Reduce dimensions of returned axes array | True |
Key Takeaways
Use plt.subplots() to create figure and axes objects together for easy plotting.
Unpack the returned tuple correctly: fig, ax = plt.subplots().
For multiple subplots, ax is an array; for one subplot, ax is a single object.
Use plt.tight_layout() to prevent overlapping plot elements.
Set nrows and ncols to arrange multiple plots in a grid.