Schedulers
PyTorch-accelerated provides some scheduler implementations which can be used in any PyTorch training loop. However, unlike PyTorch’s native schedulers - which can be called at different points in the training loop - all Pytorch-accelerated schedulers expect to be called after each optimizer update.
Implemented Schedulers
- class pytorch_accelerated.schedulers.cosine_scheduler.CosineLrScheduler(optimizer: torch.optim.Optimizer, total_num_epochs: int, num_update_steps_per_epoch: int, k_decay=1.0, lr_min: float = 1e-06, min_lr_ratio=None, num_warmup_epochs: int = 0, warmup_starting_lr=1e-06, warmup_starting_lr_ratio=None, num_cooldown_epochs=0)[source]
Bases:
StatefulSchedulerBaseA stateful Cosine annealing learning rate scheduler, as described in this paper, but without restarts.
This scheduler differs from the PyTorch’s
CosineAnnealingLRas it provides options to add warmup and cooldown epochs. Additionally, the annealing rate can be modified by adjusting the k-decay parameter, for which the rate of change of the learning rate is changed by its k-th order derivative, as described in here.If warmup epochs are specified, the learning rate will increase in constant increments from the
warmup_starting_lrprovided until the learning rate specified in the parameter group is reached.If cooldown epochs are specified, the learning rate will be fixed at the minimum lr value given. This behaviour will continue if the scheduler is called after the training cycle has completed.
- __init__(optimizer: torch.optim.Optimizer, total_num_epochs: int, num_update_steps_per_epoch: int, k_decay=1.0, lr_min: float = 1e-06, min_lr_ratio=None, num_warmup_epochs: int = 0, warmup_starting_lr=1e-06, warmup_starting_lr_ratio=None, num_cooldown_epochs=0)[source]
Create a new ConsineLrScheduler object which can be used to modify the learning rate in an optimizer’s parameter groups.
- Parameters:
optimizer – a PyTorch optimizer containing one or more parameter groups
total_num_epochs – the total number of training epochs, inclusive of any warmup and cooldown epochs
num_update_steps_per_epoch – the number of optimizer updates that take place per epoch
k_decay – adjusts the rate of annealing. Higher values will maintain a higher lr for longer
lr_min – the minimum value that the learning rate should decay to for all parameter groups. This will be held fixed during cooldown epochs
min_lr_ratio – this can be used to represent the minimum lr for each parameter group as a ratio of its maximum lr. If set, this will take precedence over
lr_minnum_warmup_epochs – the number of epochs to gradually increase the lr until it reaches the maximum value
warmup_starting_lr – the starting lr that will be used for all parameter groups at the beginning of training if
num_warmup_epochsis greater than 0warmup_starting_lr_ratio – this can be used to represent the warmup starting lr for each parameter group as a ratio of its maximum lr. If set, this will take precedence over
warmup_starting_lrnum_cooldown_epochs – the number of epochs to hold the lr at its minimum value
- classmethod create_scheduler_fn(total_num_epochs: int = TrainerPlaceholderValues.NUM_EPOCHS, num_update_steps_per_epoch: int = TrainerPlaceholderValues.NUM_UPDATE_STEPS_PER_EPOCH, k_decay=1.0, lr_min: float = 1e-06, min_lr_ratio=None, num_warmup_epochs: int = 0, warmup_starting_lr=1e-06, warmup_starting_lr_ratio=None, num_cooldown_epochs=0) Callable[source]
An alternative constructor which returns a function that accepts an optimizer and creates an instance of
CosineLrScheduler. This is primarily intended to be used with theTraineras illustrated below:trainer.train( train_dataset=train_dataset, num_epochs=num_epochs, per_device_batch_size=batch_size, create_scheduler_fn=CosineLrScheduler.create_scheduler_fn(num_warmup_epochs=5), )
By default, the
total_num_epochsandnum_iterations_per_epocharguments will be set by theTrainerwith the correct values at runtime.- Parameters:
total_num_epochs – the total number of training epochs, inclusive of any warmup and cooldown epochs
num_update_steps_per_epoch – the number of optimizer updates that take place per epoch
k_decay – adjusts the rate of annealing. Higher values will maintain a higher lr for longer
lr_min – the minimum value that the learning rate should decay to for all parameter groups. This will be held fixed during cooldown epochs
min_lr_ratio – this can be used to represent the minimum lr for each parameter group as a ratio of its maximum lr. If set, this will take precedence over
lr_minnum_warmup_epochs – the number of epochs to gradually increase the lr until it reaches the maximum value
warmup_starting_lr – the starting lr that will be used for all parameter groups at the beginning of training if
num_warmup_epochsis greater than 0warmup_starting_lr_ratio – this can be used to represent the warmup starting lr for each parameter group as a ratio of its maximum lr. If set, this will take precedence over
warmup_starting_lrnum_cooldown_epochs – the number of epochs to hold the lr at its minimum value
- Returns:
a function which accepts an optimizer as an argument and returns an instance of
CosineLrScheduler
- get_updated_values(num_updates: int)[source]
Calculate the learning rate for a particular step given the number of previous updates.
If warmup epochs are specified, the learning rate will increase in constant increments from the
warmup_starting_lrprovided until the learning rate specified in the parameter group is reached.If cooldown epochs are specified, the learning rate will be fixed at the minimum lr value given. This behaviour will continue if the scheduler is called after the training cycle has completed.
Between any warmup or cooldown epochs, the cosine annealing strategy will be used.
- Parameters:
num_updates – the number of previous updates
- Returns:
the learning rates with which to update each parameter group
- class pytorch_accelerated.schedulers.wsd_scheduler.WSDLrScheduler(optimizer: torch.optim.Optimizer, num_epochs: int = None, num_update_steps_per_epoch: int = None, total_steps: int = None, num_warmup_steps: int = 0, decay_phase_ratio: float = 0.1, lr_min: float = 1e-06, warmup_starting_lr: float = 1e-06, use_inverse_sqrt_decay: bool = True, num_checkpoints: int = 1, is_continuation_from_checkpoint: bool = False)[source]
Bases:
StatefulSchedulerBaseImplements the Warmup-Stable-Decay (WSD) Simplified learning rate schedule as described in Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective.
- The schedule has three phases:
Warmup: Linear warmup from warmup_starting_lr to base learning rate
Stable: Maintains constant high learning rate
Decay: Rapidly decays learning rate before each checkpoint
This scheduler is designed to create intermediate model checkpoints during training. Each checkpoint involves decaying the learning rate to get better model performance.
- Use multiple checkpoints (typically 2-3) if:
Training on large datasets (>100B tokens) where intermediate models are useful for development/testing
You want to evaluate model performance vs training data size (e.g., does your model need full training?)
You might need to continue training later but want flexibility about when to stop training
- The scheduler uses geometric progression to space checkpoints evenly on a log scale:
First checkpoint is placed at 25% of total steps
Each subsequent checkpoint is ~2x steps from previous checkpoint
- Examples:
2 checkpoints for 100K steps: [50K, 100K]
3 checkpoints for 200K steps: [50K, 100K, 200K]
4 checkpoints for 200K steps: [25K, 50K, 100K, 200K]
- For each checkpoint:
The stable phase continues until decay_phase_ratio portion of steps remain
Then learning rate decays to lr_min * base_lr using selected decay formula
Two decay formulas are provided:
- Inverse Proportional Decay (paper’s formula):
- lr = 1 / (t * (1/lr_min - 1) + 1)
Derived from theoretical analysis on quadratic functions
Steeper initial decay, more gradual approach to lr_min
Optimal for quadratic loss landscapes
- Sqrt Decay:
- lr = lr_min + (1 - lr_min) * (1 - sqrt(t))
Similar to traditional cosine decay patterns
More gradual initial decay, consistent decay rate
May be more robust across different architectures
- Continuation Behavior:
Training can be continued from a pre-decay (WSD) or post-decay (WSD-S) checkpoint
When continuing, scheduler starts a fresh stable phase with new total_steps
Decay phase ratio applies to new training length
No warmup is applied during continuation
State must be loaded via load_state_dict for continuation to work
- Example:
- Initial run (1000 steps, 0.1 decay ratio):
Steps 0-50: Optional warmup
Steps 50-900: Stable high learning rate
Steps 900-1000: Decay to lr_min
- Continuation (500 new steps, 0.1 decay ratio):
Steps 0-450: Stable high learning rate
Steps 450-500: Decay to lr_min
Note
This scheduler is designed to be used with the
WSDCheckpointCallbackclass, which handles saving and loading checkpoints.- __init__(optimizer: torch.optim.Optimizer, num_epochs: int = None, num_update_steps_per_epoch: int = None, total_steps: int = None, num_warmup_steps: int = 0, decay_phase_ratio: float = 0.1, lr_min: float = 1e-06, warmup_starting_lr: float = 1e-06, use_inverse_sqrt_decay: bool = True, num_checkpoints: int = 1, is_continuation_from_checkpoint: bool = False)[source]
Create a new WSDLrScheduler object which can be used to modify the learning rate in an optimizer’s parameter groups.
- Parameters:
optimizer – PyTorch optimizer
num_epochs – Total number of training epochs
num_update_steps_per_epoch – The number of update steps per epoch per process
num_warmup_steps – Number of warmup steps. If None is passed, this will be set to 10% of the total steps
total_steps – Total number of training steps per process
decay_phase_ratio – Fraction of steps to use for decay before each checkpoint
lr_min – The minimum learning rate as a fraction of the base learning rate. For example, 0.1 means decay to 1% of the base learning rate.
warmup_starting_lr – Starting learning rate for warmup
use_inverse_sqrt_decay – Whether to use a more gradual sqrt decay
num_checkpoints – Number of checkpoints to use
is_continuation_from_checkpoint – If True, indicates this is a continuation run from a previous checkpoint. The scheduler will start a fresh stable phase with new total_steps.
Note
- For continuation of training:
State must be loaded via load_state_dict before training
New training segment starts fresh stable phase
Decay ratio applies to new total_steps
No warmup is applied
- classmethod create_scheduler_fn(total_num_epochs: int = TrainerPlaceholderValues.NUM_EPOCHS, num_update_steps_per_epoch: int = TrainerPlaceholderValues.PER_PROCESS_NUM_UPDATE_STEPS_PER_EPOCH, num_warmup_epochs: int = None, decay_phase_ratio: float = 0.1, lr_min: float = 1e-06, warmup_starting_lr: float = 1e-06, use_inverse_sqrt_decay: bool = True, num_checkpoints: int = 1, is_continuation_from_checkpoint: bool = False) Callable[source]
An alternative constructor which returns a function that accepts an optimizer and creates an instance of
WSDLrScheduler. This is primarily intended to be used with theTraineras illustrated below:trainer = Trainer( ..., callbacks=[ WSDCheckpointCallback( save_dir="checkpoints", initial_checkpoint="checkpoint_45000_pre_decay.pt", ) ],) trainer.train( train_dataset=train_dataset, num_epochs=num_epochs, per_device_batch_size=batch_size, create_scheduler_fn=CosineLrScheduler.WSDLrScheduler(is_continuation_from_checkpoint=True), )
By default, the
total_num_epochsandnum_iterations_per_epocharguments will be set by theTrainerwith the correct values at runtime. if the number of warmup epochs is not set, this will be set to 10% of the total steps
- get_checkpoint_steps() List[int][source]
Return the list of steps at which checkpoints occur. Useful for training loop coordination.
- get_current_phase_info() dict[source]
Get phase information for the current step.
- Returns:
- dict: Phase information containing period_start, period_end,
decay_steps, and pre_decay_step for current position
- get_current_step() int[source]
Get the current step count of the scheduler.
- Returns:
int: The current number of optimizer updates completed
- get_decay_info() List[dict][source]
Get information about decay phases for all checkpoint periods.
- Returns:
- List[dict]: List of dicts containing for each period:
period_start: Start of period
period_end: End of period (checkpoint)
decay_steps: Number of steps in decay phase
pre_decay_step: Step before decay phase starts
- get_phase_info(num_updates: int) dict[source]
Return full information about the current phase given the training step.
- get_updated_values(num_updates: int)[source]
Calculate updated values for the scheduled parameter.
If a single value is returned, all parameter groups will be updated with this value.
To update each parameter group with a different value, an iterable collection, containing an updated value for each parameter group, should be returned.
If None is returned, the parameter groups will not be updated.
- Parameters:
num_updates – the number of optimizer updates
- Returns:
the updated values of the scheduled parameter. This should be either a single value, or an iterable collection containing a value for each parameter group.
Base Schedulers
PyTorch-accelerated provides base classes for two types of schedulers.
Stateful Schedulers
Stateful schedulers maintain an internal count corresponding to how many times the scheduler’s
step() method has beeen called.
As these schedulers have the same interface as the native PyTorch schedulers, these are supported by the
Trainer by default.
- class pytorch_accelerated.schedulers.scheduler_base.StatefulSchedulerBase(optimizer, param_group_field: str = 'lr')[source]
A stateful parameter scheduler base class that can be used to update any field within an optimizer’s parameter groups. The most common use case for this is learning rate scheduling.
Unlike PyTorch’s schedulers, which can be called at different points in the training loop depending on the implementation, this class is intended to be consistently called at the end of each optimizer update.
This class is responsible for maintaining the number of updates, incrementing an internal count each time that the scheduler step is calculated.
The usage of this class is illustrated below:
for current_epoch, epoch in enumerate(num_epochs): for batch in train_dataloader: xb, yb = batch predictions = model(xb) loss = loss_func(predictions, yb) loss.backward() optimizer.step() scheduler.step()
- __init__(optimizer, param_group_field: str = 'lr')[source]
Create a new instance of a stateful parameter scheduler.
- Parameters:
optimizer – a PyTorch optimizer
param_group_field – the field in the optimizer’s parameter groups corresponding to the parameter to be scheduled
- abstractmethod load_state_dict(scheduler_state_dict: dict)[source]
Updates the attributes of the given scheduler from the given state dict.
- Parameters:
state_dict – the state dict to be loaded
Stateless Schedulers
These schedulers maintain no internal state about the current training run, and therefore require that the current
number of updates is explicitly provided when called. To use a stateless scheduler with the
Trainer, this would require subclassing the
Trainer and overriding the
scheduler_step() method.
- class pytorch_accelerated.schedulers.scheduler_base.SchedulerBase(optimizer: torch.optim.Optimizer, param_group_field: str = 'lr')[source]
A parameter scheduler base class that can be used to update any field within an optimizer’s parameter groups. The most common use case for this is learning rate scheduling.
Unlike PyTorch’s schedulers, which can be called at different points in the training loop depending on the implementation, this class is intended to be consistently called at the end of each optimizer update.
As this class is stateless by default, it expects that the number of updates is explicitly provided, as illustrated below:
for current_epoch, epoch in enumerate(num_epochs): num_updates = current_epoch * num_update_steps_per_epoch for batch in train_dataloader: xb, yb = batch predictions = model(xb) loss = loss_func(predictions, yb) loss.backward() optimizer.step() num_updates +=1 scheduler.step_update(num_updates)
- __init__(optimizer: torch.optim.Optimizer, param_group_field: str = 'lr')[source]
Create a new instance of a parameter scheduler.
- Parameters:
optimizer – a PyTorch optimizer
param_group_field – the field in the optimizer’s parameter groups corresponding to the parameter to be scheduled
- abstractmethod get_updated_values(num_updates: int) None | Number | Iterable[Number][source]
Calculate updated values for the scheduled parameter.
If a single value is returned, all parameter groups will be updated with this value.
To update each parameter group with a different value, an iterable collection, containing an updated value for each parameter group, should be returned.
If None is returned, the parameter groups will not be updated.
- Parameters:
num_updates – the number of optimizer updates
- Returns:
the updated values of the scheduled parameter. This should be either a single value, or an iterable collection containing a value for each parameter group.
- load_state_dict(state_dict)[source]
Updates the attributes of the given scheduler from the given state dict.
- Parameters:
state_dict – the state dict to be loaded
Creating New Schedulers
Whilst schedulers are usually used to schedule learning rates, the scheduler base classes in PyTorch-accelerated can be used to schedule any parameter in an optimizer’s parameter group.
To create a new scheduler, in most cases, all that is required is to subclass one of the base classes and override the
get_updated_values() method.
Example: Creating a simple milestone lr scheduler
Here is an example of how we can implement a scheduler to adjust the learning rate for each parameter group
by a factor gamma each time an epoch milestone is reached:
from pytorch_accelerated.schedulers import StatefulSchedulerBase
class MilestoneLrScheduler(StatefulSchedulerBase):
def __init__(
self, optimizer, gamma=0.5, epoch_milestones=(2, 4, 5), num_steps_per_epoch=100
):
super().__init__(optimizer, param_group_field="lr")
self.milestones = set(
(num_steps_per_epoch * milestone for milestone in epoch_milestones)
)
self.gamma = gamma
def get_updated_values(self, num_updates: int):
if num_updates in self.milestones:
lr_values = [
group[self.param_group_field] for group in self.optimizer.param_groups
]
updated_lrs = [lr * self.gamma for lr in lr_values]
return updated_lrs
Example: Scheduling weight decay
Here is an example of how we can define a scheduler to incrementally increase the amount of weight decay
by a factor gamma every n steps:
from pytorch_accelerated.schedulers import StatefulSchedulerBase
class StepWdScheduler(StatefulSchedulerBase):
def __init__(self, optimizer, n=1000, gamma=1.1):
super().__init__(optimizer, param_group_field="weight_decay")
self.n = n
self.gamma = gamma
def get_updated_values(self, num_updates: int):
if num_updates % self.n == 0 and num_updates > 0:
wd_values = [
group[self.param_group_field] for group in self.optimizer.param_groups
]
updated_wd_values = [wd * self.gamma for wd in wd_values]
return updated_wd_values