Source code for pytorch_accelerated.finetuning

from collections import namedtuple, defaultdict
from typing import List

import torch

BN_MODULES = (
    torch.nn.BatchNorm1d,
    torch.nn.BatchNorm2d,
    torch.nn.BatchNorm3d,
    torch.nn.SyncBatchNorm,
)

Layer = namedtuple("Layer", ["layer_group_idx", "module", "is_frozen"])
LayerGroup = namedtuple("LayerGroup", ["layer_group_idx", "module", "is_frozen"])


[docs]class ModelFreezer: """ A class to freeze and unfreeze different parts of a model, to simplify the process of fine-tuning during transfer learning. This class uses the following abstractions: - `Layer`: A subclass of :class:`torch.nn.Module` with a depth of 1. i.e. The module is not nested. - `LayerGroup`: The modules which have been defined as attributes of a model. These may be Layers or nested modules. For example, let's consider the following model:: from torch import nn class MyModel(nn.Module): def __init__(self): super().__init__() self.input = nn.Linear(100, 100) self.block_1 = nn.Sequential( nn.Linear(100, 100), nn.BatchNorm1d(100), nn.Sequential( nn.Linear(100, 100), nn.BatchNorm1d(100), nn.ReLU(), ), ) self.output = nn.Linear(100, 10) def forward(self, x): x = self.input(x) x = self.block_1(x) out = self.output(x) return out Here, the layer groups would be the modules [`input`, `block_1`, `output`], whereas the layers would be ordered, flattened list of Linear, BatchNorm and ReLU modules. """ def __init__(self, model, freeze_batch_norms=False): """ Create a new ModelFreezer instance, which can be used to freeze and unfreeze all, or parts, or a model. When a model is passed to a ModelFreezer instance, all parameters will be unfrozen regardless of their previous state. Subsequent freezing/unfreezing should be done using this instance. :param model: The model to freeze/unfreeze. This should be a subclass of :class:`torch.nn.Module` :param freeze_batch_norms: Whether to freeze BatchNorm layers, during freezing. By default, BatchNorm layers are left unfrozen. """ self.model = model _set_requires_grad(model.parameters(), value=True) self._layer_groups, self._layers = _get_layer_groups_for_module(model) self.freeze_bn = freeze_batch_norms self.num_groups = len(self._layer_groups)
[docs] def get_layer_groups(self) -> List[LayerGroup]: """ Return all of the model's layer groups. A layer group is any module which has been defined as an attribute of the model. :return: a list of all layer groups in the model. """ layer_groups = [] for group_idx, layer_group_module in self._layer_groups.items(): params = { not param.requires_grad for param in layer_group_module.parameters() } frozen = True if True in params else False layer_groups.append( LayerGroup( (group_idx, group_idx - self.num_groups), layer_group_module, frozen ) ) return layer_groups
[docs] def get_layers(self) -> List[Layer]: """ Return all of the model's layers. A Layer is any non-nested module which is included in the model. :return: a list of all layers in the model. """ layers = [] for group_idx, layer in self._layers: frozen_status = { param.requires_grad for param in layer.parameters(recurse=False) } if len(frozen_status) > 1: raise ValueError elif len(frozen_status) == 1: layers.append( Layer( (group_idx, group_idx - self.num_groups), layer, not list(frozen_status)[0], ) ) else: # layer has no parameters layers.append( Layer( (group_idx, group_idx - self.num_groups), layer, not layer.training, ) ) return layers
[docs] def get_trainable_parameters(self): """ Return a list of all unfrozen model parameters, which will be updated during training. :return: a list of all trainable parameters """ return [param for param in self.model.parameters() if param.requires_grad]
[docs] def freeze(self, from_index=0, to_index=-2, set_modules_as_eval=False): """ Freeze layer groups corresponding to the specified indexes, which are inclusive. By default, this freezes all layer groups except the final one. :param from_index: The index of the first layer group to freeze. :param to_index: The index of the final layer group to freeze. :param set_modules_as_eval: If True, frozen modules will also be placed in `eval` mode. This is False by default. """ self.__freeze_unfreeze( from_index, to_index, freeze=True, toggle_train_eval=set_modules_as_eval )
[docs] def unfreeze(self, from_index=-1, to_index=0, set_modules_as_training=True): """ Unfreeze layer groups corresponding to the specified indexes, which are inclusive. By default, this unfreezes all layer groups. For each layer group, any parameters which have been unfrozen are returned, so that they can be added to an optimizer if needed. :param from_index: The index of the first layer group to unfreeze. :param to_index: The index of the final layer group to unfreeze. :param set_modules_as_training: If True, unfrozen modules will also be placed in `train` mode. This is True by default. :return: a dictionary containing the parameters which have been unfrozen for each layer group. """ unfrozen_params = self.__freeze_unfreeze( from_index, to_index, freeze=False, toggle_train_eval=set_modules_as_training, ) return unfrozen_params
def __freeze_unfreeze( self, from_layer_group_index, to_layer_group_index, freeze=True, toggle_train_eval=True, ): modified_parameters = defaultdict(list) set_grad_value = not freeze layers = self.get_layers() from_layer_group_index, to_layer_group_index = self._convert_idxs( from_layer_group_index, to_layer_group_index ) for layer in layers: if layer.layer_group_idx[0] < from_layer_group_index: continue elif layer.layer_group_idx[0] > to_layer_group_index: break else: if layer.is_frozen == freeze: # layer already in correct state continue else: is_batch_norm = _module_is_batch_norm(layer.module) if is_batch_norm and not self.freeze_bn: continue else: params = _change_layer_state( layer, toggle_train_eval, set_grad_value ) if params: modified_parameters[layer.layer_group_idx[0]].extend(params) return { layer_group_idx: {"params": params} for layer_group_idx, params in modified_parameters.items() } def _convert_idxs(self, from_idx, to_idx): from_idx = _convert_idx(from_idx, self.num_groups) to_idx = _convert_idx(to_idx, self.num_groups) if from_idx > to_idx: from_idx, to_idx = to_idx, from_idx return from_idx, to_idx
def _change_layer_state(layer: Layer, set_grad_value: bool, toggle_train_eval: bool): params = list(layer.module.parameters()) if params: _set_requires_grad(params, value=set_grad_value) if toggle_train_eval: layer.module.train(mode=set_grad_value) return params def _module_is_batch_norm(module): return isinstance(module, BN_MODULES) def _convert_idx(idx, num_groups): if idx < 0: idx = idx + num_groups return idx def _set_requires_grad(parameters, value=True): for param in parameters: param.requires_grad = value def _get_layer_groups_for_module(module): layers = [] layer_groups = dict() for layer_group, group in enumerate(module.children()): _recursive_get_layers(group, layers, layer_group) layer_groups[layer_group] = group return layer_groups, layers def _recursive_get_layers(module, result, layer_group=0): children = list(module.children()) if not children: # is leaf result.append((layer_group, module)) else: # is nested for child in children: _recursive_get_layers(child, result, layer_group)