Utils

Utils

class pytorch_accelerated.utils.ModelEma(*args: Any, **kwargs: Any)[source]

Maintains a moving average of everything in the model state_dict (parameters and buffers), based on the ideas from https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage.

This class maintains a copy of the model that we are training. However, rather than updating all of the parameters of this model after every update step, we set these parameters using a linear combination of the existing parameter values and the updated values

Note

It is important to note that this class is sensitive to where it is initialised. During distributed training, it should be applied before before the conversion to SyncBatchNorm takes place and before the torch.nn.parallel.DistributedDataParallel wrapper is used!

class pytorch_accelerated.utils.DataLoaderSlice(dl, slice_size)[source]

A class which can be used to slice a DataLoader to only return a certain number of batches.

class pytorch_accelerated.utils.LimitBatches(num_batches: int)[source]

A context manager which can be used to limit the batches used within a Trainer. Any Trainer initialised within this context manager will contain the LimitBatchesCallback callback. To remove this behaviour, a new trainer must be created or this callback must be explicitly removed.

This will be automatically applied by the trainer if the environment variable PT_ACC_LIMIT_BATCHES is set.

Process Management

pytorch_accelerated.utils.local_process_zero_only(func)[source]

A decorator which can be used to ensure that the decorated function is only executed on the local main process during distributed training

Parameters:

func – the function to be decorated

pytorch_accelerated.utils.local_process_zero_first(func)[source]

A decorator which can be used to ensure that the decorated function is executed on the local main process first during distributed training

Parameters:

func – the function to be decorated

pytorch_accelerated.utils.world_process_zero_only(func)[source]

A decorator which can be used to ensure that the decorated function is only executed on the global main process during distributed training

Parameters:

func – the function to be decorated