0
0
MlopsConceptBeginner · 4 min read

What is Overfitting in Machine Learning with Python | Sklearn Guide

In machine learning, overfitting happens when a model learns the training data too well, including noise and details that don't apply to new data. This causes the model to perform very well on training data but poorly on unseen data, reducing its ability to generalize.
⚙️

How It Works

Imagine you are trying to learn how to recognize dogs in photos. If you memorize every single photo you saw, including the background and lighting, you might fail to recognize a new dog in a different setting. This is similar to overfitting in machine learning.

When a model overfits, it captures not only the main patterns but also the random noise or small details in the training data. This makes it too specialized and less flexible to handle new data. The model becomes like a student who memorizes answers instead of understanding the concepts.

In Python, using libraries like sklearn, overfitting can happen if the model is too complex or trained for too long without checks. We need to balance learning enough from data without memorizing it.

💻

Example

This example shows overfitting using a polynomial regression model on simple data. A high-degree polynomial fits the training points perfectly but fails to predict new points well.

python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

# Create simple data
np.random.seed(0)
x = np.linspace(0, 1, 10).reshape(-1, 1)
y = np.sin(2 * np.pi * x).ravel() + np.random.normal(0, 0.1, x.shape[0])

# Fit low-degree polynomial (good fit)
poly_low = PolynomialFeatures(degree=3)
x_poly_low = poly_low.fit_transform(x)
model_low = LinearRegression().fit(x_poly_low, y)
y_pred_low = model_low.predict(x_poly_low)

# Fit high-degree polynomial (overfitting)
poly_high = PolynomialFeatures(degree=15)
x_poly_high = poly_high.fit_transform(x)
model_high = LinearRegression().fit(x_poly_high, y)
y_pred_high = model_high.predict(x_poly_high)

# Calculate errors
mse_low = mean_squared_error(y, y_pred_low)
mse_high = mean_squared_error(y, y_pred_high)

# Plot results
plt.scatter(x, y, color='black', label='Data')
plt.plot(x, y_pred_low, label='Degree 3 fit')
plt.plot(x, y_pred_high, label='Degree 15 fit (Overfit)')
plt.legend()
plt.title(f'MSE Low Degree: {mse_low:.3f}, High Degree: {mse_high:.3f}')
plt.show()
Output
A plot showing black dots for data points, a smooth curve for degree 3 fit closely following data trend, and a very wiggly curve for degree 15 fit passing exactly through all points indicating overfitting.
🎯

When to Use

Understanding overfitting is important when building machine learning models to ensure they work well on new data, not just the training set.

Use this knowledge when:

  • Training models on small datasets where memorizing noise is easy.
  • Choosing model complexity, like deciding how deep a decision tree should be.
  • Evaluating model performance using separate validation or test data.
  • Applying regularization techniques or early stopping to prevent overfitting.

Real-world cases include predicting house prices, detecting spam emails, or recognizing images, where generalization is key.

Key Points

  • Overfitting means the model learns training data too exactly, including noise.
  • This causes poor performance on new, unseen data.
  • It often happens with very complex models or too much training.
  • Use validation data and simpler models to avoid overfitting.
  • Techniques like regularization and cross-validation help prevent it.

Key Takeaways

Overfitting occurs when a model learns noise and details from training data, hurting new data performance.
Balancing model complexity and training time helps prevent overfitting.
Use validation sets and metrics to detect overfitting early.
Regularization and simpler models improve generalization.
Visualizing model predictions can reveal overfitting patterns.