Schedulers

PyTorch-accelerated provides some scheduler implementations which can be used in any PyTorch training loop. However, unlike PyTorch’s native schedulers - which can be called at different points in the training loop - all Pytorch-accelerated schedulers expect to be called after each optimizer update.

Implemented Schedulers

class pytorch_accelerated.schedulers.cosine_scheduler.CosineLrScheduler(optimizer: Optimizer, total_num_epochs: int, num_update_steps_per_epoch: int, k_decay=1.0, lr_min: float = 1e-06, min_lr_ratio=None, num_warmup_epochs: int = 0, warmup_starting_lr=1e-06, warmup_starting_lr_ratio=None, num_cooldown_epochs=0)[source]

Bases: StatefulSchedulerBase

A stateful Cosine annealing learning rate scheduler, as described in this paper, but without restarts.

This scheduler differs from the PyTorch’s CosineAnnealingLR as it provides options to add warmup and cooldown epochs. Additionally, the annealing rate can be modified by adjusting the k-decay parameter, for which the rate of change of the learning rate is changed by its k-th order derivative, as described in here.

If warmup epochs are specified, the learning rate will increase in constant increments from the warmup_starting_lr provided until the learning rate specified in the parameter group is reached.

If cooldown epochs are specified, the learning rate will be fixed at the minimum lr value given. This behaviour will continue if the scheduler is called after the training cycle has completed.

__init__(optimizer: Optimizer, total_num_epochs: int, num_update_steps_per_epoch: int, k_decay=1.0, lr_min: float = 1e-06, min_lr_ratio=None, num_warmup_epochs: int = 0, warmup_starting_lr=1e-06, warmup_starting_lr_ratio=None, num_cooldown_epochs=0)[source]

Create a new ConsineLrScheduler object which can be used to modify the learning rate in an optimizer’s parameter groups.

Parameters:
  • optimizer – a PyTorch optimizer containing one or more parameter groups

  • total_num_epochs – the total number of training epochs, inclusive of any warmup and cooldown epochs

  • num_update_steps_per_epoch – the number of optimizer updates that take place per epoch

  • k_decay – adjusts the rate of annealing. Higher values will maintain a higher lr for longer

  • lr_min – the minimum value that the learning rate should decay to for all parameter groups. This will be held fixed during cooldown epochs

  • min_lr_ratio – this can be used to represent the minimum lr for each parameter group as a ratio of its maximum lr. If set, this will take precedence over lr_min

  • num_warmup_epochs – the number of epochs to gradually increase the lr until it reaches the maximum value

  • warmup_starting_lr – the starting lr that will be used for all parameter groups at the beginning of training if num_warmup_epochs is greater than 0

  • warmup_starting_lr_ratio – this can be used to represent the warmup starting lr for each parameter group as a ratio of its maximum lr. If set, this will take precedence over warmup_starting_lr

  • num_cooldown_epochs – the number of epochs to hold the lr at its minimum value

classmethod create_scheduler_fn(total_num_epochs: int = TrainerPlaceholderValues.NUM_EPOCHS, num_update_steps_per_epoch: int = TrainerPlaceholderValues.NUM_UPDATE_STEPS_PER_EPOCH, k_decay=1.0, lr_min: float = 1e-06, min_lr_ratio=None, num_warmup_epochs: int = 0, warmup_starting_lr=1e-06, warmup_starting_lr_ratio=None, num_cooldown_epochs=0) Callable[source]

An alternative constructor which returns a function that accepts an optimizer and creates an instance of CosineLrScheduler. This is primarily intended to be used with the Trainer as illustrated below:

trainer.train(
train_dataset=train_dataset,
num_epochs=num_epochs,
per_device_batch_size=batch_size,
create_scheduler_fn=CosineLrScheduler.create_scheduler_fn(num_warmup_epochs=5),
)

By default, the total_num_epochs and num_iterations_per_epoch arguments will be set by the Trainer with the correct values at runtime.

Parameters:
  • total_num_epochs – the total number of training epochs, inclusive of any warmup and cooldown epochs

  • num_update_steps_per_epoch – the number of optimizer updates that take place per epoch

  • k_decay – adjusts the rate of annealing. Higher values will maintain a higher lr for longer

  • lr_min – the minimum value that the learning rate should decay to for all parameter groups. This will be held fixed during cooldown epochs

  • min_lr_ratio – this can be used to represent the minimum lr for each parameter group as a ratio of its maximum lr. If set, this will take precedence over lr_min

  • num_warmup_epochs – the number of epochs to gradually increase the lr until it reaches the maximum value

  • warmup_starting_lr – the starting lr that will be used for all parameter groups at the beginning of training if num_warmup_epochs is greater than 0

  • warmup_starting_lr_ratio – this can be used to represent the warmup starting lr for each parameter group as a ratio of its maximum lr. If set, this will take precedence over warmup_starting_lr

  • num_cooldown_epochs – the number of epochs to hold the lr at its minimum value

Returns:

a function which accepts an optimizer as an argument and returns an instance of CosineLrScheduler

get_updated_values(num_updates: int)[source]

Calculate the learning rate for a particular step given the number of previous updates.

If warmup epochs are specified, the learning rate will increase in constant increments from the warmup_starting_lr provided until the learning rate specified in the parameter group is reached.

If cooldown epochs are specified, the learning rate will be fixed at the minimum lr value given. This behaviour will continue if the scheduler is called after the training cycle has completed.

Between any warmup or cooldown epochs, the cosine annealing strategy will be used.

Parameters:

num_updates – the number of previous updates

Returns:

the learning rates with which to update each parameter group

Base Schedulers

PyTorch-accelerated provides base classes for two types of schedulers.

Stateful Schedulers

Stateful schedulers maintain an internal count corresponding to how many times the scheduler’s step() method has beeen called. As these schedulers have the same interface as the native PyTorch schedulers, these are supported by the Trainer by default.

class pytorch_accelerated.schedulers.scheduler_base.StatefulSchedulerBase(optimizer, param_group_field: str = 'lr')[source]

A stateful parameter scheduler base class that can be used to update any field within an optimizer’s parameter groups. The most common use case for this is learning rate scheduling.

Unlike PyTorch’s schedulers, which can be called at different points in the training loop depending on the implementation, this class is intended to be consistently called at the end of each optimizer update.

This class is responsible for maintaining the number of updates, incrementing an internal count each time that the scheduler step is calculated.

The usage of this class is illustrated below:

for current_epoch, epoch in enumerate(num_epochs):
    for batch in train_dataloader:
        xb, yb = batch
        predictions = model(xb)
        loss = loss_func(predictions, yb)

        loss.backward()
        optimizer.step()

        scheduler.step()
__init__(optimizer, param_group_field: str = 'lr')[source]

Create a new instance of a stateful parameter scheduler.

Parameters:
  • optimizer – a PyTorch optimizer

  • param_group_field – the field in the optimizer’s parameter groups corresponding to the parameter to be scheduled

step()[source]

Calculate the updated value of the scheduled parameter and update the optimizer’s parameter groups.

Stateless Schedulers

These schedulers maintain no internal state about the current training run, and therefore require that the current number of updates is explicitly provided when called. To use a stateless scheduler with the Trainer, this would require subclassing the Trainer and overriding the scheduler_step() method.

class pytorch_accelerated.schedulers.scheduler_base.SchedulerBase(optimizer: Optimizer, param_group_field: str = 'lr')[source]

A parameter scheduler base class that can be used to update any field within an optimizer’s parameter groups. The most common use case for this is learning rate scheduling.

Unlike PyTorch’s schedulers, which can be called at different points in the training loop depending on the implementation, this class is intended to be consistently called at the end of each optimizer update.

As this class is stateless by default, it expects that the number of updates is explicitly provided, as illustrated below:

for current_epoch, epoch in enumerate(num_epochs):
    num_updates = current_epoch * num_update_steps_per_epoch
    for batch in train_dataloader:
        xb, yb = batch
        predictions = model(xb)
        loss = loss_func(predictions, yb)

        loss.backward()
        optimizer.step()

        num_updates +=1
        scheduler.step_update(num_updates)
__init__(optimizer: Optimizer, param_group_field: str = 'lr')[source]

Create a new instance of a parameter scheduler.

Parameters:
  • optimizer – a PyTorch optimizer

  • param_group_field – the field in the optimizer’s parameter groups corresponding to the parameter to be scheduled

abstract get_updated_values(num_updates: int) None | Number | Iterable[Number][source]

Calculate updated values for the scheduled parameter.

If a single value is returned, all parameter groups will be updated with this value.

To update each parameter group with a different value, an iterable collection, containing an updated value for each parameter group, should be returned.

If None is returned, the parameter groups will not be updated.

Parameters:

num_updates – the number of optimizer updates

Returns:

the updated values of the scheduled parameter. This should be either a single value, or an iterable collection containing a value for each parameter group.

load_state_dict(state_dict)[source]

Updates the attributes of the given scheduler from the given state dict.

Parameters:

state_dict – the state dict to be loaded

state_dict()[source]

Get the state dict for the scheduler, containing all attributes except the optimizer, which should be saved separately.

Returns:

the scheduler’s state dict

step_update(num_updates: int)[source]

Calculate the updated value of the scheduled parameter and update the optimizer’s parameter groups.

Parameters:

num_updates – the number of optimizer updates

Creating New Schedulers

Whilst schedulers are usually used to schedule learning rates, the scheduler base classes in PyTorch-accelerated can be used to schedule any parameter in an optimizer’s parameter group.

To create a new scheduler, in most cases, all that is required is to subclass one of the base classes and override the get_updated_values() method.

Example: Creating a simple milestone lr scheduler

Here is an example of how we can implement a scheduler to adjust the learning rate for each parameter group by a factor gamma each time an epoch milestone is reached:

from pytorch_accelerated.schedulers import StatefulSchedulerBase

class MilestoneLrScheduler(StatefulSchedulerBase):
    def __init__(
        self, optimizer, gamma=0.5, epoch_milestones=(2, 4, 5), num_steps_per_epoch=100
    ):
        super().__init__(optimizer, param_group_field="lr")
        self.milestones = set(
            (num_steps_per_epoch * milestone for milestone in epoch_milestones)
        )
        self.gamma = gamma

    def get_updated_values(self, num_updates: int):
        if num_updates in self.milestones:
            lr_values = [
                group[self.param_group_field] for group in self.optimizer.param_groups
            ]
            updated_lrs = [lr * self.gamma for lr in lr_values]
            return updated_lrs

Example: Scheduling weight decay

Here is an example of how we can define a scheduler to incrementally increase the amount of weight decay by a factor gamma every n steps:

from pytorch_accelerated.schedulers import StatefulSchedulerBase

class StepWdScheduler(StatefulSchedulerBase):
    def __init__(self, optimizer, n=1000, gamma=1.1):
        super().__init__(optimizer, param_group_field="weight_decay")
        self.n = n
        self.gamma = gamma

    def get_updated_values(self, num_updates: int):
        if num_updates % self.n == 0 and num_updates > 0:
            wd_values = [
                group[self.param_group_field] for group in self.optimizer.param_groups
            ]
            updated_wd_values = [wd * self.gamma for wd in wd_values]
            return updated_wd_values