Nanobit commited on
Commit
fb12895
1 Parent(s): 9fc29e0

Feat: Add warmup_ratio (#893)

Browse files

* Feat: Add warmup_ratio

* fix: update readme with more details on conflict

README.md CHANGED
@@ -675,7 +675,8 @@ gradient_accumulation_steps: 1
675
  micro_batch_size: 2
676
  eval_batch_size:
677
  num_epochs: 4
678
- warmup_steps: 100
 
679
  learning_rate: 0.00003
680
  lr_quadratic_warmup:
681
  logging_steps:
 
675
  micro_batch_size: 2
676
  eval_batch_size:
677
  num_epochs: 4
678
+ warmup_steps: 100 # cannot use with warmup_ratio
679
+ warmup_ratio: 0.05 # cannot use with warmup_steps
680
  learning_rate: 0.00003
681
  lr_quadratic_warmup:
682
  logging_steps:
src/axolotl/core/trainer_builder.py CHANGED
@@ -461,11 +461,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
461
  return AxolotlTrainer
462
 
463
  def build(self, total_num_steps):
464
- warmup_steps = (
465
- self.cfg.warmup_steps
466
- if self.cfg.warmup_steps is not None
467
- else min(int(0.03 * total_num_steps), 100)
468
- )
 
 
 
469
  logging_steps = (
470
  self.cfg.logging_steps
471
  if self.cfg.logging_steps is not None
 
461
  return AxolotlTrainer
462
 
463
  def build(self, total_num_steps):
464
+ warmup_steps = None
465
+ if self.cfg.warmup_steps is not None:
466
+ warmup_steps = self.cfg.warmup_steps
467
+ elif self.cfg.warmup_ratio is not None:
468
+ warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
469
+ else:
470
+ warmup_steps = min(int(0.03 * total_num_steps), 100)
471
+
472
  logging_steps = (
473
  self.cfg.logging_steps
474
  if self.cfg.logging_steps is not None
src/axolotl/utils/config.py CHANGED
@@ -372,6 +372,9 @@ def validate_config(cfg):
372
  if cfg.rope_scaling:
373
  LOG.warning("`rope_scaling` should now be be a key under `model_config`")
374
 
 
 
 
375
  # TODO
376
  # MPT 7b
377
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
372
  if cfg.rope_scaling:
373
  LOG.warning("`rope_scaling` should now be be a key under `model_config`")
374
 
375
+ if cfg.warmup_steps and cfg.warmup_ratio:
376
+ raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
377
+
378
  # TODO
379
  # MPT 7b
380
  # https://github.com/facebookresearch/bitsandbytes/issues/25
tests/test_validation.py CHANGED
@@ -649,3 +649,33 @@ class ValidationTest(unittest.TestCase):
649
  )
650
 
651
  validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
  )
650
 
651
  validate_config(cfg)
652
+
653
+ def test_warmup_step_no_conflict(self):
654
+ cfg = DictDefault(
655
+ {
656
+ "warmup_steps": 10,
657
+ "warmup_ratio": 0.1,
658
+ }
659
+ )
660
+
661
+ with pytest.raises(
662
+ ValueError,
663
+ match=r".*warmup_steps and warmup_ratio are mutually exclusive*",
664
+ ):
665
+ validate_config(cfg)
666
+
667
+ cfg = DictDefault(
668
+ {
669
+ "warmup_steps": 10,
670
+ }
671
+ )
672
+
673
+ validate_config(cfg)
674
+
675
+ cfg = DictDefault(
676
+ {
677
+ "warmup_ratio": 0.1,
678
+ }
679
+ )
680
+
681
+ validate_config(cfg)