0
0
TensorFlowml~5 mins

Callbacks (EarlyStopping, ModelCheckpoint) in TensorFlow

Choose your learning style9 modes available
Introduction

Callbacks help your model learn better by stopping early when it stops improving and saving the best version automatically.

When you want to stop training if the model stops getting better to save time and avoid overfitting.
When you want to save the best model during training to use it later without retraining.
When training takes a long time and you want to keep the best checkpoint in case of interruptions.
When you want to monitor validation performance and act on it automatically during training.
Syntax
TensorFlow
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

early_stop = EarlyStopping(monitor='val_loss', patience=3, verbose=1)
model_checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, verbose=1)

monitor: The metric to watch, like 'val_loss' or 'val_accuracy'.

patience: How many bad epochs to wait before stopping.

Examples
Stops training if validation accuracy does not improve for 5 epochs.
TensorFlow
early_stop = EarlyStopping(monitor='val_accuracy', patience=5, verbose=1)
Saves the model weights only when validation loss improves.
TensorFlow
model_checkpoint = ModelCheckpoint('best_weights.h5', monitor='val_loss', save_best_only=True, save_weights_only=True, verbose=1)
Stops early and restores the best weights found during training.
TensorFlow
early_stop = EarlyStopping(monitor='val_loss', patience=2, restore_best_weights=True)
Sample Model

This code trains a simple neural network on random data. It uses EarlyStopping to stop if validation loss doesn't improve for 3 epochs and ModelCheckpoint to save the best model. After training, it prints validation loss and accuracy.

TensorFlow
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# Create simple model
model = Sequential([
    Dense(16, activation='relu', input_shape=(10,)),
    Dense(1, activation='sigmoid')
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Generate dummy data
import numpy as np
x_train = np.random.random((100, 10))
y_train = np.random.randint(2, size=(100, 1))
x_val = np.random.random((20, 10))
y_val = np.random.randint(2, size=(20, 1))

# Setup callbacks
early_stop = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True)
model_checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, verbose=1)

# Train model with callbacks
history = model.fit(
    x_train, y_train,
    epochs=20,
    batch_size=10,
    validation_data=(x_val, y_val),
    callbacks=[early_stop, model_checkpoint],
    verbose=2
)

# Evaluate model
loss, accuracy = model.evaluate(x_val, y_val, verbose=0)
print(f'Validation loss: {loss:.4f}')
print(f'Validation accuracy: {accuracy:.4f}')
OutputSuccess
Important Notes

EarlyStopping helps avoid wasting time training when the model stops improving.

ModelCheckpoint saves your best model automatically, so you don't lose progress.

Use restore_best_weights=True in EarlyStopping to keep the best model after stopping.

Summary

Callbacks like EarlyStopping and ModelCheckpoint improve training by saving time and best models.

EarlyStopping stops training when no improvement happens for a set patience.

ModelCheckpoint saves the best model during training automatically.