Callbacks
In addition to overridable hooks, the Trainer
also includes a callback system.
It is recommended that callbacks are used to contain ‘infrastructure’ code, which is not essential to the operation of the training loop, such as logging, but this decision is left to the judgement of the user based on the specific use case.
Warning
Callbacks are executed sequentially, so if a callback is used to modify state, such as updating a metric, it is the responsibility of the user to ensure that this callback is placed before any callback which will read this state (i.e. for logging purposes)!
Note
Callbacks are called after their corresponding hooks, e.g., a callback’s on_train_epoch_end
method is called
after the method pytorch_accelerated.trainer.Trainer.train_epoch_end()
. This is done to support the pattern of updating the
trainer’s state in a method before reading this state in a callback.
For more info on execution order within the training loop, see: What goes on inside the Trainer?.
Implemented Callbacks
- class pytorch_accelerated.callbacks.TerminateOnNaNCallback[source]
Bases:
TrainerCallback
A callback that terminates the training run if a
NaN
loss is observed during either training or evaluation.
- class pytorch_accelerated.callbacks.LogMetricsCallback[source]
Bases:
TrainerCallback
A callback that logs the latest values of any metric which has been updated in the trainer’s run history. By default, this just prints to the command line once per machine.
Metrics prefixed with ‘train’ are logged at the end of a training epoch, all other metrics are logged after evaluation.
This can be subclassed to create loggers for different platforms by overriding the
log_metrics()
method.
- class pytorch_accelerated.callbacks.PrintProgressCallback[source]
Bases:
TrainerCallback
A callback which prints a message at the start and end of a run, as well as at the start of each epoch.
- class pytorch_accelerated.callbacks.ProgressBarCallback[source]
Bases:
TrainerCallback
A callback which visualises the state of each training and evaluation epoch using a progress bar
- class pytorch_accelerated.callbacks.SaveBestModelCallback(save_path='best_model.pt', watch_metric='eval_loss_epoch', greater_is_better: bool = False, reset_on_train: bool = True, save_optimizer: bool = True)[source]
Bases:
TrainerCallback
A callback which saves the best model during a training run, according to a given metric. The best model weights are loaded at the end of the training run.
- __init__(save_path='best_model.pt', watch_metric='eval_loss_epoch', greater_is_better: bool = False, reset_on_train: bool = True, save_optimizer: bool = True)[source]
- Parameters:
save_path – The path to save the checkpoint to. This should end in
.pt
.watch_metric – the metric used to compare model performance. This should be accessible from the trainer’s run history.
greater_is_better – whether an increase in the
watch_metric
should be interpreted as the model performing better.reset_on_train – whether to reset the best metric on subsequent training runs. If
True
, only the metrics observed during the current training run will be compared.save_optimizer – whether to also save the optimizer as part of the model checkpoint
- class pytorch_accelerated.callbacks.ModelEmaCallback(decay: float = 0.99, evaluate_during_training: bool = True, save_path: str = 'ema_model.pt', watch_metric: str = 'ema_model_eval_loss_epoch', greater_is_better: bool = False, model_ema=<class 'pytorch_accelerated.utils.ModelEma'>, callbacks=())[source]
Bases:
SaveBestModelCallback
A callback which maintains and saves an exponential moving average of the weights of the model that is currently being trained.
This callback offers the option of evaluating the EMA model during. If enabled, this is done by running an additional validation after each training epoch, which will use additional GPU resources. During this additional epoch, only the provided callbacks will be executed.
Note
This callback is sensitive to the order that it is executed. This should be placed after any callbacks that modify state (e.g. metrics) but before any callbacks that read state (e.g. loggers) or
ConvertSyncBatchNormCallback
.- __init__(decay: float = 0.99, evaluate_during_training: bool = True, save_path: str = 'ema_model.pt', watch_metric: str = 'ema_model_eval_loss_epoch', greater_is_better: bool = False, model_ema=<class 'pytorch_accelerated.utils.ModelEma'>, callbacks=())[source]
- Parameters:
decay – the amount of decay to use, which determines how much of the previous state will be maintained.
evaluate_during_training – whether to evaluate the EMA model during training. If True, an additional validation epoch will be conducted after each training epoch, which will use additional GPU resources, and the best model will be saved. If False, the saved EMA model checkpoint will be updated at the end of each epoch.
watch_metric – the metric used to compare model performance. This should be accessible from the trainer’s run history. This is only used when
evaluate_during_training
is enabled.greater_is_better – whether an increase in the
watch_metric
should be interpreted as the model performing better.model_ema – the class which is responsible for maintaining the moving average of the model.
callbacks – an iterable of callbacks that will be executed during the evaluation loop of the EMA model
- class pytorch_accelerated.callbacks.EarlyStoppingCallback(early_stopping_patience: int = 1, early_stopping_threshold: float = 0.01, watch_metric='eval_loss_epoch', greater_is_better: bool = False, reset_on_train: bool = True)[source]
Bases:
TrainerCallback
A callback which stops training early if progress is not being observed.
- __init__(early_stopping_patience: int = 1, early_stopping_threshold: float = 0.01, watch_metric='eval_loss_epoch', greater_is_better: bool = False, reset_on_train: bool = True)[source]
- Parameters:
early_stopping_patience – the number of epochs with no improvement after which training will be stopped.
early_stopping_threshold – the minimum change in the
watch_metric
to qualify as an improvement, i.e. an absolute change of less than this threshold, will count as no improvement.watch_metric – the metric used to compare model performance. This should be accessible from the trainer’s run history.
greater_is_better – whether an increase in the
watch_metric
should be interpreted as the model performing better.reset_on_train – whether to reset the best metric on subsequent training runs. If
True
, only the metrics observed during the current training run will be compared.
- class pytorch_accelerated.callbacks.MoveModulesToDeviceCallback[source]
Bases:
TrainerCallback
A callback which moves any
Trainer
attributes which are instances oftorch.nn.Module
on to the appropriate device at the start of a training or evaluation run.Note
This does not include the model, as this will be prepared separately by the
Trainer
’s internal instance ofaccelerate.Accelerator
.
- class pytorch_accelerated.callbacks.ConvertSyncBatchNormCallback[source]
Bases:
TrainerCallback
A callback which converts all BatchNorm*D layers in the model to
torch.nn.SyncBatchNorm
layers.
Creating New Callbacks
To create a new callback containing custom behaviour, e.g. logging to an external platform, it is recommended to subclass
TrainerCallback
. To avoid confusion with the Trainer
’s methods, all callback methods are
prefixed with _on
.
Warning
For maximum flexibility, the current instance of the Trainer
is available in every callback method.
However, changing the trainer state within a callback can have unintended consequences, as this may affect other parts
of the training run. If a callback is used to modify Trainer
state, it is responsibility of the user
to ensure that everything continues to work as intended.
- class pytorch_accelerated.callbacks.TrainerCallback[source]
The abstract base class to be subclassed when creating new callbacks.
- on_train_step_end(trainer, batch, batch_output, **kwargs)[source]
Event called at the end of a training step.
- Parameters:
batch – the current batch of training data
batch_output – the outputs returned by
pytorch_accelerated.trainer.Trainer.calculate_train_batch_loss()
- on_eval_epoch_start(trainer, **kwargs)[source]
Event called at the beginning of an evaluation epoch.
- on_eval_step_end(trainer, batch, batch_output, **kwargs)[source]
Event called at the end of an evaluation step.
- Parameters:
batch – the current batch of evaluation data
batch_output – the outputs returned by
pytorch_accelerated.trainer.Trainer.calculate_eval_batch_loss()
Stopping Training Early
A training run may be stopped early by raising a StopTrainingError
Example: Tracking metrics using a callback
By default, the only metrics that are recorded by the pytorch_accelerated.trainer.Trainer
are the losses observed during
training and evaluation. To track additional metrics, we can extend this behaviour using a callback.
Here is an example of how we can define a callback and use the RunHistory
to track metrics
computed using TorchMetrics:
from torchmetrics import MetricCollection, Accuracy, Precision, Recall
class ClassificationMetricsCallback(TrainerCallback):
def __init__(self, num_classes):
self.metrics = MetricCollection(
{
"accuracy": Accuracy(task="multiclass", num_classes=num_classes),
"precision": Precision(task="multiclass", num_classes=num_classes),
"recall": Recall(task="multiclass", num_classes=num_classes),
}
)
def _move_to_device(self, trainer):
self.metrics.to(trainer.device)
def on_training_run_start(self, trainer, **kwargs):
self._move_to_device(trainer)
def on_evaluation_run_start(self, trainer, **kwargs):
self._move_to_device(trainer)
def on_eval_step_end(self, trainer, batch, batch_output, **kwargs):
preds = batch_output["model_outputs"].argmax(dim=-1)
self.metrics.update(preds, batch[1])
def on_eval_epoch_end(self, trainer, **kwargs):
metrics = self.metrics.compute()
trainer.run_history.update_metric("accuracy", metrics["accuracy"].cpu())
trainer.run_history.update_metric("precision", metrics["precision"].cpu())
trainer.run_history.update_metric("recall", metrics["recall"].cpu())
self.metrics.reset()
Note
If you feel that it would be clearer to compute metrics as part of the training loop, this could also be done by
subclassing the pytorch_accelerated.trainer.Trainer
as demonstrated in Recording metrics.
Example: Create a custom logging callback
It is recommended that callbacks are used to handle logging, to keep the training loop focused on the ML related code.
It is easy to create loggers for other platforms by subclassing the LogMetricsCallback
callback. For example,
we can create a logger for AzureML (which uses the MLFlow API) as demonstrated below:
import mlflow
class AzureMLLoggerCallback(LogMetricsCallback):
def __init__(self):
mlflow.set_tracking_uri(os.environ['MLFLOW_TRACKING_URI'])
def on_training_run_start(self, trainer, **kwargs):
mlflow.set_tags(trainer.run_config.to_dict())
def log_metrics(self, trainer, metrics):
if trainer.run_config.is_world_process_zero:
mlflow.log_metrics(metrics)
Example: Create a custom callback to save predictions on evaluation
Here is an example custom callback to record predictions during evaluation and then save them to csv at the end of the evaluation epoch:
from collections import defaultdict
import pandas as pd
class SavePredictionsCallback(TrainerCallback):
def __init__(self, out_filename='./outputs/valid_predictions.csv') -> None:
super().__init__()
self.predictions = defaultdict(list)
self.out_filename = out_filename
def on_eval_step_end(self, trainer, batch, batch_output, **kwargs):
input_features, targets = batch
class_preds = trainer.gather(batch_output['model_outputs']).argmax(dim=-1)
self.predictions['prediction'].extend(class_preds.cpu().tolist())
self.predictions['targets'].extend(targets.cpu().tolist())
def on_eval_epoch_end(self, trainer, **kwargs):
trainer._accelerator.wait_for_everyone()
if trainer.run_config.is_local_process_zero:
df = pd.DataFrame.from_dict(self.predictions)
df.to_csv(f'{self.out_filename}', index=False)
Callback handler
The execution of any callbacks passed to the Trainer
is handled by an instance of an internal
callback handler class.
- class pytorch_accelerated.callbacks.CallbackHandler(callbacks)[source]
The
CallbackHandler
is responsible for calling a list of callbacks. This class calls the callbacks in the order that they are given.- add_callback(callback)[source]
Add a callbacks to the callback handler
- Parameters:
callback – an instance of a subclass of
TrainerCallback
- add_callbacks(callbacks)[source]
Add a list of callbacks to the callback handler
- Parameters:
callbacks – a list of
TrainerCallback
- call_event(event, *args, **kwargs)[source]
For each callback which has been registered, sequentially call the method corresponding to the given event.
- Parameters:
event – The event corresponding to the method to call on each callback
args – a list of arguments to be passed to each callback
kwargs – a list of keyword arguments to be passed to each callback
Creating new callback events
To add even more flexibility, it is relatively simple to define custom callback events, and use them in the training loop:
class VerifyBatchCallback(TrainerCallback):
def verify_train_batch(self, trainer, xb, yb):
assert xb.shape[0] == trainer.run_config["train_per_device_batch_size"]
assert xb.shape[1] == 1
assert xb.shape[2] == 28
assert xb.shape[3] == 28
assert yb.shape[0] == trainer.run_config["train_per_device_batch_size"]
class TrainerWithCustomCallbackEvent(Trainer):
def calculate_train_batch_loss(self, batch) -> dict:
xb, yb = batch
self.callback_handler.call_event(
"verify_train_batch", trainer=self, xb=xb, yb=yb
)
return super().calculate_train_batch_loss(batch)