Callbacks

In addition to overridable hooks, the Trainer also includes a callback system.

It is recommended that callbacks are used to contain ‘infrastructure’ code, which is not essential to the operation of the training loop, such as logging, but this decision is left to the judgement of the user based on the specific use case.

Warning

Callbacks are executed sequentially, so if a callback is used to modify state, such as updating a metric, it is the responsibility of the user to ensure that this callback is placed before any callback which will read this state (i.e. for logging purposes)!

Note

Callbacks are called after their corresponding hooks, e.g., a callback’s on_train_epoch_end method is called after the method pytorch_accelerated.trainer.Trainer.train_epoch_end(). This is done to support the pattern of updating the trainer’s state in a method before reading this state in a callback.

For more info on execution order within the training loop, see: What goes on inside the Trainer?.

Implemented Callbacks

class pytorch_accelerated.callbacks.TerminateOnNaNCallback[source]

Bases: TrainerCallback

A callback that terminates the training run if a NaN loss is observed during either training or evaluation.

class pytorch_accelerated.callbacks.LogMetricsCallback[source]

Bases: TrainerCallback

A callback that logs the latest values of any metric which has been updated in the trainer’s run history. By default, this just prints to the command line once per machine.

Metrics prefixed with ‘train’ are logged at the end of a training epoch, all other metrics are logged after evaluation.

This can be subclassed to create loggers for different platforms by overriding the log_metrics() method.

class pytorch_accelerated.callbacks.PrintProgressCallback[source]

Bases: TrainerCallback

A callback which prints a message at the start and end of a run, as well as at the start of each epoch.

class pytorch_accelerated.callbacks.ProgressBarCallback[source]

Bases: TrainerCallback

A callback which visualises the state of each training and evaluation epoch using a progress bar

class pytorch_accelerated.callbacks.SaveBestModelCallback(save_path='best_model.pt', watch_metric='eval_loss_epoch', greater_is_better: bool = False, reset_on_train: bool = True)[source]

Bases: TrainerCallback

A callback which saves the best model during a training run, according to a given metric. The best model weights are loaded at the end of the training run.

__init__(save_path='best_model.pt', watch_metric='eval_loss_epoch', greater_is_better: bool = False, reset_on_train: bool = True)[source]
Parameters
  • save_path – The path to save the checkpoint to. This should end in .pt.

  • watch_metric – the metric used to compare model performance. This should be accessible from the trainer’s run history.

  • greater_is_better – whether an increase in the watch_metric should be interpreted as the model performing better.

  • reset_on_train – whether to reset the best metric on subsequent training runs. If True, only the metrics observed during the current training run will be compared.

class pytorch_accelerated.callbacks.EarlyStoppingCallback(early_stopping_patience: int = 1, early_stopping_threshold: float = 0.01, watch_metric='eval_loss_epoch', greater_is_better: bool = False, reset_on_train: bool = True)[source]

Bases: TrainerCallback

A callback which stops training early if progress is not being observed.

__init__(early_stopping_patience: int = 1, early_stopping_threshold: float = 0.01, watch_metric='eval_loss_epoch', greater_is_better: bool = False, reset_on_train: bool = True)[source]
Parameters
  • early_stopping_patience – the number of epochs with no improvement after which training will be stopped.

  • early_stopping_threshold – the minimum change in the watch_metric to qualify as an improvement, i.e. an absolute change of less than this threshold, will count as no improvement.

  • watch_metric – the metric used to compare model performance. This should be accessible from the trainer’s run history.

  • greater_is_better – whether an increase in the watch_metric should be interpreted as the model performing better.

  • reset_on_train – whether to reset the best metric on subsequent training runs. If True, only the metrics observed during the current training run will be compared.

class pytorch_accelerated.callbacks.MoveModulesToDeviceCallback[source]

Bases: TrainerCallback

A callback which moves any Trainer attributes which are instances of torch.nn.Module on to the appropriate device at the start of a training or evaluation run.

Note

This does not include the model, as this will be prepared separately by the Trainer’s internal instance of accelerate.Accelerator.

Creating New Callbacks

To create a new callback containing custom behaviour, e.g. logging to an external platform, it is recommended to subclass TrainerCallback. To avoid confusion with the Trainer’s methods, all callback methods are prefixed with _on.

Warning

For maximum flexibility, the current instance of the Trainer is available in every callback method. However, changing the trainer state within a callback can have unintended consequences, as this may affect other parts of the training run. If a callback is used to modify Trainer state, it is responsibility of the user to ensure that everything continues to work as intended.

class pytorch_accelerated.callbacks.TrainerCallback[source]

The abstract base class to be subclassed when creating new callbacks.

on_init_end(trainer, **kwargs)[source]

Event called at the end of trainer initialisation.

on_training_run_start(trainer, **kwargs)[source]

Event called at the start of training run.

on_train_epoch_start(trainer, **kwargs)[source]

Event called at the beginning of a training epoch.

on_train_step_start(trainer, **kwargs)[source]

Event called at the beginning of a training step.

on_train_step_end(trainer, batch, batch_output, **kwargs)[source]

Event called at the end of a training step.

Parameters
on_train_epoch_end(trainer, **kwargs)[source]

Event called at the end of a training epoch.

on_eval_epoch_start(trainer, **kwargs)[source]

Event called at the beginning of an evaluation epoch.

on_eval_step_start(trainer, **kwargs)[source]

Event called at the beginning of a evaluation step.

on_eval_step_end(trainer, batch, batch_output, **kwargs)[source]

Event called at the end of an evaluation step.

Parameters
on_eval_epoch_end(trainer, **kwargs)[source]

Event called at the end of evaluation.

on_training_run_end(trainer, **kwargs)[source]

Event called at the end of training run.

on_stop_training_error(trainer, **kwargs)[source]

Event called when a stop training error is raised

Stopping Training Early

A training run may be stopped early by raising a StopTrainingError

Example: Tracking metrics using a callback

By default, the only metrics that are recorded by the pytorch_accelerated.trainer.Trainer are the losses observed during training and evaluation. To track additional metrics, we can extend this behaviour using a callback.

Here is an example of how we can define a callback and use the RunHistory to track metrics computed using TorchMetrics:

from torchmetrics import MetricCollection, Accuracy, Precision, Recall

class ClassificationMetricsCallback(TrainerCallback):
    def __init__(self, num_classes):
        self.metrics = MetricCollection(
            {
                "accuracy": Accuracy(num_classes=num_classes),
                "precision": Precision(num_classes=num_classes),
                "recall": Recall(num_classes=num_classes),
            }
        )

    def _move_to_device(self, trainer):
        self.metrics.to(trainer.device)

    def on_training_run_start(self, trainer, **kwargs):
        self._move_to_device(trainer)

    def on_evaluation_run_start(self, trainer, **kwargs):
        self._move_to_device(trainer)

    def on_eval_step_end(self, trainer, batch, batch_output, **kwargs):
        preds = batch_output["model_outputs"].argmax(dim=-1)
        self.metrics.update(preds, batch[1])

    def on_eval_epoch_end(self, trainer, **kwargs):
        metrics = self.metrics.compute()
        trainer.run_history.update_metric("accuracy", metrics["accuracy"].cpu())
        trainer.run_history.update_metric("precision", metrics["precision"].cpu())
        trainer.run_history.update_metric("recall", metrics["recall"].cpu())

        self.metrics.reset()

Note

If you feel that it would be clearer to compute metrics as part of the training loop, this could also be done by subclassing the pytorch_accelerated.trainer.Trainer as demonstrated in Recording metrics.

Example: Create a custom logging callback

It is recommended that callbacks are used to handle logging, to keep the training loop focused on the ML related code. It is easy to create loggers for other platforms by subclassing the LogMetricsCallback callback. For example, we can create a logger for AzureML (which uses the MLFlow API) as demonstrated below:

import mlflow

class AzureMLLoggerCallback(LogMetricsCallback):
    def __init__(self):
        mlflow.set_tracking_uri(os.environ['MLFLOW_TRACKING_URI'])

    def on_training_run_start(self, trainer, **kwargs):
        mlflow.set_tags(trainer.run_config.to_dict())

    def log_metrics(self, trainer, metrics):
        if trainer.run_config.is_world_process_zero:
            mlflow.log_metrics(metrics)

Example: Create a custom callback to save predictions on evaluation

Here is an example custom callback to record predictions during evaluation and then save them to csv at the end of the evaluation epoch:

from collections import defaultdict
import pandas as pd

class SavePredictionsCallback(TrainerCallback):

    def __init__(self, out_filename='./outputs/valid_predictions.csv') -> None:
        super().__init__()
        self.predictions = defaultdict(list)
        self.out_filename = out_filename

    def on_eval_step_end(self, trainer, batch, batch_output, **kwargs):
        input_features, targets = batch
        class_preds = trainer.gather(batch_output['model_outputs']).argmax(dim=-1)
        self.predictions['prediction'].extend(class_preds.cpu().tolist())
        self.predictions['targets'].extend(targets.cpu().tolist())

    def on_eval_epoch_end(self, trainer, **kwargs):
        trainer._accelerator.wait_for_everyone()
        if trainer.run_config.is_local_process_zero:
            df = pd.DataFrame.from_dict(self.predictions)
            df.to_csv(f'{self.out_filename}', index=False)

Callback handler

The execution of any callbacks passed to the Trainer is handled by an instance of an internal callback handler class.

class pytorch_accelerated.callbacks.CallbackHandler(callbacks)[source]

The CallbackHandler is responsible for calling a list of callbacks. This class calls the callbacks in the order that they are given.

add_callback(callback)[source]

Add a callbacks to the callback handler

Parameters

callback – an instance of a subclass of TrainerCallback

add_callbacks(callbacks)[source]

Add a list of callbacks to the callback handler

Parameters

callbacks – a list of TrainerCallback

call_event(event, *args, **kwargs)[source]

For each callback which has been registered, sequentially call the method corresponding to the given event.

Parameters
  • event – The event corresponding to the method to call on each callback

  • args – a list of arguments to be passed to each callback

  • kwargs – a list of keyword arguments to be passed to each callback

Creating new callback events

To add even more flexibility, it is relatively simple to define custom callback events, and use them in the training loop:

class VerifyBatchCallback(TrainerCallback):
    def verify_train_batch(self, trainer, xb, yb):
        assert xb.shape[0] == trainer.run_config["train_per_device_batch_size"]
        assert xb.shape[1] == 1
        assert xb.shape[2] == 28
        assert xb.shape[3] == 28
        assert yb.shape[0] == trainer.run_config["train_per_device_batch_size"]


class TrainerWithCustomCallbackEvent(Trainer):
    def calculate_train_batch_loss(self, batch) -> dict:
        xb, yb = batch
        self.callback_handler.call_event(
            "verify_train_batch", trainer=self, xb=xb, yb=yb
        )
        return super().calculate_train_batch_loss(batch)