Trainer

class pytorch_accelerated.trainer.Trainer(model, loss_func, optimizer, callbacks=(<class 'pytorch_accelerated.callbacks.MoveModulesToDeviceCallback'>, <class 'pytorch_accelerated.callbacks.TerminateOnNaNCallback'>, <class 'pytorch_accelerated.callbacks.PrintProgressCallback'>, <class 'pytorch_accelerated.callbacks.ProgressBarCallback'>, <class 'pytorch_accelerated.callbacks.LogMetricsCallback'>), run_history=None)[source]

The Trainer is designed to encapsulate an entire training loop for a specific task, bringing together the model, loss function and optimizer, and providing a specification of the behaviour to execute for each step of the training process.

The trainer has been implemented such that it provides (overridable) implementations of the parts of training that rarely change after they have been defined – such as creating a data loader, or how a batch of data is fed to the model – whilst remaining decoupled from components that are likely to change, such as the model, dataset, loss function and optimizer.

__init__(model, loss_func, optimizer, callbacks=(<class 'pytorch_accelerated.callbacks.MoveModulesToDeviceCallback'>, <class 'pytorch_accelerated.callbacks.TerminateOnNaNCallback'>, <class 'pytorch_accelerated.callbacks.PrintProgressCallback'>, <class 'pytorch_accelerated.callbacks.ProgressBarCallback'>, <class 'pytorch_accelerated.callbacks.LogMetricsCallback'>), run_history=None)[source]

Create a new trainer object which can be used to train the given model using the provided loss function and optimizer.

Parameters
  • model – a subclass of nn.Module to be trained

  • loss_func – the loss function to use when training the model

  • optimizer – the optimizer to update the model’s parameters

  • callbacks – a list of callbacks to use during training runs. If a list of callbacks is not provided, the default selection will be used.

  • run_history – an instance of a RunHistory subclass to track training runs. If this is not provided, a new one will be created.

The callbacks that are used by default are ( MoveModulesToDeviceCallback, TerminateOnNaNCallback, PrintProgressCallback, ProgressBarCallback, LogMetricsCallback, )

class pytorch_accelerated.trainer.TrainerWithTimmScheduler(*args, **kwargs)[source]

Subclass of the Trainer that works with timm schedulers instead of standard PyTorch learning rate schedulers

Training a model

The main entrypoint for the Trainer is the train() method, which is used to launch a training run.

Trainer.train(train_dataset, num_epochs, eval_dataset=None, per_device_batch_size=8, max_num_train_steps=None, gradient_accumulation_steps=1, gradient_clip_value=None, create_scheduler_fn: Optional[Callable] = None, train_dataloader_kwargs: Optional[dict] = None, eval_dataloader_kwargs: Optional[dict] = None, reset_run_history=True, collate_fn=None)[source]

Start a training run. If an evaluation dataset is provided, this routine will include both training and evaluation epochs.

Note

As the optimizer needs to be internally prepared prior to training, in order to use a learning rate scheduler, a factory function must be provided to create_scheduler_fn. This must be a function which accepts the optimizer as a single parameter and returns an instance of a learning rate scheduler. Passing an instance of a learning rate scheduler will not work here.

Parameters
  • train_dataset – the dataset to use during training epochs

  • num_epochs – the number of epochs to train for

  • eval_dataset – the dataset to use during evaluation epochs, if this is not provided, evaluation is skipped.

  • per_device_batch_size – the batch size to use per device

  • max_num_train_steps – the maximum number of steps to train for. If provided, this will override num_epochs

  • gradient_accumulation_steps – accumulate gradients to the specified number of steps to simulate a bigger batch size. By default, this is set to 1

  • gradient_clip_value – if specified, the gradients of the model’s parameters will be clipped to the range [-gradient_clip_value, gradient_clip_value]

  • create_scheduler_fn – a function which accepts an optimizer as an argument and returns a learning rate scheduler

  • train_dataloader_kwargs – : a dictionary of keyword arguments to pass to the training dataloader constructor, for details see torch.utils.data.DataLoader

  • eval_dataloader_kwargs – a dictionary of keyword arguments to pass to the evaluation dataloader constructor, for details see torch.utils.data.DataLoader

  • reset_run_history – reset any run history saved by the trainer from previous training runs

  • collate_fn – the collate function to be used by the training and evaluation dataloaders

Using learning rate schedulers

As Pytorch schedulers are not consistently called in the same way, to enable maximum flexibility, PyTorch-accelerated’s Trainer expects that a given scheduler should be called after each optimizer update by default.

Note that, as the optimizer and dataloaders need to be internally prepared prior to training, in order to use a learning rate scheduler, a factory function must be provided to train() as the create_scheduler_fn argument. This must be a function which accepts the optimizer as a single parameter and returns an instance of a learning rate scheduler.

Note

Passing an instance of a PyTorch learning rate scheduler as the create_scheduler_fn argument to train() will not work as intended.

A simple method of creating a scheduler factory function this is by using functools.partial() like so:

from functools import Partial

from torch.optim import lr_scheduler

create_scheduler_fn = partial(lr_scheduler.StepLR, step_size=7, gamma=0.1)

Note

The Trainer calls a step on the provided scheduler after every batch. This can lead to unexpected results as some PyTorch schedulers are expected to step only after every epoch.

For instance, in the example above, the learning rate would be multiplied by 0.1 at every batch. As this particular scheduler is designed to be called once per epoch, this is not the desired behaviour! We can resolve this by representing the step_size in terms of the number of updates, like this:

from functools import Partial

from torch.optim import lr_scheduler

from pytorch_accelerated import TrainerPlaceholderValues

epochs_step_size = 7

create_scheduler_fn = partial(
    lr_scheduler.StepLR,
    step_size=TrainerPlaceHolderValues.NUM_UPDATE_STEPS_PER_EPOCH * epochs_step_size
)

Here, to determine the value of the number of steps per epoch, we have used a TrainerPlaceholderValues placeholder, which are described below.

Using TrainerPlaceHolderValues

class pytorch_accelerated.trainer.TrainerPlaceholderValues(value)[source]

Some learning rate schedulers require information such as the total number of steps that will take place during a training run. As this information is not accessible prior to creating the training dataloader - which will be done as part of the train() method - a placeholder value can be used in the cases, as demonstrated below:

from functools import Partial

from pytorch_accelerated import TrainerPlaceholderValues
from torch.optim.lr_scheduler import OneCycleLR

create_scheduler_fn = partial(
            OneCycleLR,
            max_lr=config.lr,
            epochs=TrainerPlaceholderValues.NUM_EPOCHS,
            steps_per_epoch=TrainerPlaceholderValues.NUM_UPDATE_STEPS_PER_EPOCH,
        )

These placeholders will be replaced by the trainer with the correct values during training.

The list of the available placeholders are:

  • NUM_EPOCHS

  • NUM_UPDATE_STEPS_PER_EPOCH

  • TRAIN_DATALOADER_LEN

  • EVAL_DATALOADER_LEN

Alternatively, the same outcome could be achieved by overriding the Trainer’s create_scheduler() method.

Using PyTorch-accelerated schedulers

PyTorch-accelerated includes some implementations of schedulers, which have the same interface as PyTorch schedulers, as well as base classes to easily create custom schedules; these are discussed in more detail in Schedulers.

These scheduler implementations have an alternative constructor, which can be passed to train() as the the create_scheduler_fn argument directly, as demonstrated below:

from pytorch_accelerated.schedulers import CosineLrScheduler

trainer.train(
        train_dataset=train_dataset,
        num_epochs=num_epochs,
        per_device_batch_size=batch_size,
        create_scheduler_fn=CosineLrScheduler.create_scheduler_fn(num_warmup_epochs=5,
                                                                  warmup_starting_lr=1e-6,
                                                                  num_cooldown_epochs=5),
        )

Using timm schedulers

The schedulers included in timm have a different interface to the native PyTorch schedulers, so do not work with the base Trainer by default.

PyTorch-accelerated includes an alternative trainer TrainerWithTimmScheduler, which is compatible with timm schedulers; schedulers should be passed to this trainer as a factory function the same way as described above.

Evaluating a model

Once a model has been trained, or loaded from a checkpoint, the Trainer can also be used for evaluation, which consists of running a single epoch, using the Trainer’s evaluation loop logic, on the given dataset.

Trainer.evaluate(dataset=None, per_device_batch_size=8, dataloader_kwargs: Optional[dict] = None, collate_fn=None)[source]

Start an evaluation run.

Note

Starting an evaluation run will reset the Trainer’s run history.

Note

During distributed evaluation, if the per_device_batch_size * the number of processes used does not exactly divide the dataset, and drop_last=False has not been passed as a dataloader kwarg, the dataloader will repeat from the start in processes that run out of batches. This should be taken into consideration when calculating metrics.

Parameters
  • dataset – the dataset to use during evaluation

  • per_device_batch_size – the batch size to use per device

  • dataloader_kwargs – a dictionary of keyword arguments to pass to the dataloader constructor, for details see torch.utils.data.DataLoader

  • collate_fn – the collate function to be used by the dataloader

Utility Methods

Trainer.save_checkpoint(save_path, checkpoint_kwargs=None, save_optimizer=True, save_per_node=True)[source]

Save the model, optimizer and specified args as a checkpoint file.

Parameters
  • save_path – the path where to save the checkpoint, this should end in ‘.pt’

  • checkpoint_kwargs – additional objects to include in the checkpoint

  • save_optimizer – flag to indicate whether to include the optimizer in the checkpoint

  • save_per_node – flag to indicate whether to save the checkpoint once per machine, if False, the checkpoint will only be saved from the world process zero. This is True by default.

Trainer.load_checkpoint(checkpoint_path, load_optimizer=True)[source]

Load the model and optimizer from a checkpoint file.

Parameters
  • checkpoint_path – the path of the checkpoint file to load

  • load_optimizer – flag to indicate whether to load the optimizer if it is included in the checkpoint

Trainer.print(*args, **kwargs)[source]

Use in replacement of print() to only print once per node.

Trainer.gather(tensor)[source]

Gather the values in tensor across all processes and concatenate them on the first dimension. This can be useful to regroup the predictions from all processes when doing evaluation.

Note

This gather happens in all processes.

Parameters

tensor – (torch.Tensor, or a nested tuple/list/dictionary of torch.Tensor) The tensors to gather across all processes.

Returns

The gathered tensor(s) (torch.Tensor, or a nested tuple/list/dictionary of torch.Tensor). The first dimension of the result is num_processes multiplied by the first dimension of the input tensors.

Customizing Trainer Behaviour

Whilst the Trainer should work out of the box in straightforward use cases, subclassing the trainer and overriding its methods is intended and encouraged - think of the base implementation as a set of ‘sensible defaults’!

Note

Methods which are prefixed with a verb such as create or calculate expect a value to be returned, all other methods are used to set internal state (e.g. optimizer.step())

Setup Methods

Trainer.create_train_dataloader(batch_size: int, train_dl_kwargs: Optional[dict] = None) Iterable[source]

Create a dataloader to be used during training. This is initialised with the train_dataset and collate function which have been passed to the Trainer.

If no arguments are passed, the arguments returned by Trainer.get_default_train_dl_kwargs() are used.

Note

if batch size is included in train_dl_kwargs, this takes precedence over the batch_size argument.

Parameters
  • batch_size – the batch size to use per device

  • train_dl_kwargs – a dictionary of keyword arguments to pass to the dataloader constructor, for details see torch.utils.data.DataLoader

Returns

an instance of DataLoader

Trainer.get_default_train_dl_kwargs(batch_size) dict[source]

Return the default arguments that will be used by the training dataloader.

Parameters

batch_size – the batch size to use during training

Returns

a dictionary containing the default arguments for the training dataloader

Trainer.create_eval_dataloader(batch_size: int, eval_dl_kwargs: Optional[dict] = None) Iterable[source]

Create a dataloader to be used during evaluation. This is initialised with the eval_dataset and collate function which have been passed to the Trainer.

If no arguments are passed, the arguments returned by Trainer.get_default_eval_dl_kwargs() are used.

Note

if batch size is included in eval_dl_kwargs, this takes precedence over the batch_size argument.

Parameters
  • batch_size – the batch size to use per device

  • eval_dl_kwargs – a dictionary of keyword arguments to pass to the dataloader constructor, for details see torch.utils.data.DataLoader

Returns

an instance of torch.utils.data.DataLoader

Trainer.get_default_eval_dl_kwargs(batch_size) dict[source]

Return the default arguments that will be used by the evaluation dataloader.

Parameters

batch_size – the batch size to use during evaluation

Returns

a dictionary containing the default arguments for the evaluation dataloader

Trainer.create_scheduler()[source]

Create a learning rate scheduler based on the create_scheduler_fn function which has been passed to the Trainer. :return: a learning rate scheduler instance

Training Run Methods

Trainer.training_run_start()[source]

This method is called at the start of a training run.

Trainer.training_run_epoch_end()[source]

This method is called during a training run after both training and evaluation epochs have been completed.

Trainer.training_run_end()[source]

This method is called at the end of a training run.

Training epoch Methods

Trainer.train_epoch_start()[source]

This method is called at the start of a training epoch.

The default behaviour of this method is to call self.model.train()

Trainer.calculate_train_batch_loss(batch) dict[source]

Calculates the training loss and return this along with the batch size and model outputs. Any additional values returned will be available in the on_train_step_end() callback method.

Parameters

batch – the output of the train dataloader

Returns

A dictionary containing the training loss, model outputs and batch size. Can include any keys, but must include the keys ‘loss’, ‘model_outputs’ and ‘batch_size’

Trainer.backward_step(loss)[source]

Use the accelerator to perform the backward pass on the calculated value of the loss returned by calculate_train_batch_loss(). If gradient accumulation is enabled, this loss has been scaled by 1 / accumulation steps.

Parameters

loss – The loss tensor returned by calculate_train_batch_loss().

Trainer.optimizer_step()[source]

Performs a single optimization step and updates the parameters which have been passed to self.optimizer.

Trainer.scheduler_step()[source]

Performs a single scheduler step if self.scheduler has been assigned.

Trainer.optimizer_zero_grad()[source]

Sets the gradients of all optimized torch.Tensor s to zero.

Trainer.train_epoch_end()[source]

This method is called at the end of each training epoch.

Evaluation epoch Methods

Trainer.eval_epoch_start()[source]

This method is called at the start of an evaluation epoch.

The default behaviour of this method is to call self.model.eval()

Trainer.calculate_eval_batch_loss(batch) dict[source]

Calculates the evaluation loss and return this along with the batch size and model outputs. Any additional values returned will be available in the on_eval_step_end() callback.

Parameters

batch – the output of the eval dataloader

Returns

A dictionary containing the evaluation loss, model outputs and batch size. Can include any keys, but must include the keys loss, model_outputs and batch_size

Trainer.eval_epoch_end()[source]

This method is called at the end of an evaluation epoch.

Evaluation Run Methods

Trainer.evaluation_run_start()[source]

This method is called at the start of an evaluation run.

Trainer.evaluation_run_end()[source]

This method is called at the end of an evaluation run.

Internal Methods

Warning

In the spirit of Python, nothing is truly hidden within the Trainer. However, care must be taken as, by overriding these methods, you are fundamentally changing how the Trainer is working internally and this may have untended consequences. When modifying one or more internal methods, it is the responsibility of the user to ensure that the Trainer continues to work as intended!

Internal Setup

Trainer._create_accelerator()[source]

Create an instance of accelerate.Accelerator which will be used to manage training. :return:

Trainer._prepare_model_optimizer_and_dataloaders()[source]

Uses the trainer’s instance of accelerate.Accelerator to wrap the model, optimizer and dataloaders in any wrappers necessary for training. (e.g. torch.nn.parallel.DistributedDataParallel) and ensures the parameters are placed on the appropriate device.

By default, this will convert each dataloader to an instance of accelerate.data_loader.DataLoaderShard. Depending on the value of the drop_last attribute of the dataloaders, either iterations will stop at the first batch that would be too small / not present on all processes or loop with batches from the beginning on processes which run out of data, so that all batch sizes are the same size.

Note

This may change the length of the dataloaders, so this should be called before the number of update steps per epoch is calculated, i.e. to initialise a learning rate scheduler

Trainer._create_run_config(per_device_batch_size, num_epochs, gradient_accumulation_steps, max_num_train_steps, gradient_clip_value) TrainerRunConfig[source]

Create an instance of TrainerRunConfig representing the current state of the trainer.

Parameters
  • per_device_batch_size – the batch size per device

  • num_epochs – the number of epochs in the current training run

  • gradient_accumulation_steps – the number of gradient accumulation steps which will be used during the training run

  • max_num_train_steps – If specified, the maximum number of steps to train for. If present, this will take precedence over num_epochs

  • gradient_clip_value – the value used to determine the threshold to clip the gradients of the model’s parameters

Training run behaviour

Trainer._run_training()[source]

The method responsible for the orchestration of the high level steps which will be executed during a training run.

Training epoch behaviour

Trainer._run_train_epoch(train_dl)[source]

The method responsible for the behaviour of each training epoch.

Parameters

train_dl – the dataloader to be used during training

Trainer._clip_gradients()[source]

Clip the gradients of the model’s parameters that fall outside of the threshold specified in train().

By default, this clips the gradients using accelerate.Accelerator.clip_grad_value_()

Evaluation epoch behaviour

Trainer._run_eval_epoch(valid_dl, is_training: bool = True)[source]

The method responsible for the behaviour of each evaluation epoch.

Parameters
  • valid_dl – the dataloader to be used during evaluation

  • is_training – signals whether the evaluation is being run as part of a training run

Should I subclass the Trainer or use a callback?

The behaviour of the Trainer can also be extended using Callbacks. All callback methods are prefixed with on_.

It is recommended that callbacks are used to contain ‘infrastructure’ code, which is not essential to the operation of the training loop, such as logging, but this decision is left to the judgement of the user based on the specific use case. If it seems overkill to subclass the Trainer for the modification you wish to make, it may be better to use a callback instead.

For more information on callbacks, see Callbacks.

Recording metrics

The Trainer contains an instance of RunHistory, which can be used to store and retrieve the values of any metrics to track during training. By default, the only metrics that are recorded by the Trainer are the losses observed during training and evaluation.

Note

If the callback PrintMetricsCallback is being used, any metrics recorded in the run history will be printed to the console automatically.

The API for RunHistory is detailed at RunHistory.

Here is an example of how we can subclass the Trainer and use the RunHistory to track metrics computed using TorchMetrics:

from torchmetrics import MetricCollection, Accuracy, Precision, Recall
from pytorch_accelerated import Trainer

class TrainerWithMetrics(Trainer):
    def __init__(self, num_classes, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # this will be moved to the correct device automatically by the
        # MoveModulesToDeviceCallback callback, which is used by default
        self.metrics = MetricCollection(
            {
                "accuracy": Accuracy(num_classes=num_classes),
                "precision": Precision(num_classes=num_classes),
                "recall": Recall(num_classes=num_classes),
            }
        )

    def calculate_eval_batch_loss(self, batch):
        batch_output = super().calculate_eval_batch_loss(batch)
        preds = batch_output["model_outputs"].argmax(dim=-1)

        self.metrics.update(preds, batch[1])

        return batch_output

    def eval_epoch_end(self):
        metrics = self.metrics.compute()
        self.run_history.update_metric("accuracy", metrics["accuracy"].cpu())
        self.run_history.update_metric("precision", metrics["precision"].cpu())
        self.run_history.update_metric("recall", metrics["recall"].cpu())

        self.metrics.reset()

Note

If you feel that subclassing the Trainer seems too excessive for this use case, this could also be done using a callback as demonstrated in Example: Tracking metrics using a callback.

What goes on inside the Trainer?

In pseudocode, the execution of a training run can be depicted as:

train_dl = create_train_dataloader()
eval_dl = create_eval_dataloader()
scheduler = create_scheduler()

training_run_start()
on_training_run_start()

for epoch in num_epochs:
    train_epoch_start()
    on_train_epoch_start()
    for batch in train_dl:
        on_train_step_start()
        batch_output = calculate_train_batch_loss(batch)
        on_train_step_end(batch, batch_output)
        backward_step(batch_output["loss"])
        optimizer_step()
        scheduler_step()
        optimizer_zero_grad()
    train_epoch_end()
    on_train_epoch_end()

    eval_epoch_start()
    on_eval_epoch_start()
    for batch in eval_dl:
        on_eval_step_start()
        batch_output = calculate_eval_batch_loss(batch)
        on_eval_step_end(batch, batch_output)
    eval_epoch_end()
    on_eval_epoch_end()

    training_run_epoch_end()
    on_training_run_epoch_end()

training_run_end()
on_training_run_end()

Similarly, the execution of an evaluation run can be depicted as:

eval_dl = create_eval_dataloader()

evaluation_run_start()
on_evaluation_run_start()

eval_epoch_start()
on_eval_epoch_start()
for batch in eval_dl:
    on_eval_step_start()
    batch_output = calculate_eval_batch_loss(batch)
    on_eval_step_end(batch, batch_output)
eval_epoch_end()
on_eval_epoch_end()

evaluation_run_end()
on_evaluation_run_end()

The best way to understand how the Trainer works internally is by examining the source code for the train() method; significant care has gone into making the internal methods as clean and clear as possible.