winglian commited on
Commit
e30f1e3
1 Parent(s): 3437149

Early stopping metric (#537)

Browse files

* set early stopping metric to check

* tweak how load_best_model_at_end gets set for early stopping

* add validation for earl;y stopping patience

* remove negation

* save results to metrics in callback

* move early stopping callback after the benchmark evals

* broadcast metrics so early stopping works

src/axolotl/utils/callbacks.py CHANGED
@@ -25,6 +25,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
25
  from axolotl.utils.bench import log_gpu_memory_usage
26
  from axolotl.utils.distributed import (
27
  barrier,
 
28
  gather_scalar_from_all_ranks,
29
  get_world_size,
30
  is_distributed,
@@ -271,6 +272,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
271
  lambda: len(data_loader), get_world_size()
272
  )
273
 
 
274
  if is_distributed() and not is_main_process():
275
  dist.gather_object(local_bench_names, dst=0)
276
  else:
@@ -316,4 +318,8 @@ def bench_eval_callback_factory(trainer, tokenizer):
316
  )["accuracy"]
317
  trainer.log(results)
318
 
 
 
 
 
319
  return BenchEvalCallback
 
25
  from axolotl.utils.bench import log_gpu_memory_usage
26
  from axolotl.utils.distributed import (
27
  barrier,
28
+ broadcast_dict,
29
  gather_scalar_from_all_ranks,
30
  get_world_size,
31
  is_distributed,
 
272
  lambda: len(data_loader), get_world_size()
273
  )
274
 
275
+ results = {}
276
  if is_distributed() and not is_main_process():
277
  dist.gather_object(local_bench_names, dst=0)
278
  else:
 
318
  )["accuracy"]
319
  trainer.log(results)
320
 
321
+ results = broadcast_dict(results)
322
+ for key, val in results.items():
323
+ metrics[key] = val
324
+
325
  return BenchEvalCallback
src/axolotl/utils/config.py CHANGED
@@ -220,6 +220,15 @@ def validate_config(cfg):
220
  "sample_packing not compatible with xformers_attention. Use flash_attention"
221
  )
222
 
 
 
 
 
 
 
 
 
 
223
  # TODO
224
  # MPT 7b
225
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
220
  "sample_packing not compatible with xformers_attention. Use flash_attention"
221
  )
222
 
223
+ if cfg.early_stopping_patience:
224
+ if not cfg.save_steps or not cfg.eval_steps:
225
+ raise ValueError(
226
+ "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
227
+ )
228
+ if cfg.save_steps % cfg.eval_steps != 0:
229
+ raise ValueError(
230
+ "`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
231
+ )
232
  # TODO
233
  # MPT 7b
234
  # https://github.com/facebookresearch/bitsandbytes/issues/25
src/axolotl/utils/distributed.py CHANGED
@@ -2,6 +2,7 @@
2
  utility helpers for distributed checks
3
  """
4
  import os
 
5
  from contextlib import contextmanager
6
 
7
  import torch
@@ -93,3 +94,30 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
93
  gathered_values.append(float(tensor.item()))
94
  return gathered_values
95
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  utility helpers for distributed checks
3
  """
4
  import os
5
+ import pickle # nosec
6
  from contextlib import contextmanager
7
 
8
  import torch
 
94
  gathered_values.append(float(tensor.item()))
95
  return gathered_values
96
  return None
97
+
98
+
99
+ def broadcast_dict(vals: dict):
100
+ if not is_distributed():
101
+ return vals
102
+
103
+ if is_main_process():
104
+ data_byte = pickle.dumps(vals)
105
+ data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
106
+ data_size = torch.IntTensor([len(data_byte)]).to("cuda")
107
+ else:
108
+ data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
109
+ data_size = torch.IntTensor([0]).to("cuda")
110
+
111
+ dist.broadcast(data_size, 0)
112
+ if not is_main_process():
113
+ # resize
114
+ data_tensor = data_tensor.new_empty([data_size.item()])
115
+
116
+ dist.broadcast(data_tensor, 0)
117
+
118
+ if not is_main_process():
119
+ data_list = data_tensor.cpu().tolist()
120
+ data_byte = bytes(data_list[: data_size.item()])
121
+ vals = pickle.loads(data_byte) # nosec
122
+
123
+ return vals
src/axolotl/utils/trainer.py CHANGED
@@ -576,6 +576,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
576
  training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
577
  if cfg.bench_dataset:
578
  training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
 
 
 
 
579
 
580
  # DDP Config
581
  if cfg.ddp_timeout:
@@ -601,11 +605,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
601
  output_dir=cfg.output_dir,
602
  save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
603
  load_best_model_at_end=(
604
- cfg.load_best_model_at_end is not False
605
  and cfg.val_set_size > 0
606
  and cfg.save_steps
607
  and cfg.save_steps % cfg.eval_steps == 0
608
- and cfg.load_in_8bit is not True
609
  )
610
  or False,
611
  ddp_find_unused_parameters=False if cfg.ddp else None,
@@ -637,13 +640,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
637
  if cfg.relora_steps:
638
  callbacks.append(ReLoRACallback(cfg))
639
 
640
- # TODO on_save callback to sync checkpoints to GCP/AWS in background
641
- if cfg.early_stopping_patience:
642
- early_stop_cb = EarlyStoppingCallback(
643
- cfg.early_stopping_patience,
644
- )
645
- callbacks.append(early_stop_cb)
646
-
647
  if cfg.local_rank == 0 and cfg.adapter in [
648
  "lora",
649
  "qlora",
@@ -710,4 +706,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
710
  if cfg.do_bench_eval:
711
  trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
712
 
 
 
 
 
 
 
 
713
  return trainer
 
576
  training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
577
  if cfg.bench_dataset:
578
  training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
579
+ if cfg.metric_for_best_model:
580
+ training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
581
+ if cfg.greater_is_better:
582
+ training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
583
 
584
  # DDP Config
585
  if cfg.ddp_timeout:
 
605
  output_dir=cfg.output_dir,
606
  save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
607
  load_best_model_at_end=(
608
+ (cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
609
  and cfg.val_set_size > 0
610
  and cfg.save_steps
611
  and cfg.save_steps % cfg.eval_steps == 0
 
612
  )
613
  or False,
614
  ddp_find_unused_parameters=False if cfg.ddp else None,
 
640
  if cfg.relora_steps:
641
  callbacks.append(ReLoRACallback(cfg))
642
 
 
 
 
 
 
 
 
643
  if cfg.local_rank == 0 and cfg.adapter in [
644
  "lora",
645
  "qlora",
 
706
  if cfg.do_bench_eval:
707
  trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
708
 
709
+ # TODO on_save callback to sync checkpoints to GCP/AWS in background
710
+ if cfg.early_stopping_patience:
711
+ early_stop_cb = EarlyStoppingCallback(
712
+ cfg.early_stopping_patience,
713
+ )
714
+ trainer.add_callback(early_stop_cb)
715
+
716
  return trainer