winglian commited on
Commit
687d889
2 Parent(s): 13ac4d8 c4cf567

Merge pull request #271 from OpenAccess-AI-Collective/quadratic-warmup

Browse files
src/axolotl/utils/schedulers.py CHANGED
@@ -1,6 +1,9 @@
1
  """Module for custom LRScheduler class"""
 
 
2
 
3
- from torch.optim.lr_scheduler import LRScheduler
 
4
 
5
 
6
  class InterpolatingLogScheduler(LRScheduler):
@@ -42,3 +45,58 @@ class InterpolatingLogScheduler(LRScheduler):
42
  lrs = [self.max_lr for base_lr in self.base_lrs]
43
 
44
  return lrs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Module for custom LRScheduler class"""
2
+ import math
3
+ from functools import partial
4
 
5
+ from torch.optim import Optimizer
6
+ from torch.optim.lr_scheduler import LambdaLR, LRScheduler
7
 
8
 
9
  class InterpolatingLogScheduler(LRScheduler):
 
45
  lrs = [self.max_lr for base_lr in self.base_lrs]
46
 
47
  return lrs
48
+
49
+
50
+ def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
51
+ current_step: int,
52
+ *,
53
+ num_warmup_steps: int,
54
+ num_training_steps: int,
55
+ num_cycles: float
56
+ ):
57
+ if current_step < num_warmup_steps:
58
+ return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
59
+ progress = float(current_step - num_warmup_steps) / float(
60
+ max(1, num_training_steps - num_warmup_steps)
61
+ )
62
+ return max(
63
+ 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
64
+ )
65
+
66
+
67
+ def get_cosine_schedule_with_quadratic_warmup(
68
+ optimizer: Optimizer,
69
+ num_warmup_steps: int,
70
+ num_training_steps: int,
71
+ num_cycles: float = 0.5,
72
+ last_epoch: int = -1,
73
+ ):
74
+ """
75
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
76
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
77
+ initial lr set in the optimizer.
78
+
79
+ Args:
80
+ optimizer ([`~torch.optim.Optimizer`]):
81
+ The optimizer for which to schedule the learning rate.
82
+ num_warmup_steps (`int`):
83
+ The number of steps for the warmup phase.
84
+ num_training_steps (`int`):
85
+ The total number of training steps.
86
+ num_cycles (`float`, *optional*, defaults to 0.5):
87
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
88
+ following a half-cosine).
89
+ last_epoch (`int`, *optional*, defaults to -1):
90
+ The index of the last epoch when resuming training.
91
+
92
+ Return:
93
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
94
+ """
95
+
96
+ lr_lambda = partial(
97
+ _get_cosine_schedule_with_quadratic_warmup_lr_lambda,
98
+ num_warmup_steps=num_warmup_steps,
99
+ num_training_steps=num_training_steps,
100
+ num_cycles=num_cycles,
101
+ )
102
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
src/axolotl/utils/trainer.py CHANGED
@@ -5,6 +5,7 @@ import logging
5
  import math
6
  import os
7
  import sys
 
8
  from pathlib import Path
9
  from typing import Optional
10
 
@@ -13,17 +14,67 @@ import torch.cuda
13
  import transformers
14
  from torch import nn
15
  from torch.optim.lr_scheduler import OneCycleLR
16
- from transformers import EarlyStoppingCallback, Trainer
17
  from transformers.trainer_pt_utils import get_parameter_names
18
 
19
  from axolotl.utils.callbacks import (
20
  SaveBetterTransformerModelCallback,
21
  SavePeftModelCallback,
22
  )
23
- from axolotl.utils.schedulers import InterpolatingLogScheduler
 
 
 
24
 
25
 
26
- class OneCycleLRSchedulerTrainer(Trainer):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
  Trainer subclass that uses the OneCycleLR scheduler
29
  """
@@ -103,6 +154,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
103
  if cfg.fsdp_config:
104
  training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
105
 
 
 
 
106
  # deepspeed
107
  if (
108
  os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
@@ -128,7 +182,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
128
  training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
129
  training_arguments_kwargs["push_to_hub"] = True
130
 
131
- training_args = transformers.TrainingArguments(
132
  per_device_train_batch_size=cfg.micro_batch_size,
133
  per_device_eval_batch_size=cfg.eval_batch_size
134
  if cfg.eval_batch_size is not None
@@ -278,7 +332,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
278
  trainer_cls = (
279
  OneCycleLRSchedulerTrainer
280
  if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
281
- else transformers.Trainer
282
  )
283
  trainer = trainer_cls(
284
  model=model,
 
5
  import math
6
  import os
7
  import sys
8
+ from dataclasses import field
9
  from pathlib import Path
10
  from typing import Optional
11
 
 
14
  import transformers
15
  from torch import nn
16
  from torch.optim.lr_scheduler import OneCycleLR
17
+ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
18
  from transformers.trainer_pt_utils import get_parameter_names
19
 
20
  from axolotl.utils.callbacks import (
21
  SaveBetterTransformerModelCallback,
22
  SavePeftModelCallback,
23
  )
24
+ from axolotl.utils.schedulers import (
25
+ InterpolatingLogScheduler,
26
+ get_cosine_schedule_with_quadratic_warmup,
27
+ )
28
 
29
 
30
+ class AxolotlTrainingArguments(TrainingArguments):
31
+ """
32
+ Extend the base TrainingArguments for axolotl helpers
33
+ """
34
+
35
+ lr_quadratic_warmup: bool = field(
36
+ default=False,
37
+ metadata={"help": "Use quadratic warmup for cosine scheduling."},
38
+ )
39
+
40
+
41
+ class AxolotlTrainer(Trainer):
42
+ """
43
+ Extend the base Trainer for axolotl helpers
44
+ """
45
+
46
+ args = None # type: AxolotlTrainingArguments
47
+
48
+ def create_scheduler(
49
+ self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
50
+ ):
51
+ """
52
+ Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
53
+ passed as an argument.
54
+
55
+ Args:
56
+ num_training_steps (int): The number of training steps to do.
57
+ optimizer (torch.optim.Optimizer): The training optimizer
58
+ """
59
+
60
+ # fmt: off
61
+ if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
62
+ # fmt: on
63
+ if (
64
+ self.args.lr_scheduler_type == "cosine"
65
+ and self.args.lr_quadratic_warmup is True
66
+ ):
67
+ self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
68
+ optimizer,
69
+ num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
70
+ num_training_steps=num_training_steps,
71
+ )
72
+ else:
73
+ return super().create_scheduler(num_training_steps, optimizer)
74
+ return self.lr_scheduler
75
+
76
+
77
+ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
78
  """
79
  Trainer subclass that uses the OneCycleLR scheduler
80
  """
 
154
  if cfg.fsdp_config:
155
  training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
156
 
157
+ if cfg.lr_quadratic_warmup is not None:
158
+ training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
159
+
160
  # deepspeed
161
  if (
162
  os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
 
182
  training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
183
  training_arguments_kwargs["push_to_hub"] = True
184
 
185
+ training_args = AxolotlTrainingArguments(
186
  per_device_train_batch_size=cfg.micro_batch_size,
187
  per_device_eval_batch_size=cfg.eval_batch_size
188
  if cfg.eval_batch_size is not None
 
332
  trainer_cls = (
333
  OneCycleLRSchedulerTrainer
334
  if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
335
+ else AxolotlTrainer
336
  )
337
  trainer = trainer_cls(
338
  model=model,