winglian commited on
Commit
2844eb2
1 Parent(s): e85d2eb

run eval on the first step to get a baseline (#617)

Browse files

* run eval on the first step to get a baseline

* wandb kleeps getting moved around by pre-commit ...

src/axolotl/utils/callbacks.py CHANGED
@@ -66,6 +66,29 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
66
  return control
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  class SaveBetterTransformerModelCallback(
70
  TrainerCallback
71
  ): # pylint: disable=too-few-public-methods
 
66
  return control
67
 
68
 
69
+ class EvalFirstStepCallback(
70
+ TrainerCallback
71
+ ): # pylint: disable=too-few-public-methods disable=unused-argument
72
+ """
73
+ Callback to trigger evals on the first step
74
+ """
75
+
76
+ def on_step_end(
77
+ self,
78
+ args: TrainingArguments,
79
+ state: TrainerState,
80
+ control: TrainerControl,
81
+ **kwargs,
82
+ ):
83
+ if (
84
+ args.evaluation_strategy == IntervalStrategy.STEPS
85
+ and args.eval_steps < 1.0
86
+ and state.global_step == 1
87
+ ):
88
+ control.should_evaluate = True
89
+ return control
90
+
91
+
92
  class SaveBetterTransformerModelCallback(
93
  TrainerCallback
94
  ): # pylint: disable=too-few-public-methods
src/axolotl/utils/trainer.py CHANGED
@@ -28,6 +28,7 @@ from transformers.trainer_pt_utils import SequentialDistributedSampler
28
 
29
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
30
  from axolotl.utils.callbacks import (
 
31
  GPUStatsCallback,
32
  SaveBetterTransformerModelCallback,
33
  SavePeftModelCallback,
@@ -704,6 +705,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
704
 
705
  callbacks = []
706
  callbacks.append(GPUStatsCallback(cfg))
 
707
 
708
  if cfg.relora_steps:
709
  callbacks.append(ReLoRACallback(cfg))
 
28
 
29
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
30
  from axolotl.utils.callbacks import (
31
+ EvalFirstStepCallback,
32
  GPUStatsCallback,
33
  SaveBetterTransformerModelCallback,
34
  SavePeftModelCallback,
 
705
 
706
  callbacks = []
707
  callbacks.append(GPUStatsCallback(cfg))
708
+ callbacks.append(EvalFirstStepCallback)
709
 
710
  if cfg.relora_steps:
711
  callbacks.append(ReLoRACallback(cfg))