Utils

Utils

class pytorch_accelerated.utils.ModelEma(model, decay=0.9999)[source]

Maintains a moving average of everything in the model state_dict (parameters and buffers), based on the ideas from https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage.

This class maintains a copy of the model that we are training. However, rather than updating all of the parameters of this model after every update step, we set these parameters using a linear combination of the existing parameter values and the updated values

Note

It is important to note that this class is sensitive to where it is initialised. During distributed training, it should be applied before before the conversion to SyncBatchNorm takes place and before the torch.nn.parallel.DistributedDataParallel wrapper is used!