Trainer
- class pytorch_accelerated.trainer.Trainer(model, loss_func, optimizer, callbacks=(<class 'pytorch_accelerated.callbacks.MoveModulesToDeviceCallback'>, <class 'pytorch_accelerated.callbacks.TerminateOnNaNCallback'>, <class 'pytorch_accelerated.callbacks.PrintProgressCallback'>, <class 'pytorch_accelerated.callbacks.ProgressBarCallback'>, <class 'pytorch_accelerated.callbacks.LogMetricsCallback'>), run_history=None)[source]
The Trainer is designed to encapsulate an entire training loop for a specific task, bringing together the model, loss function and optimizer, and providing a specification of the behaviour to execute for each step of the training process.
The trainer has been implemented such that it provides (overridable) implementations of the parts of training that rarely change after they have been defined – such as creating a data loader, or how a batch of data is fed to the model – whilst remaining decoupled from components that are likely to change, such as the model, dataset, loss function and optimizer.
- __init__(model, loss_func, optimizer, callbacks=(<class 'pytorch_accelerated.callbacks.MoveModulesToDeviceCallback'>, <class 'pytorch_accelerated.callbacks.TerminateOnNaNCallback'>, <class 'pytorch_accelerated.callbacks.PrintProgressCallback'>, <class 'pytorch_accelerated.callbacks.ProgressBarCallback'>, <class 'pytorch_accelerated.callbacks.LogMetricsCallback'>), run_history=None)[source]
Create a new trainer object which can be used to train the given model using the provided loss function and optimizer.
- Parameters:
model – a subclass of nn.Module to be trained
loss_func – the loss function to use when training the model
optimizer – the optimizer to update the model’s parameters
callbacks – a list of callbacks to use during training runs. If a list of callbacks is not provided, the default selection will be used.
run_history – an instance of a RunHistory subclass to track training runs. If this is not provided, a new one will be created.
The callbacks that are used by default are (
MoveModulesToDeviceCallback
,TerminateOnNaNCallback
,PrintProgressCallback
,ProgressBarCallback
,LogMetricsCallback
, )
- class pytorch_accelerated.trainer.TrainerWithTimmScheduler(*args, **kwargs)[source]
Subclass of the
Trainer
that works with timm schedulers instead of standard PyTorch learning rate schedulers
Training a model
The main entrypoint for the Trainer
is the train()
method, which is used to launch a training run.
- Trainer.train(train_dataset, num_epochs, eval_dataset=None, per_device_batch_size=8, max_num_train_steps=None, gradient_accumulation_steps=1, gradient_clip_value=None, create_scheduler_fn=None, train_dataloader_kwargs: dict | None = None, eval_dataloader_kwargs: dict | None = None, reset_run_history=True, collate_fn=None)[source]
Start a training run. If an evaluation dataset is provided, this routine will include both training and evaluation epochs.
Note
As the optimizer needs to be internally prepared prior to training, in order to use a learning rate scheduler, a factory function must be provided to
create_scheduler_fn
. This must be a function which accepts the optimizer as a single parameter and returns an instance of a learning rate scheduler. Passing an instance of a learning rate scheduler will not work here.- Parameters:
train_dataset – the dataset to use during training epochs
num_epochs – the number of epochs to train for
eval_dataset – the dataset to use during evaluation epochs, if this is not provided, evaluation is skipped.
per_device_batch_size – the batch size to use per device
max_num_train_steps – the maximum number of steps to train for. If provided, this will override num_epochs
gradient_accumulation_steps – accumulate gradients to the specified number of steps to simulate a bigger batch size. By default, this is set to
1
gradient_clip_value – if specified, the gradients of the model’s parameters will be clipped to the range
[-gradient_clip_value, gradient_clip_value]
create_scheduler_fn – a function which accepts an optimizer as an argument and returns a learning rate scheduler
train_dataloader_kwargs – : a dictionary of keyword arguments to pass to the training dataloader constructor, for details see
torch.utils.data.DataLoader
eval_dataloader_kwargs – a dictionary of keyword arguments to pass to the evaluation dataloader constructor, for details see
torch.utils.data.DataLoader
reset_run_history – reset any run history saved by the trainer from previous training runs
collate_fn – the collate function to be used by the training and evaluation dataloaders
Using learning rate schedulers
As Pytorch schedulers are not consistently called in the same way, to enable maximum flexibility, PyTorch-accelerated’s
Trainer
expects that a given scheduler should be called after each optimizer update by default.
Note that, as the optimizer and dataloaders need to be internally prepared prior to training, in order to use a learning
rate scheduler, a factory function must be provided to train()
as the create_scheduler_fn
argument.
This must be a function which accepts the optimizer as a single parameter and returns an instance of a learning rate scheduler.
Note
Passing an instance of a PyTorch learning rate scheduler as the create_scheduler_fn
argument
to train()
will not work as intended.
A simple method of creating a scheduler factory function this is by using functools.partial()
like so:
from functools import Partial
from torch.optim import lr_scheduler
create_scheduler_fn = partial(lr_scheduler.StepLR, step_size=7, gamma=0.1)
Note
The Trainer
calls a step on the provided scheduler after every batch. This can lead to
unexpected results as some PyTorch schedulers are expected to step only after every epoch.
For instance, in the example above, the learning rate would be multiplied by 0.1
at every batch. As this particular
scheduler is designed to be called once per epoch, this is not the desired behaviour!
We can resolve this by representing the step_size
in terms of the number of updates, like this:
from functools import Partial
from torch.optim import lr_scheduler
from pytorch_accelerated import TrainerPlaceholderValues
epochs_step_size = 7
create_scheduler_fn = partial(
lr_scheduler.StepLR,
step_size=TrainerPlaceHolderValues.NUM_UPDATE_STEPS_PER_EPOCH * epochs_step_size
)
Here, to determine the value of the number of steps per epoch, we have used a TrainerPlaceholderValues
placeholder, which are described below.
Using TrainerPlaceHolderValues
- class pytorch_accelerated.trainer.TrainerPlaceholderValues(value)[source]
Some learning rate schedulers require information such as the total number of steps that will take place during a training run. As this information is not accessible prior to creating the training dataloader - which will be done as part of the
train()
method - a placeholder value can be used in the cases, as demonstrated below:from functools import Partial from pytorch_accelerated import TrainerPlaceholderValues from torch.optim.lr_scheduler import OneCycleLR create_scheduler_fn = partial( OneCycleLR, max_lr=config.lr, epochs=TrainerPlaceholderValues.NUM_EPOCHS, steps_per_epoch=TrainerPlaceholderValues.NUM_UPDATE_STEPS_PER_EPOCH, )
These placeholders will be replaced by the trainer with the correct values during training.
The list of the available placeholders are:
NUM_EPOCHS
NUM_UPDATE_STEPS_PER_EPOCH
TRAIN_DATALOADER_LEN
EVAL_DATALOADER_LEN
Alternatively, the same outcome could be achieved by overriding the Trainer
’s create_scheduler()
method.
Using PyTorch-accelerated schedulers
PyTorch-accelerated includes some implementations of schedulers, which have the same interface as PyTorch schedulers, as well as base classes to easily create custom schedules; these are discussed in more detail in Schedulers.
These scheduler implementations have an alternative constructor, which can be passed to train()
as the
the create_scheduler_fn
argument directly, as demonstrated below:
from pytorch_accelerated.schedulers import CosineLrScheduler
trainer.train(
train_dataset=train_dataset,
num_epochs=num_epochs,
per_device_batch_size=batch_size,
create_scheduler_fn=CosineLrScheduler.create_scheduler_fn(num_warmup_epochs=5,
warmup_starting_lr=1e-6,
num_cooldown_epochs=5),
)
Using timm schedulers
The schedulers included in timm have a different interface to the native PyTorch schedulers,
so do not work with the base Trainer
by default.
PyTorch-accelerated includes an alternative trainer TrainerWithTimmScheduler
, which is compatible with
timm schedulers; schedulers should be passed to this trainer as a factory function the same way as described above.
Evaluating a model
Once a model has been trained, or loaded from a checkpoint, the Trainer
can also be used for evaluation, which
consists of running a single epoch, using the Trainer
’s evaluation loop logic, on the given dataset.
- Trainer.evaluate(dataset=None, per_device_batch_size=8, dataloader_kwargs: dict | None = None, collate_fn=None)[source]
Start an evaluation run.
Note
Starting an evaluation run will reset the
Trainer
’s run history.Note
During distributed evaluation, if the per_device_batch_size * the number of processes used does not exactly divide the dataset, and drop_last=False has not been passed as a dataloader kwarg, the dataloader will repeat from the start in processes that run out of batches. This should be taken into consideration when calculating metrics.
- Parameters:
dataset – the dataset to use during evaluation
per_device_batch_size – the batch size to use per device
dataloader_kwargs – a dictionary of keyword arguments to pass to the dataloader constructor, for details see
torch.utils.data.DataLoader
collate_fn – the collate function to be used by the dataloader
Utility Methods
- Trainer.save_checkpoint(save_path, checkpoint_kwargs=None, save_optimizer=True, save_per_node=True)[source]
Save the model, optimizer and specified args as a checkpoint file.
- Parameters:
save_path – the path where to save the checkpoint, this should end in ‘.pt’
checkpoint_kwargs – additional objects to include in the checkpoint
save_optimizer – flag to indicate whether to include the optimizer in the checkpoint
save_per_node – flag to indicate whether to save the checkpoint once per machine, if False, the checkpoint will only be saved from the world process zero. This is True by default.
- Trainer.load_checkpoint(checkpoint_path, load_optimizer=True)[source]
Load the model and optimizer from a checkpoint file.
- Parameters:
checkpoint_path – the path of the checkpoint file to load
load_optimizer – flag to indicate whether to load the optimizer if it is included in the checkpoint
- Trainer.gather(tensor, padding_value=None)[source]
Gather the values in tensor across all processes and concatenate them on the first dimension. This can be useful to regroup the predictions from all processes when doing evaluation.
If a padding value is provided, padding will be applied along the first dimension where necessary, to ensure that tensors in all processes have the same shape.
Note
The given value of padding_value should ideally not appear in the expected range of values that the tensor may contain
- Parameters:
tensor – (
torch.Tensor
, or a nested tuple/list/dictionary oftorch.Tensor
) The tensors to gather across all processes.padding_value – if provided, the value with which to pad tensors to ensure that all processes have the same shape
- Returns:
The gathered tensor(s) (
torch.Tensor
, or a nested tuple/list/dictionary oftorch.Tensor
). The first dimension of the result is num_processes multiplied by the first dimension of the input tensors.
Note
This gather happens in all processes.
Customizing Trainer Behaviour
Whilst the Trainer
should work out of the box in straightforward use cases, subclassing the trainer and overriding
its methods is intended and encouraged - think of the base implementation as a set of ‘sensible defaults’!
Note
Methods which are prefixed with a verb such as create or calculate expect a value to be returned,
all other methods are used to set internal state (e.g. optimizer.step()
)
Setup Methods
- Trainer.create_train_dataloader(batch_size: int, train_dl_kwargs: dict | None = None) Iterable [source]
Create a dataloader to be used during training. This is initialised with the train_dataset and collate function which have been passed to the Trainer.
If no arguments are passed, the arguments returned by
Trainer.get_default_train_dl_kwargs()
are used.Note
if batch size is included in train_dl_kwargs, this takes precedence over the batch_size argument.
- Parameters:
batch_size – the batch size to use per device
train_dl_kwargs – a dictionary of keyword arguments to pass to the dataloader constructor, for details see
torch.utils.data.DataLoader
- Returns:
an instance of
DataLoader
- Trainer.get_default_train_dl_kwargs(batch_size) dict [source]
Return the default arguments that will be used by the training dataloader.
- Parameters:
batch_size – the batch size to use during training
- Returns:
a dictionary containing the default arguments for the training dataloader
- Trainer.create_eval_dataloader(batch_size: int, eval_dl_kwargs: dict | None = None) Iterable [source]
Create a dataloader to be used during evaluation. This is initialised with the eval_dataset and collate function which have been passed to the Trainer.
If no arguments are passed, the arguments returned by
Trainer.get_default_eval_dl_kwargs()
are used.Note
if batch size is included in eval_dl_kwargs, this takes precedence over the batch_size argument.
- Parameters:
batch_size – the batch size to use per device
eval_dl_kwargs – a dictionary of keyword arguments to pass to the dataloader constructor, for details see
torch.utils.data.DataLoader
- Returns:
an instance of
torch.utils.data.DataLoader
Training Run Methods
- Trainer.training_run_epoch_end()[source]
This method is called during a training run after both training and evaluation epochs have been completed.
Training epoch Methods
- Trainer.train_epoch_start()[source]
This method is called at the start of a training epoch.
The default behaviour of this method is to call
self.model.train()
- Trainer.calculate_train_batch_loss(batch) dict [source]
Calculates the training loss and return this along with the batch size and model outputs. Any additional values returned will be available in the
on_train_step_end()
callback method.- Parameters:
batch – the output of the train dataloader
- Returns:
A dictionary containing the training loss, model outputs and batch size. Can include any keys, but must include the keys ‘loss’, ‘model_outputs’ and ‘batch_size’
- Trainer.backward_step(loss)[source]
Use the accelerator to perform the backward pass on the calculated value of the loss returned by
calculate_train_batch_loss()
. If gradient accumulation is enabled, this loss has been scaled by 1 / accumulation steps.- Parameters:
loss – The loss tensor returned by
calculate_train_batch_loss()
.
- Trainer.optimizer_step()[source]
Performs a single optimization step and updates the parameters which have been passed to
self.optimizer
.
Evaluation epoch Methods
- Trainer.eval_epoch_start()[source]
This method is called at the start of an evaluation epoch.
The default behaviour of this method is to call
self.model.eval()
- Trainer.calculate_eval_batch_loss(batch) dict [source]
Calculates the evaluation loss and return this along with the batch size and model outputs. Any additional values returned will be available in the
on_eval_step_end()
callback.- Parameters:
batch – the output of the eval dataloader
- Returns:
A dictionary containing the evaluation loss, model outputs and batch size. Can include any keys, but must include the keys
loss
,model_outputs
andbatch_size
Evaluation Run Methods
Internal Methods
Warning
In the spirit of Python, nothing is truly hidden within the Trainer
. However, care must be taken as, by
overriding these methods, you are fundamentally changing how the Trainer
is working internally and this may have
untended consequences. When modifying one or more internal methods, it is the responsibility of the user to ensure that
the Trainer
continues to work as intended!
Internal Setup
- Trainer._create_accelerator()[source]
Create an instance of
accelerate.Accelerator
which will be used to manage training.
- Trainer._prepare_model_optimizer_and_dataloaders()[source]
Uses the trainer’s instance of
accelerate.Accelerator
to wrap the model, optimizer and dataloaders in any wrappers necessary for training. (e.g.torch.nn.parallel.DistributedDataParallel
) and ensures the parameters are placed on the appropriate device.By default, this will convert each dataloader to an instance of
accelerate.data_loader.DataLoaderShard
. Depending on the value of the drop_last attribute of the dataloaders, either iterations will stop at the first batch that would be too small / not present on all processes or loop with batches from the beginning on processes which run out of data, so that all batch sizes are the same size.Note
This may change the length of the dataloaders, so this should be called before the number of update steps per epoch is calculated, i.e. to initialise a learning rate scheduler
- Trainer._create_run_config(per_device_batch_size, num_epochs, gradient_accumulation_steps, max_num_train_steps, gradient_clip_value) TrainerRunConfig [source]
Create an instance of
TrainerRunConfig
representing the current state of the trainer.- Parameters:
per_device_batch_size – the batch size per device
num_epochs – the number of epochs in the current training run
gradient_accumulation_steps – the number of gradient accumulation steps which will be used during the training run
max_num_train_steps – If specified, the maximum number of steps to train for. If present, this will take precedence over
num_epochs
gradient_clip_value – the value used to determine the threshold to clip the gradients of the model’s parameters
Training run behaviour
Training epoch behaviour
- Trainer._run_train_epoch(train_dl)[source]
The method responsible for the behaviour of each training epoch.
- Parameters:
train_dl – the dataloader to be used during training
- Trainer._clip_gradients()[source]
Clip the gradients of the model’s parameters that fall outside of the threshold specified in
train()
.By default, this clips the gradients using
accelerate.Accelerator.clip_grad_value_()
Evaluation epoch behaviour
Should I subclass the Trainer or use a callback?
The behaviour of the Trainer
can also be extended using Callbacks. All callback methods are prefixed with on_
.
It is recommended that callbacks are used to contain ‘infrastructure’ code, which is not essential to the operation of the training loop,
such as logging, but this decision is left to the judgement of the user based on the specific use case. If it seems
overkill to subclass the Trainer
for the modification you wish to make, it may be better to use a callback instead.
For more information on callbacks, see Callbacks.
Recording metrics
The Trainer
contains an instance of RunHistory
, which can be used to store and retrieve the values of
any metrics to track during training. By default, the only metrics that are recorded by the Trainer
are
the losses observed during training and evaluation.
Note
If the callback PrintMetricsCallback
is being used, any metrics recorded
in the run history will be printed to the console automatically.
The API for RunHistory
is detailed at RunHistory.
Here is an example of how we can subclass the Trainer
and use the RunHistory
to track metrics
computed using TorchMetrics:
from torchmetrics import MetricCollection, Accuracy, Precision, Recall
from pytorch_accelerated import Trainer
class TrainerWithMetrics(Trainer):
def __init__(self, num_classes, *args, **kwargs):
super().__init__(*args, **kwargs)
# this will be moved to the correct device automatically by the
# MoveModulesToDeviceCallback callback, which is used by default
self.metrics = MetricCollection(
{
"accuracy": Accuracy(num_classes=num_classes),
"precision": Precision(num_classes=num_classes),
"recall": Recall(num_classes=num_classes),
}
)
def calculate_eval_batch_loss(self, batch):
batch_output = super().calculate_eval_batch_loss(batch)
preds = batch_output["model_outputs"].argmax(dim=-1)
self.metrics.update(preds, batch[1])
return batch_output
def eval_epoch_end(self):
metrics = self.metrics.compute()
self.run_history.update_metric("accuracy", metrics["accuracy"].cpu())
self.run_history.update_metric("precision", metrics["precision"].cpu())
self.run_history.update_metric("recall", metrics["recall"].cpu())
self.metrics.reset()
Note
If you feel that subclassing the Trainer
seems too excessive for this use case, this could also be done using a callback
as demonstrated in Example: Tracking metrics using a callback.
What goes on inside the Trainer?
In pseudocode, the execution of a training run can be depicted as:
train_dl = create_train_dataloader()
eval_dl = create_eval_dataloader()
scheduler = create_scheduler()
training_run_start()
on_training_run_start()
for epoch in num_epochs:
train_epoch_start()
on_train_epoch_start()
for batch in train_dl:
on_train_step_start()
batch_output = calculate_train_batch_loss(batch)
on_train_step_end(batch, batch_output)
backward_step(batch_output["loss"])
optimizer_step()
scheduler_step()
optimizer_zero_grad()
train_epoch_end()
on_train_epoch_end()
eval_epoch_start()
on_eval_epoch_start()
for batch in eval_dl:
on_eval_step_start()
batch_output = calculate_eval_batch_loss(batch)
on_eval_step_end(batch, batch_output)
eval_epoch_end()
on_eval_epoch_end()
training_run_epoch_end()
on_training_run_epoch_end()
training_run_end()
on_training_run_end()
Similarly, the execution of an evaluation run can be depicted as:
eval_dl = create_eval_dataloader()
evaluation_run_start()
on_evaluation_run_start()
eval_epoch_start()
on_eval_epoch_start()
for batch in eval_dl:
on_eval_step_start()
batch_output = calculate_eval_batch_loss(batch)
on_eval_step_end(batch, batch_output)
eval_epoch_end()
on_eval_epoch_end()
evaluation_run_end()
on_evaluation_run_end()
The best way to understand how the Trainer
works internally is by examining the source code for the train()
method;
significant care has gone into making the internal methods as clean and clear as possible.