Utils
Utils
- class pytorch_accelerated.utils.ModelEma(model, decay=0.9999)[source]
Maintains a moving average of everything in the model state_dict (parameters and buffers), based on the ideas from https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage.
This class maintains a copy of the model that we are training. However, rather than updating all of the parameters of this model after every update step, we set these parameters using a linear combination of the existing parameter values and the updated values
Note
It is important to note that this class is sensitive to where it is initialised. During distributed training, it should be applied before before the conversion to
SyncBatchNorm
takes place and before thetorch.nn.parallel.DistributedDataParallel
wrapper is used!