winglian commited on
Commit
1a82082
1 Parent(s): 1210dc8

fix bettertransformers save, force it to skip after saving correctly in callback

Browse files
src/axolotl/utils/callbacks.py CHANGED
@@ -9,7 +9,7 @@ from transformers import (
9
  TrainerState,
10
  TrainingArguments,
11
  )
12
- from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
13
 
14
 
15
  class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -36,21 +36,33 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
36
  class SaveBetterTransformerModelCallback(
37
  TrainerCallback
38
  ): # pylint: disable=too-few-public-methods
39
- """Callback to save the BatterTransformer wrapped model"""
40
 
41
- def on_save(
42
  self,
43
  args: TrainingArguments,
44
  state: TrainerState,
45
  control: TrainerControl,
46
  **kwargs,
47
  ):
48
- checkpoint_folder = os.path.join(
49
- args.output_dir,
50
- f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
51
- )
 
 
 
 
 
 
 
 
 
52
 
53
- model = BetterTransformer.reverse(kwargs["model"])
54
- model.save_pretrained(checkpoint_folder)
55
 
 
 
 
56
  return control
 
9
  TrainerState,
10
  TrainingArguments,
11
  )
12
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
13
 
14
 
15
  class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
 
36
  class SaveBetterTransformerModelCallback(
37
  TrainerCallback
38
  ): # pylint: disable=too-few-public-methods
39
+ """Callback to save the BetterTransformer wrapped model"""
40
 
41
+ def on_step_end(
42
  self,
43
  args: TrainingArguments,
44
  state: TrainerState,
45
  control: TrainerControl,
46
  **kwargs,
47
  ):
48
+ # Save
49
+ if (
50
+ args.save_strategy == IntervalStrategy.STEPS
51
+ and args.save_steps > 0
52
+ and state.global_step % args.save_steps == 0
53
+ ):
54
+ control.should_save = True
55
+
56
+ if control.should_save:
57
+ checkpoint_folder = os.path.join(
58
+ args.output_dir,
59
+ f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
60
+ )
61
 
62
+ model = BetterTransformer.reverse(kwargs["model"])
63
+ model.save_pretrained(checkpoint_folder)
64
 
65
+ # since we're saving here, we don't need the trainer loop to attempt to save too b/c
66
+ # the trainer will raise an exception since it can't save a BetterTransformer wrapped model
67
+ control.should_save = False
68
  return control
src/axolotl/utils/trainer.py CHANGED
@@ -232,6 +232,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
232
  callbacks.append(SavePeftModelCallback)
233
 
234
  if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
 
235
  callbacks.append(SaveBetterTransformerModelCallback)
236
 
237
  data_collator_kwargs = {
 
232
  callbacks.append(SavePeftModelCallback)
233
 
234
  if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
235
+ logging.info("Setting up SaveBetterTransformerModelCallback.")
236
  callbacks.append(SaveBetterTransformerModelCallback)
237
 
238
  data_collator_kwargs = {
src/axolotl/utils/validation.py CHANGED
@@ -66,9 +66,10 @@ def validate_config(cfg):
66
  )
67
  if cfg.fp16 or cfg.bf16:
68
  raise ValueError("AMP is not supported with BetterTransformer")
69
- if cfg.float16 is not True:
70
  logging.warning(
71
- "You should probably set float16 to true to load the model in float16 for BetterTransformers"
 
72
  )
73
  if int(torch.__version__.split(".")[0]) < 2:
74
  logging.warning("torch>=2.0.0 required")
 
66
  )
67
  if cfg.fp16 or cfg.bf16:
68
  raise ValueError("AMP is not supported with BetterTransformer")
69
+ if cfg.float16 is not True and cfg.bloat16 is not True:
70
  logging.warning(
71
+ "You should probably set bfloat16 or float16 to true to "
72
+ "load the model in float16 for BetterTransformers"
73
  )
74
  if int(torch.__version__.split(".")[0]) < 2:
75
  logging.warning("torch>=2.0.0 required")