Utils
Utils
- class pytorch_accelerated.utils.ModelEma(*args: Any, **kwargs: Any)[source]
Maintains a moving average of everything in the model state_dict (parameters and buffers), based on the ideas from https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage.
This class maintains a copy of the model that we are training. However, rather than updating all of the parameters of this model after every update step, we set these parameters using a linear combination of the existing parameter values and the updated values
Note
It is important to note that this class is sensitive to where it is initialised. During distributed training, it should be applied before before the conversion to
SyncBatchNormtakes place and before thetorch.nn.parallel.DistributedDataParallelwrapper is used!
- class pytorch_accelerated.utils.DataLoaderSlice(dl, slice_size)[source]
A class which can be used to slice a
DataLoaderto only return a certain number of batches.
- class pytorch_accelerated.utils.LimitBatches(num_batches: int)[source]
A context manager which can be used to limit the batches used within a
Trainer. Any Trainer initialised within this context manager will contain theLimitBatchesCallbackcallback. To remove this behaviour, a new trainer must be created or this callback must be explicitly removed.This will be automatically applied by the trainer if the environment variable
PT_ACC_LIMIT_BATCHESis set.
Process Management
- pytorch_accelerated.utils.local_process_zero_only(func)[source]
A decorator which can be used to ensure that the decorated function is only executed on the local main process during distributed training
- Parameters:
func – the function to be decorated