How to Plot Confusion Matrix in Python with sklearn
To plot a confusion matrix in Python, use
sklearn.metrics.plot_confusion_matrix with your trained model and test data. This function creates a visual matrix showing true vs predicted labels, helping evaluate classification performance.Syntax
The basic syntax to plot a confusion matrix using sklearn is:
plot_confusion_matrix(estimator, X, y_true, labels=None, cmap=None, normalize=None)estimator: Your trained model.X: Test features.y_true: True labels for test data.labels: Optional list of labels to index the matrix.cmap: Color map for the plot.normalize: Option to normalize values ('true', 'pred', or 'all').
python
from sklearn.metrics import plot_confusion_matrix plot_confusion_matrix(estimator, X_test, y_test, cmap='Blues', normalize=None) import matplotlib.pyplot as plt plt.show()
Example
This example trains a simple logistic regression model on the iris dataset, predicts test labels, and plots the confusion matrix to visualize classification results.
python
from sklearn.datasets import load_iris from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn.metrics import plot_confusion_matrix import matplotlib.pyplot as plt # Load data iris = load_iris() X, y = iris.data, iris.target # Split data X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) # Train model model = LogisticRegression(max_iter=200) model.fit(X_train, y_train) # Plot confusion matrix plot_confusion_matrix(model, X_test, y_test, cmap='Blues', normalize=None) plt.title('Confusion Matrix') plt.show()
Output
A window showing a confusion matrix plot with blue color shading representing counts of true vs predicted labels for iris classes.
Common Pitfalls
- Not fitting the model before plotting causes errors.
- Passing training data instead of test data can give misleading results.
- For multiclass problems, forgetting to specify labels can cause label order confusion.
- Using deprecated functions like
confusion_matrixalone without plotting requires extra steps.
Always use plot_confusion_matrix with a trained model and test data for correct visualization.
python
from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt # Wrong: plotting confusion matrix without model or plot function cm = confusion_matrix(y_test, model.predict(X_test)) print(cm) # Right: use plot_confusion_matrix for direct plotting from sklearn.metrics import plot_confusion_matrix plot_confusion_matrix(model, X_test, y_test) plt.show()
Output
Printed confusion matrix array in console for wrong way; graphical confusion matrix plot for right way.
Quick Reference
| Parameter | Description | Example |
|---|---|---|
| estimator | Trained model object | LogisticRegression() fitted |
| X | Test features | X_test array |
| y_true | True labels for test data | y_test array |
| labels | List of class labels (optional) | ['setosa', 'versicolor', 'virginica'] |
| cmap | Color map for plot | 'Blues' |
| normalize | Normalize counts ('true', 'pred', 'all') | 'true' |
Key Takeaways
Use sklearn's plot_confusion_matrix with a trained model and test data to visualize classification results easily.
Always fit your model before plotting the confusion matrix to avoid errors.
Use the normalize parameter to see proportions instead of raw counts if needed.
Avoid using confusion_matrix alone for plotting; plot_confusion_matrix handles visualization directly.
Specify labels explicitly for clarity in multiclass problems.