Schedulers

PyTorch-accelerated provides some scheduler implementations which can be used in any PyTorch training loop. However, unlike PyTorch’s native schedulers - which can be called at different points in the training loop - all Pytorch-accelerated schedulers expect to be called after each optimizer update.

Implemented Schedulers

class pytorch_accelerated.schedulers.cosine_scheduler.CosineLrScheduler(optimizer: torch.optim.Optimizer, total_num_epochs: int, num_update_steps_per_epoch: int, k_decay=1.0, lr_min: float = 1e-06, min_lr_ratio=None, num_warmup_epochs: int = 0, warmup_starting_lr=1e-06, warmup_starting_lr_ratio=None, num_cooldown_epochs=0)[source]

Bases: StatefulSchedulerBase

A stateful Cosine annealing learning rate scheduler, as described in this paper, but without restarts.

This scheduler differs from the PyTorch’s CosineAnnealingLR as it provides options to add warmup and cooldown epochs. Additionally, the annealing rate can be modified by adjusting the k-decay parameter, for which the rate of change of the learning rate is changed by its k-th order derivative, as described in here.

If warmup epochs are specified, the learning rate will increase in constant increments from the warmup_starting_lr provided until the learning rate specified in the parameter group is reached.

If cooldown epochs are specified, the learning rate will be fixed at the minimum lr value given. This behaviour will continue if the scheduler is called after the training cycle has completed.

__init__(optimizer: torch.optim.Optimizer, total_num_epochs: int, num_update_steps_per_epoch: int, k_decay=1.0, lr_min: float = 1e-06, min_lr_ratio=None, num_warmup_epochs: int = 0, warmup_starting_lr=1e-06, warmup_starting_lr_ratio=None, num_cooldown_epochs=0)[source]

Create a new ConsineLrScheduler object which can be used to modify the learning rate in an optimizer’s parameter groups.

Parameters:
  • optimizer – a PyTorch optimizer containing one or more parameter groups

  • total_num_epochs – the total number of training epochs, inclusive of any warmup and cooldown epochs

  • num_update_steps_per_epoch – the number of optimizer updates that take place per epoch

  • k_decay – adjusts the rate of annealing. Higher values will maintain a higher lr for longer

  • lr_min – the minimum value that the learning rate should decay to for all parameter groups. This will be held fixed during cooldown epochs

  • min_lr_ratio – this can be used to represent the minimum lr for each parameter group as a ratio of its maximum lr. If set, this will take precedence over lr_min

  • num_warmup_epochs – the number of epochs to gradually increase the lr until it reaches the maximum value

  • warmup_starting_lr – the starting lr that will be used for all parameter groups at the beginning of training if num_warmup_epochs is greater than 0

  • warmup_starting_lr_ratio – this can be used to represent the warmup starting lr for each parameter group as a ratio of its maximum lr. If set, this will take precedence over warmup_starting_lr

  • num_cooldown_epochs – the number of epochs to hold the lr at its minimum value

classmethod create_scheduler_fn(total_num_epochs: int = TrainerPlaceholderValues.NUM_EPOCHS, num_update_steps_per_epoch: int = TrainerPlaceholderValues.NUM_UPDATE_STEPS_PER_EPOCH, k_decay=1.0, lr_min: float = 1e-06, min_lr_ratio=None, num_warmup_epochs: int = 0, warmup_starting_lr=1e-06, warmup_starting_lr_ratio=None, num_cooldown_epochs=0) Callable[source]

An alternative constructor which returns a function that accepts an optimizer and creates an instance of CosineLrScheduler. This is primarily intended to be used with the Trainer as illustrated below:

trainer.train(
train_dataset=train_dataset,
num_epochs=num_epochs,
per_device_batch_size=batch_size,
create_scheduler_fn=CosineLrScheduler.create_scheduler_fn(num_warmup_epochs=5),
)

By default, the total_num_epochs and num_iterations_per_epoch arguments will be set by the Trainer with the correct values at runtime.

Parameters:
  • total_num_epochs – the total number of training epochs, inclusive of any warmup and cooldown epochs

  • num_update_steps_per_epoch – the number of optimizer updates that take place per epoch

  • k_decay – adjusts the rate of annealing. Higher values will maintain a higher lr for longer

  • lr_min – the minimum value that the learning rate should decay to for all parameter groups. This will be held fixed during cooldown epochs

  • min_lr_ratio – this can be used to represent the minimum lr for each parameter group as a ratio of its maximum lr. If set, this will take precedence over lr_min

  • num_warmup_epochs – the number of epochs to gradually increase the lr until it reaches the maximum value

  • warmup_starting_lr – the starting lr that will be used for all parameter groups at the beginning of training if num_warmup_epochs is greater than 0

  • warmup_starting_lr_ratio – this can be used to represent the warmup starting lr for each parameter group as a ratio of its maximum lr. If set, this will take precedence over warmup_starting_lr

  • num_cooldown_epochs – the number of epochs to hold the lr at its minimum value

Returns:

a function which accepts an optimizer as an argument and returns an instance of CosineLrScheduler

get_updated_values(num_updates: int)[source]

Calculate the learning rate for a particular step given the number of previous updates.

If warmup epochs are specified, the learning rate will increase in constant increments from the warmup_starting_lr provided until the learning rate specified in the parameter group is reached.

If cooldown epochs are specified, the learning rate will be fixed at the minimum lr value given. This behaviour will continue if the scheduler is called after the training cycle has completed.

Between any warmup or cooldown epochs, the cosine annealing strategy will be used.

Parameters:

num_updates – the number of previous updates

Returns:

the learning rates with which to update each parameter group

load_state_dict(state_dict: dict)[source]

Updates the attributes of the given scheduler from the given state dict.

Parameters:

state_dict – the state dict to be loaded

state_dict()[source]

Get the state dict for the scheduler, containing all attributes except the optimizer, which should be saved separately.

Returns:

the scheduler’s state dict

class pytorch_accelerated.schedulers.wsd_scheduler.WSDLrScheduler(optimizer: torch.optim.Optimizer, num_epochs: int = None, num_update_steps_per_epoch: int = None, total_steps: int = None, num_warmup_steps: int = 0, decay_phase_ratio: float = 0.1, lr_min: float = 1e-06, warmup_starting_lr: float = 1e-06, use_inverse_sqrt_decay: bool = True, num_checkpoints: int = 1, is_continuation_from_checkpoint: bool = False)[source]

Bases: StatefulSchedulerBase

Implements the Warmup-Stable-Decay (WSD) Simplified learning rate schedule as described in Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective.

The schedule has three phases:
  1. Warmup: Linear warmup from warmup_starting_lr to base learning rate

  2. Stable: Maintains constant high learning rate

  3. Decay: Rapidly decays learning rate before each checkpoint

This scheduler is designed to create intermediate model checkpoints during training. Each checkpoint involves decaying the learning rate to get better model performance.

Use multiple checkpoints (typically 2-3) if:
  • Training on large datasets (>100B tokens) where intermediate models are useful for development/testing

  • You want to evaluate model performance vs training data size (e.g., does your model need full training?)

  • You might need to continue training later but want flexibility about when to stop training

The scheduler uses geometric progression to space checkpoints evenly on a log scale:
  • First checkpoint is placed at 25% of total steps

  • Each subsequent checkpoint is ~2x steps from previous checkpoint

Examples:
  • 2 checkpoints for 100K steps: [50K, 100K]

  • 3 checkpoints for 200K steps: [50K, 100K, 200K]

  • 4 checkpoints for 200K steps: [25K, 50K, 100K, 200K]

For each checkpoint:
  • The stable phase continues until decay_phase_ratio portion of steps remain

  • Then learning rate decays to lr_min * base_lr using selected decay formula

Two decay formulas are provided:

  1. Inverse Proportional Decay (paper’s formula):
    lr = 1 / (t * (1/lr_min - 1) + 1)
    • Derived from theoretical analysis on quadratic functions

    • Steeper initial decay, more gradual approach to lr_min

    • Optimal for quadratic loss landscapes

  2. Sqrt Decay:
    lr = lr_min + (1 - lr_min) * (1 - sqrt(t))
    • Similar to traditional cosine decay patterns

    • More gradual initial decay, consistent decay rate

    • May be more robust across different architectures

Continuation Behavior:
  • Training can be continued from a pre-decay (WSD) or post-decay (WSD-S) checkpoint

  • When continuing, scheduler starts a fresh stable phase with new total_steps

  • Decay phase ratio applies to new training length

  • No warmup is applied during continuation

  • State must be loaded via load_state_dict for continuation to work

Example:
Initial run (1000 steps, 0.1 decay ratio):
  • Steps 0-50: Optional warmup

  • Steps 50-900: Stable high learning rate

  • Steps 900-1000: Decay to lr_min

Continuation (500 new steps, 0.1 decay ratio):
  • Steps 0-450: Stable high learning rate

  • Steps 450-500: Decay to lr_min

Note

This scheduler is designed to be used with the WSDCheckpointCallback class, which handles saving and loading checkpoints.

__init__(optimizer: torch.optim.Optimizer, num_epochs: int = None, num_update_steps_per_epoch: int = None, total_steps: int = None, num_warmup_steps: int = 0, decay_phase_ratio: float = 0.1, lr_min: float = 1e-06, warmup_starting_lr: float = 1e-06, use_inverse_sqrt_decay: bool = True, num_checkpoints: int = 1, is_continuation_from_checkpoint: bool = False)[source]

Create a new WSDLrScheduler object which can be used to modify the learning rate in an optimizer’s parameter groups.

Parameters:
  • optimizer – PyTorch optimizer

  • num_epochs – Total number of training epochs

  • num_update_steps_per_epoch – The number of update steps per epoch per process

  • num_warmup_steps – Number of warmup steps. If None is passed, this will be set to 10% of the total steps

  • total_steps – Total number of training steps per process

  • decay_phase_ratio – Fraction of steps to use for decay before each checkpoint

  • lr_min – The minimum learning rate as a fraction of the base learning rate. For example, 0.1 means decay to 1% of the base learning rate.

  • warmup_starting_lr – Starting learning rate for warmup

  • use_inverse_sqrt_decay – Whether to use a more gradual sqrt decay

  • num_checkpoints – Number of checkpoints to use

  • is_continuation_from_checkpoint – If True, indicates this is a continuation run from a previous checkpoint. The scheduler will start a fresh stable phase with new total_steps.

Note

For continuation of training:
  • State must be loaded via load_state_dict before training

  • New training segment starts fresh stable phase

  • Decay ratio applies to new total_steps

  • No warmup is applied

classmethod create_scheduler_fn(total_num_epochs: int = TrainerPlaceholderValues.NUM_EPOCHS, num_update_steps_per_epoch: int = TrainerPlaceholderValues.PER_PROCESS_NUM_UPDATE_STEPS_PER_EPOCH, num_warmup_epochs: int = None, decay_phase_ratio: float = 0.1, lr_min: float = 1e-06, warmup_starting_lr: float = 1e-06, use_inverse_sqrt_decay: bool = True, num_checkpoints: int = 1, is_continuation_from_checkpoint: bool = False) Callable[source]

An alternative constructor which returns a function that accepts an optimizer and creates an instance of WSDLrScheduler. This is primarily intended to be used with the Trainer as illustrated below:

trainer = Trainer(
...,
callbacks=[
    WSDCheckpointCallback(
        save_dir="checkpoints",
        initial_checkpoint="checkpoint_45000_pre_decay.pt",
        )
    ],)

trainer.train(
train_dataset=train_dataset,
num_epochs=num_epochs,
per_device_batch_size=batch_size,
create_scheduler_fn=CosineLrScheduler.WSDLrScheduler(is_continuation_from_checkpoint=True),
)

By default, the total_num_epochs and num_iterations_per_epoch arguments will be set by the Trainer with the correct values at runtime. if the number of warmup epochs is not set, this will be set to 10% of the total steps

get_checkpoint_steps() List[int][source]

Return the list of steps at which checkpoints occur. Useful for training loop coordination.

get_current_phase_info() dict[source]

Get phase information for the current step.

Returns:
dict: Phase information containing period_start, period_end,

decay_steps, and pre_decay_step for current position

get_current_step() int[source]

Get the current step count of the scheduler.

Returns:

int: The current number of optimizer updates completed

get_decay_info() List[dict][source]

Get information about decay phases for all checkpoint periods.

Returns:
List[dict]: List of dicts containing for each period:
  • period_start: Start of period

  • period_end: End of period (checkpoint)

  • decay_steps: Number of steps in decay phase

  • pre_decay_step: Step before decay phase starts

get_phase_info(num_updates: int) dict[source]

Return full information about the current phase given the training step.

get_updated_values(num_updates: int)[source]

Calculate updated values for the scheduled parameter.

If a single value is returned, all parameter groups will be updated with this value.

To update each parameter group with a different value, an iterable collection, containing an updated value for each parameter group, should be returned.

If None is returned, the parameter groups will not be updated.

Parameters:

num_updates – the number of optimizer updates

Returns:

the updated values of the scheduled parameter. This should be either a single value, or an iterable collection containing a value for each parameter group.

load_state_dict(state_dict: dict)[source]

Updates the attributes of the given scheduler from the given state dict.

Parameters:

state_dict – the state dict to be loaded

state_dict()[source]

Get the state dict for the scheduler, containing all attributes except the optimizer, which should be saved separately.

Returns:

the scheduler’s state dict

Base Schedulers

PyTorch-accelerated provides base classes for two types of schedulers.

Stateful Schedulers

Stateful schedulers maintain an internal count corresponding to how many times the scheduler’s step() method has beeen called. As these schedulers have the same interface as the native PyTorch schedulers, these are supported by the Trainer by default.

class pytorch_accelerated.schedulers.scheduler_base.StatefulSchedulerBase(optimizer, param_group_field: str = 'lr')[source]

A stateful parameter scheduler base class that can be used to update any field within an optimizer’s parameter groups. The most common use case for this is learning rate scheduling.

Unlike PyTorch’s schedulers, which can be called at different points in the training loop depending on the implementation, this class is intended to be consistently called at the end of each optimizer update.

This class is responsible for maintaining the number of updates, incrementing an internal count each time that the scheduler step is calculated.

The usage of this class is illustrated below:

for current_epoch, epoch in enumerate(num_epochs):
    for batch in train_dataloader:
        xb, yb = batch
        predictions = model(xb)
        loss = loss_func(predictions, yb)

        loss.backward()
        optimizer.step()

        scheduler.step()
__init__(optimizer, param_group_field: str = 'lr')[source]

Create a new instance of a stateful parameter scheduler.

Parameters:
  • optimizer – a PyTorch optimizer

  • param_group_field – the field in the optimizer’s parameter groups corresponding to the parameter to be scheduled

abstractmethod load_state_dict(scheduler_state_dict: dict)[source]

Updates the attributes of the given scheduler from the given state dict.

Parameters:

state_dict – the state dict to be loaded

abstractmethod state_dict()[source]

Get the state dict for the scheduler, containing all attributes except the optimizer, which should be saved separately.

Returns:

the scheduler’s state dict

step()[source]

Calculate the updated value of the scheduled parameter and update the optimizer’s parameter groups.

Stateless Schedulers

These schedulers maintain no internal state about the current training run, and therefore require that the current number of updates is explicitly provided when called. To use a stateless scheduler with the Trainer, this would require subclassing the Trainer and overriding the scheduler_step() method.

class pytorch_accelerated.schedulers.scheduler_base.SchedulerBase(optimizer: torch.optim.Optimizer, param_group_field: str = 'lr')[source]

A parameter scheduler base class that can be used to update any field within an optimizer’s parameter groups. The most common use case for this is learning rate scheduling.

Unlike PyTorch’s schedulers, which can be called at different points in the training loop depending on the implementation, this class is intended to be consistently called at the end of each optimizer update.

As this class is stateless by default, it expects that the number of updates is explicitly provided, as illustrated below:

for current_epoch, epoch in enumerate(num_epochs):
    num_updates = current_epoch * num_update_steps_per_epoch
    for batch in train_dataloader:
        xb, yb = batch
        predictions = model(xb)
        loss = loss_func(predictions, yb)

        loss.backward()
        optimizer.step()

        num_updates +=1
        scheduler.step_update(num_updates)
__init__(optimizer: torch.optim.Optimizer, param_group_field: str = 'lr')[source]

Create a new instance of a parameter scheduler.

Parameters:
  • optimizer – a PyTorch optimizer

  • param_group_field – the field in the optimizer’s parameter groups corresponding to the parameter to be scheduled

abstractmethod get_updated_values(num_updates: int) None | Number | Iterable[Number][source]

Calculate updated values for the scheduled parameter.

If a single value is returned, all parameter groups will be updated with this value.

To update each parameter group with a different value, an iterable collection, containing an updated value for each parameter group, should be returned.

If None is returned, the parameter groups will not be updated.

Parameters:

num_updates – the number of optimizer updates

Returns:

the updated values of the scheduled parameter. This should be either a single value, or an iterable collection containing a value for each parameter group.

load_state_dict(state_dict)[source]

Updates the attributes of the given scheduler from the given state dict.

Parameters:

state_dict – the state dict to be loaded

state_dict()[source]

Get the state dict for the scheduler, containing all attributes except the optimizer, which should be saved separately.

Returns:

the scheduler’s state dict

step_update(num_updates: int)[source]

Calculate the updated value of the scheduled parameter and update the optimizer’s parameter groups.

Parameters:

num_updates – the number of optimizer updates

Creating New Schedulers

Whilst schedulers are usually used to schedule learning rates, the scheduler base classes in PyTorch-accelerated can be used to schedule any parameter in an optimizer’s parameter group.

To create a new scheduler, in most cases, all that is required is to subclass one of the base classes and override the get_updated_values() method.

Example: Creating a simple milestone lr scheduler

Here is an example of how we can implement a scheduler to adjust the learning rate for each parameter group by a factor gamma each time an epoch milestone is reached:

from pytorch_accelerated.schedulers import StatefulSchedulerBase

class MilestoneLrScheduler(StatefulSchedulerBase):
    def __init__(
        self, optimizer, gamma=0.5, epoch_milestones=(2, 4, 5), num_steps_per_epoch=100
    ):
        super().__init__(optimizer, param_group_field="lr")
        self.milestones = set(
            (num_steps_per_epoch * milestone for milestone in epoch_milestones)
        )
        self.gamma = gamma

    def get_updated_values(self, num_updates: int):
        if num_updates in self.milestones:
            lr_values = [
                group[self.param_group_field] for group in self.optimizer.param_groups
            ]
            updated_lrs = [lr * self.gamma for lr in lr_values]
            return updated_lrs

Example: Scheduling weight decay

Here is an example of how we can define a scheduler to incrementally increase the amount of weight decay by a factor gamma every n steps:

from pytorch_accelerated.schedulers import StatefulSchedulerBase

class StepWdScheduler(StatefulSchedulerBase):
    def __init__(self, optimizer, n=1000, gamma=1.1):
        super().__init__(optimizer, param_group_field="weight_decay")
        self.n = n
        self.gamma = gamma

    def get_updated_values(self, num_updates: int):
        if num_updates % self.n == 0 and num_updates > 0:
            wd_values = [
                group[self.param_group_field] for group in self.optimizer.param_groups
            ]
            updated_wd_values = [wd * self.gamma for wd in wd_values]
            return updated_wd_values