Callbacks help your model learn better by stopping early when it stops improving and saving the best version automatically.
Callbacks (EarlyStopping, ModelCheckpoint) in TensorFlow
Start learning this pattern below
Jump into concepts and practice - no test required
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint early_stop = EarlyStopping(monitor='val_loss', patience=3, verbose=1) model_checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, verbose=1)
monitor: The metric to watch, like 'val_loss' or 'val_accuracy'.
patience: How many bad epochs to wait before stopping.
early_stop = EarlyStopping(monitor='val_accuracy', patience=5, verbose=1)
model_checkpoint = ModelCheckpoint('best_weights.h5', monitor='val_loss', save_best_only=True, save_weights_only=True, verbose=1)
early_stop = EarlyStopping(monitor='val_loss', patience=2, restore_best_weights=True)
This code trains a simple neural network on random data. It uses EarlyStopping to stop if validation loss doesn't improve for 3 epochs and ModelCheckpoint to save the best model. After training, it prints validation loss and accuracy.
import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint # Create simple model model = Sequential([ Dense(16, activation='relu', input_shape=(10,)), Dense(1, activation='sigmoid') ]) model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) # Generate dummy data import numpy as np x_train = np.random.random((100, 10)) y_train = np.random.randint(2, size=(100, 1)) x_val = np.random.random((20, 10)) y_val = np.random.randint(2, size=(20, 1)) # Setup callbacks early_stop = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True) model_checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, verbose=1) # Train model with callbacks history = model.fit( x_train, y_train, epochs=20, batch_size=10, validation_data=(x_val, y_val), callbacks=[early_stop, model_checkpoint], verbose=2 ) # Evaluate model loss, accuracy = model.evaluate(x_val, y_val, verbose=0) print(f'Validation loss: {loss:.4f}') print(f'Validation accuracy: {accuracy:.4f}')
EarlyStopping helps avoid wasting time training when the model stops improving.
ModelCheckpoint saves your best model automatically, so you don't lose progress.
Use restore_best_weights=True in EarlyStopping to keep the best model after stopping.
Callbacks like EarlyStopping and ModelCheckpoint improve training by saving time and best models.
EarlyStopping stops training when no improvement happens for a set patience.
ModelCheckpoint saves the best model during training automatically.
Practice
EarlyStopping callback in TensorFlow training?Solution
Step 1: Understand EarlyStopping's role
EarlyStopping monitors a metric like validation loss and stops training if no improvement occurs for a set number of epochs.Step 2: Compare options with EarlyStopping behavior
Only To stop training when the model stops improving to save time describes stopping training to save time when no improvement happens.Final Answer:
To stop training when the model stops improving to save time -> Option CQuick Check:
EarlyStopping stops training early = C [OK]
- Confusing EarlyStopping with saving models
- Thinking EarlyStopping changes learning rate
- Assuming EarlyStopping shuffles data
ModelCheckpoint callback that saves only the best model based on validation accuracy?Solution
Step 1: Identify correct parameters for ModelCheckpoint
To save only the best model,save_best_only=Trueis needed, and to monitor validation accuracy,monitor='val_accuracy'is correct.Step 2: Check options for matching parameters
tf.keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_accuracy') matches these requirements exactly.Final Answer:
tf.keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_accuracy') -> Option BQuick Check:
Best model saved by val_accuracy = A [OK]
- Using monitor='accuracy' instead of 'val_accuracy'
- Setting save_best_only=False by mistake
- Confusing save_weights_only with saving full model
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=2) model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val), callbacks=[callback])If the validation loss stops improving after epoch 4, at which epoch will training stop?
Solution
Step 1: Understand patience parameter in EarlyStopping
Patience=2 means training continues 2 more epochs after last improvement before stopping.Step 2: Calculate stopping epoch
If last improvement is at epoch 4, training continues epochs 5 and 6, then stops before epoch 7 starts, so training stops at epoch 7.Final Answer:
Epoch 7 -> Option DQuick Check:
Patience 2 means stop 2 epochs after no improvement = B [OK]
- Stopping immediately at last improvement epoch
- Stopping one epoch too early or too late
- Confusing patience with number of total epochs
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3) model.fit(x_train, y_train, epochs=20, validation_data=(x_val, y_val), callbacks=[callback])What is the most likely reason training does not stop early?
Solution
Step 1: Check if validation data is correctly passed
EarlyStopping monitors validation metrics, so if validation data is missing or incorrect, val_loss won't update and stopping won't trigger.Step 2: Evaluate other options
Patience=3 is reasonable, save_best_only is unrelated to EarlyStopping, and callbacks argument is present.Final Answer:
The validation data is not passed correctly, so val_loss is not computed -> Option AQuick Check:
EarlyStopping needs valid val_loss metric = D [OK]
- Confusing ModelCheckpoint's save_best_only with EarlyStopping
- Ignoring validation_data argument
- Setting patience too high and expecting early stop
Solution
Step 1: Match EarlyStopping parameters to requirement
We want to stop if validation accuracy does not improve for 4 epochs, so monitor='val_accuracy' and patience=4 are correct.Step 2: Match ModelCheckpoint parameters
We want to save best weights based on validation accuracy, so save_best_only=True and monitor='val_accuracy' are needed.Step 3: Check options for both callbacks
Only [tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=4), tf.keras.callbacks.ModelCheckpoint('best.h5', save_best_only=True, monitor='val_accuracy')] has both callbacks correctly configured.Final Answer:
[tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=4), tf.keras.callbacks.ModelCheckpoint('best.h5', save_best_only=True, monitor='val_accuracy')] -> Option AQuick Check:
EarlyStopping and ModelCheckpoint monitor val_accuracy correctly = A [OK]
- Using 'accuracy' instead of 'val_accuracy' for validation monitoring
- Setting save_best_only=False when saving best model
- Mismatching patience with requirement
