File size: 1,245 Bytes
5ea3aa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
module to freeze/unfreeze parameters by name
"""
import logging
import re

from axolotl.utils.distributed import is_main_process

LOG = logging.getLogger("axolotl.utils.freeze")


def freeze_parameters_except(model, regex_patterns):
    """
    Freezes all layers of the given model except for the layers that match given regex patterns.
    Periods in the patterns are treated as literal periods, not as wildcard characters.

    Parameters:
    - model (nn.Module): The PyTorch model to be modified.
    - regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.

    Returns:
    None; the model is modified in place.
    """
    # Escape periods and compile the regex patterns
    compiled_patterns = [
        re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
    ]

    # First, freeze all parameters in the model
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze layers that match the regex patterns
    for name, param in model.named_parameters():
        if any(pattern.match(name) for pattern in compiled_patterns):
            if is_main_process():
                LOG.debug(f"unfreezing {name}")
            param.requires_grad = True