""" module to freeze/unfreeze parameters by name """ import logging import re from axolotl.utils.distributed import is_main_process LOG = logging.getLogger("axolotl.utils.freeze") def freeze_parameters_except(model, regex_patterns): """ Freezes all layers of the given model except for the layers that match given regex patterns. Periods in the patterns are treated as literal periods, not as wildcard characters. Parameters: - model (nn.Module): The PyTorch model to be modified. - regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen. Returns: None; the model is modified in place. """ # Escape periods and compile the regex patterns compiled_patterns = [ re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns ] # First, freeze all parameters in the model for param in model.parameters(): param.requires_grad = False # Unfreeze layers that match the regex patterns for name, param in model.named_parameters(): if any(pattern.match(name) for pattern in compiled_patterns): if is_main_process(): LOG.debug(f"unfreezing {name}") param.requires_grad = True