user735 Karl-Johan Alm commited on
Commit
58ec8b1
1 Parent(s): 476a205

feature: loss watchdog for terminating training runs that are failing (#899)

Browse files
README.md CHANGED
@@ -694,6 +694,9 @@ max_steps:
694
  eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
695
  eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
696
 
 
 
 
697
  # Save model as safetensors (require safetensors package)
698
  save_safetensors:
699
 
 
694
  eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
695
  eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
696
 
697
+ loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
698
+ loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
699
+
700
  # Save model as safetensors (require safetensors package)
701
  save_safetensors:
702
 
examples/mistral/qlora.yml CHANGED
@@ -62,6 +62,9 @@ logging_steps: 1
62
  xformers_attention:
63
  flash_attention: true
64
 
 
 
 
65
  warmup_steps: 10
66
  eval_steps: 0.05
67
  eval_table_size:
 
62
  xformers_attention:
63
  flash_attention: true
64
 
65
+ loss_watchdog_threshold: 5.0
66
+ loss_watchdog_patience: 3
67
+
68
  warmup_steps: 10
69
  eval_steps: 0.05
70
  eval_table_size:
src/axolotl/core/trainer_builder.py CHANGED
@@ -25,6 +25,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
25
  from axolotl.utils.callbacks import (
26
  EvalFirstStepCallback,
27
  GPUStatsCallback,
 
28
  SaveAxolotlConfigtoWandBCallback,
29
  SaveBetterTransformerModelCallback,
30
  bench_eval_callback_factory,
@@ -430,6 +431,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
430
  SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
431
  )
432
 
 
 
 
433
  return callbacks
434
 
435
  def get_post_trainer_create_callbacks(self, trainer):
 
25
  from axolotl.utils.callbacks import (
26
  EvalFirstStepCallback,
27
  GPUStatsCallback,
28
+ LossWatchDogCallback,
29
  SaveAxolotlConfigtoWandBCallback,
30
  SaveBetterTransformerModelCallback,
31
  bench_eval_callback_factory,
 
431
  SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
432
  )
433
 
434
+ if self.cfg.loss_watchdog_threshold is not None:
435
+ callbacks.append(LossWatchDogCallback(self.cfg))
436
+
437
  return callbacks
438
 
439
  def get_post_trainer_create_callbacks(self, trainer):
src/axolotl/utils/callbacks.py CHANGED
@@ -124,6 +124,36 @@ class GPUStatsCallback(
124
  return control
125
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def bench_eval_callback_factory(trainer, tokenizer):
128
  accuracy = evaluate.load("accuracy")
129
  abcd_idx = [
 
124
  return control
125
 
126
 
127
+ class LossWatchDogCallback(TrainerCallback):
128
+ """Callback to track loss and stop training if loss is too high"""
129
+
130
+ def __init__(self, cfg):
131
+ self.cfg = cfg
132
+ self.logged = False
133
+ self.violations = 0
134
+ self.threshold = cfg.loss_watchdog_threshold
135
+ self.patience = cfg.loss_watchdog_patience or 3
136
+
137
+ def on_step_end(
138
+ self,
139
+ _args: TrainingArguments,
140
+ state: TrainerState,
141
+ control: TrainerControl,
142
+ **_kwargs,
143
+ ):
144
+ if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
145
+ if state.log_history[-1]["loss"] > self.threshold:
146
+ self.violations += 1
147
+ if self.violations >= self.patience:
148
+ LOG.warning(
149
+ "Loss is too high, stopping training (loss_watchdog_threshold)"
150
+ )
151
+ control.should_training_stop = True
152
+ else:
153
+ self.violations = 0
154
+ return control
155
+
156
+
157
  def bench_eval_callback_factory(trainer, tokenizer):
158
  accuracy = evaluate.load("accuracy")
159
  abcd_idx = [