Source code for pytorch_accelerated.run_config

# Copyright © 2021 Chris Hughes
from dataclasses import dataclass, asdict
from numbers import Number
from typing import Union


[docs]@dataclass(frozen=True) class TrainerRunConfig: """ An immutable dataclass holding values representing the current state of the :class:`~pytorch_accelerated.trainer.Trainer` :param num_epochs: the number of epochs in the current training run :param train_per_device_batch_size: the device size per batch used during training epochs :param train_dl_kwargs: the arguments that have been used to create the training dataloader :param eval_per_device_batch_size: the device size per batch used during evaluation epochs :param eval_dl_kwargs: the arguments that have been used to create the evaluation dataloader :param gradient_accumulation_steps: the number of gradient accumulation steps which will be used during training :param gradient_clip_value: the value used to determine the threshold to clip the gradients of the model's parameters :param train_total_batch_size: the total batch size used during training :param eval_total_batch_size: the total batch size used during evaluation :param num_update_steps_per_epoch: the number of steps per training epoch where the model's parameters will be updated :param max_num_train_steps: the maximum number of steps to train for, if present, this will take precedence over ``num_epochs`` :param is_local_process_zero: ``True`` if the current process is the main process on the current node, ``False`` otherwise :param is_world_process_zero: ``True`` if the current process is the main process across all nodes, ``False`` otherwise :param is_distributed: ``True`` if the trainer is set up to perform distributed training, ``False`` otherwise :param mixed_precision: A string containing the type of mixed precision the trainer is set up to use, ``no`` otherwise :param num_processes: the number of processes being used during training """ num_epochs: int train_per_device_batch_size: int train_dl_kwargs: dict eval_per_device_batch_size: int eval_dl_kwargs: dict gradient_accumulation_steps: int gradient_clip_value: Union[Number, None] train_total_batch_size: int eval_total_batch_size: int num_update_steps_per_epoch: int max_num_train_steps: Union[int, None] is_local_process_zero: bool is_world_process_zero: bool is_distributed: bool mixed_precision: str num_processes: int def to_dict(self): # cannot pickle torch._C.Generators, so must remove before dict creation custom_train_sampler = "sampler" in self.train_dl_kwargs custom_eval_sampler = "sampler" in self.eval_dl_kwargs if custom_train_sampler: train_sampler = self.train_dl_kwargs.pop("sampler") if custom_eval_sampler: eval_sampler = self.eval_dl_kwargs.pop("sampler") run_config_dict = asdict(self) if custom_train_sampler: run_config_dict["train_dl_kwargs"]["sampler"] = train_sampler if custom_eval_sampler: run_config_dict["eval_dl_kwargs"]["sampler"] = eval_sampler return run_config_dict