winglian commited on
Commit
7657632
1 Parent(s): 548787d

add eval benchmark callback (#441)

Browse files

* add mmlu callback

* use hf dataset for mmlu evals

* default to mmlu-zs

* make sure to define all the explicit positional args

* include metrics in callback

* another callback fix for collator max len attribute

* fix mmlu evals

* sample benchmarks, ensure we drop long samples

* fix the data file

* fix elif and add better messaging

* more fixes

* rename mmlu to bench

* more fixes

* dataset handling and aggregate across benchmark

* better handling when no subjects

* benchmark callback has its own dataloader and collator

* fixes

* updated dataset

* more fixes

* missing transformers import

* improve support for customized dataset for bench evals

* gather benchmarks from all ranks

* fix for gather across multiple gpus

requirements.txt CHANGED
@@ -4,6 +4,7 @@ transformers @ git+https://github.com/huggingface/transformers.git
4
  bitsandbytes>=0.41.1
5
  accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
6
  addict
 
7
  fire
8
  PyYAML>=6.0
9
  datasets
 
4
  bitsandbytes>=0.41.1
5
  accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
6
  addict
7
+ evaluate
8
  fire
9
  PyYAML>=6.0
10
  datasets
src/axolotl/utils/callbacks.py CHANGED
@@ -1,9 +1,19 @@
1
  """Callbacks for Trainer class"""
2
 
 
 
3
  import logging
4
  import os
 
5
 
 
 
 
 
 
 
6
  from optimum.bettertransformer import BetterTransformer
 
7
  from transformers import (
8
  TrainerCallback,
9
  TrainerControl,
@@ -13,8 +23,19 @@ from transformers import (
13
  from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
14
 
15
  from axolotl.utils.bench import log_gpu_memory_usage
 
 
 
 
 
 
 
 
 
 
16
 
17
  LOG = logging.getLogger("axolotl.callbacks")
 
18
 
19
 
20
  class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -96,3 +117,192 @@ class GPUStatsCallback(
96
  log_gpu_memory_usage(LOG, "while training", self.cfg.device)
97
  self.logged = True
98
  return control
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Callbacks for Trainer class"""
2
 
3
+ from __future__ import annotations
4
+
5
  import logging
6
  import os
7
+ from typing import TYPE_CHECKING, Dict, List
8
 
9
+ import evaluate
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ import torch.distributed as dist
14
+ from datasets import load_dataset
15
  from optimum.bettertransformer import BetterTransformer
16
+ from tqdm import tqdm
17
  from transformers import (
18
  TrainerCallback,
19
  TrainerControl,
 
23
  from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
24
 
25
  from axolotl.utils.bench import log_gpu_memory_usage
26
+ from axolotl.utils.distributed import (
27
+ barrier,
28
+ gather_scalar_from_all_ranks,
29
+ get_world_size,
30
+ is_main_process,
31
+ zero_first,
32
+ )
33
+
34
+ if TYPE_CHECKING:
35
+ from axolotl.utils.trainer import AxolotlTrainingArguments
36
 
37
  LOG = logging.getLogger("axolotl.callbacks")
38
+ IGNORE_INDEX = -100
39
 
40
 
41
  class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
 
117
  log_gpu_memory_usage(LOG, "while training", self.cfg.device)
118
  self.logged = True
119
  return control
120
+
121
+
122
+ def bench_eval_callback_factory(trainer, tokenizer):
123
+ accuracy = evaluate.load("accuracy")
124
+ abcd_idx = [
125
+ tokenizer("A", add_special_tokens=False).input_ids[0],
126
+ tokenizer("B", add_special_tokens=False).input_ids[0],
127
+ tokenizer("C", add_special_tokens=False).input_ids[0],
128
+ tokenizer("D", add_special_tokens=False).input_ids[0],
129
+ tokenizer("E", add_special_tokens=False).input_ids[0],
130
+ tokenizer("F", add_special_tokens=False).input_ids[0],
131
+ tokenizer("G", add_special_tokens=False).input_ids[0],
132
+ ]
133
+ bench_split = "eval"
134
+
135
+ def transform_bench_subject(example):
136
+ # Split on ':' and trim whitespace
137
+ parts = example["subject"].split(":")
138
+ first_part = (
139
+ parts[0].strip().lower().replace("-", "_")
140
+ ) # Lowercase the first part
141
+ second_part = (
142
+ parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
143
+ ) # Replace hyphens with underscores
144
+
145
+ # Return the transformed values
146
+ return {"name": first_part, "subject": second_part}
147
+
148
+ if trainer.args.bench_dataset == "mmlu-zs":
149
+ bench_dataset = load_dataset(
150
+ "openaccess-ai-collective/mmlu-evals",
151
+ data_files={
152
+ "eval": "zero_shot_mmlu_val.json",
153
+ "test": "zero_shot_mmlu_test.json",
154
+ },
155
+ )
156
+ # bench_dataset = bench_dataset.remove_columns("subject")
157
+ # MMLU Five-shot (Eval/Test only)
158
+ elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]:
159
+ bench_dataset = load_dataset(
160
+ "openaccess-ai-collective/mmlu-evals",
161
+ data_files={
162
+ "eval": "five_shot_mmlu_val.json",
163
+ "test": "five_shot_mmlu_test.json",
164
+ },
165
+ )
166
+ # bench_dataset = bench_dataset.remove_columns('subject')
167
+ elif "/" in trainer.args.bench_dataset:
168
+ bench_ds = trainer.args.bench_dataset
169
+ bench_ds_name = "/".join(bench_ds.split("/", 2)[:2])
170
+ bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:])
171
+ bench_dataset = load_dataset(
172
+ bench_ds_name,
173
+ data_files={
174
+ "eval": bench_ds_data_file,
175
+ },
176
+ )
177
+ bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject)
178
+ else:
179
+ raise ValueError(
180
+ f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
181
+ )
182
+ bench_dataset = bench_dataset[trainer.args.bench_split]
183
+ if trainer.args.max_bench_samples is not None:
184
+ bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))
185
+
186
+ def tokenize_evals(example):
187
+ source = f"{tokenizer.bos_token}{example['input']}"
188
+ target = f"{example['output']}{tokenizer.eos_token}"
189
+
190
+ tokenized_source = tokenizer(
191
+ source,
192
+ max_length=2048,
193
+ truncation=True,
194
+ add_special_tokens=False,
195
+ )
196
+ tokenized_target = tokenizer(
197
+ target,
198
+ max_length=2048,
199
+ truncation=True,
200
+ add_special_tokens=False,
201
+ )
202
+ input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
203
+ labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[
204
+ "input_ids"
205
+ ]
206
+
207
+ return {
208
+ "input_ids": input_ids,
209
+ "labels": labels,
210
+ "subject": example["subject"],
211
+ }
212
+
213
+ with zero_first(is_main_process()):
214
+ bench_dataset = bench_dataset.map(tokenize_evals)
215
+ bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
216
+
217
+ class BenchEvalCallback(TrainerCallback):
218
+ """
219
+ TrainerCallback that runs the MMLU evals
220
+ """
221
+
222
+ def on_evaluate(
223
+ self,
224
+ args: AxolotlTrainingArguments,
225
+ state: TrainerState, # pylint: disable=unused-argument
226
+ control: TrainerControl, # pylint: disable=unused-argument
227
+ metrics: Dict[str, float], # pylint: disable=unused-argument
228
+ **kwargs, # pylint: disable=unused-argument
229
+ ):
230
+ data_loader = trainer.get_bench_dataloader(
231
+ bench_dataset.remove_columns(["input", "subject", "output", "name"])
232
+ )
233
+ trainer.model.eval()
234
+ preds, refs = [], []
235
+ loss_bench = 0
236
+ for batch in tqdm(data_loader, total=len(data_loader)):
237
+ (loss, logits, labels) = trainer.prediction_step(
238
+ trainer.model,
239
+ batch,
240
+ prediction_loss_only=False,
241
+ )
242
+ # There are two tokens, the output, and eos token.
243
+ for i, logit in enumerate(logits):
244
+ label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[
245
+ 0
246
+ ][0]
247
+ logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
248
+ preds.append(torch.argmax(logit_abcd).item())
249
+ labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
250
+ refs += [
251
+ abcd_idx.index(label) if label in abcd_idx else -1
252
+ for label in labels.tolist()
253
+ ]
254
+ loss_bench += loss.item()
255
+ # Extract results by subject.
256
+ bench_name = bench_dataset["name"]
257
+ bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
258
+ for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
259
+ bench_names[s]["preds"].append(p)
260
+ bench_names[s]["refs"].append(r)
261
+ barrier()
262
+ local_bench_names = bench_names
263
+ gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
264
+ # Gather results from all GPUs to GPU 0
265
+
266
+ loss_bench_ranks = gather_scalar_from_all_ranks(
267
+ lambda: loss_bench, get_world_size()
268
+ )
269
+ len_data_loader_ranks = gather_scalar_from_all_ranks(
270
+ lambda: len(data_loader), get_world_size()
271
+ )
272
+
273
+ if not is_main_process():
274
+ dist.gather_object(local_bench_names, dst=0)
275
+ else:
276
+ dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
277
+ bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
278
+ results = {"bench_loss": bench_loss}
279
+
280
+ # Combine results from all GPUs
281
+ combined_bench_names: Dict[str, Dict[str, List]] = {}
282
+ for bench_name in gathered_bench_names:
283
+ for name, data in bench_name.items():
284
+ if name not in combined_bench_names:
285
+ combined_bench_names[name] = {"refs": [], "preds": []}
286
+ combined_bench_names[name]["refs"].extend(data["refs"])
287
+ combined_bench_names[name]["preds"].extend(data["preds"])
288
+
289
+ bench_scores = []
290
+ for (
291
+ bench_name
292
+ ) in combined_bench_names: # pylint: disable=consider-using-dict-items
293
+ bench_score = accuracy.compute(
294
+ references=combined_bench_names[bench_name]["refs"],
295
+ predictions=combined_bench_names[bench_name]["preds"],
296
+ )["accuracy"]
297
+ if not pd.isna(bench_score):
298
+ results[
299
+ f"bench_{bench_split}_accuracy_{bench_name}"
300
+ ] = bench_score
301
+ bench_scores.append(bench_score)
302
+ else:
303
+ results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0
304
+ bench_scores.append(0.0)
305
+ results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
306
+ trainer.log(results)
307
+
308
+ return BenchEvalCallback
src/axolotl/utils/distributed.py CHANGED
@@ -1,8 +1,10 @@
1
  """
2
  utility helpers for distributed checks
3
  """
 
4
  from contextlib import contextmanager
5
 
 
6
  import torch.distributed as dist
7
  from accelerate import Accelerator
8
 
@@ -43,6 +45,10 @@ def is_main_process():
43
  return dist.get_rank() == 0
44
 
45
 
 
 
 
 
46
  @contextmanager
47
  def zero_first(is_main):
48
  """
@@ -53,3 +59,35 @@ def zero_first(is_main):
53
  yield
54
  if is_main: # then rank 0 waits after it has run the context
55
  barrier()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  utility helpers for distributed checks
3
  """
4
+ import os
5
  from contextlib import contextmanager
6
 
7
+ import torch
8
  import torch.distributed as dist
9
  from accelerate import Accelerator
10
 
 
45
  return dist.get_rank() == 0
46
 
47
 
48
+ def get_world_size():
49
+ return int(os.getenv("WORLD_SIZE", "1"))
50
+
51
+
52
  @contextmanager
53
  def zero_first(is_main):
54
  """
 
59
  yield
60
  if is_main: # then rank 0 waits after it has run the context
61
  barrier()
62
+
63
+
64
+ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
65
+ """
66
+ Run a callable 'fn' on all ranks and gather the results on the specified rank.
67
+
68
+ Args:
69
+ - fn (callable): A function that computes the value. This should not have any side effects.
70
+ - rank (int, optional): The rank that gathers the values. Default is 0.
71
+ - world_size (int, optional): Total number of processes in the current distributed setup.
72
+
73
+ Returns:
74
+ - A list of computed values from all ranks if on the gathering rank, otherwise None.
75
+ """
76
+ value_scalar = fn()
77
+ value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
78
+
79
+ if not is_main_process():
80
+ dist.gather(value_tensor, dst=0)
81
+ else:
82
+ gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
83
+ dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
84
+
85
+ # Convert tensors back to their original type (int or float)
86
+ gathered_values = []
87
+ for tensor in gathered_tensors:
88
+ if tensor == tensor.int():
89
+ gathered_values.append(int(tensor.item()))
90
+ else:
91
+ gathered_values.append(float(tensor.item()))
92
+ return gathered_values
93
+ return None
src/axolotl/utils/trainer.py CHANGED
@@ -12,9 +12,15 @@ from typing import Optional, Union
12
 
13
  import numpy as np
14
  import torch.cuda
 
15
  from datasets import Dataset, set_caching_enabled
16
  from torch.optim.lr_scheduler import OneCycleLR
17
- from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
 
 
 
 
 
18
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
19
  from transformers.trainer_pt_utils import SequentialDistributedSampler
20
 
@@ -23,6 +29,7 @@ from axolotl.utils.callbacks import (
23
  GPUStatsCallback,
24
  SaveBetterTransformerModelCallback,
25
  SavePeftModelCallback,
 
26
  )
27
  from axolotl.utils.collators import DataCollatorForSeq2Seq
28
  from axolotl.utils.dataloader import MultipackDistributedDataloader
@@ -127,6 +134,27 @@ class AxolotlTrainingArguments(TrainingArguments):
127
  default=None,
128
  metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
129
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
 
132
  class AxolotlTrainer(Trainer):
@@ -136,6 +164,10 @@ class AxolotlTrainer(Trainer):
136
 
137
  args = None # type: AxolotlTrainingArguments
138
 
 
 
 
 
139
  def create_scheduler(
140
  self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
141
  ):
@@ -226,6 +258,31 @@ class AxolotlTrainer(Trainer):
226
  )
227
  return super().get_eval_dataloader(eval_dataset)
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  def compute_loss(self, model, inputs, return_outputs=False):
230
  # use one's weighted cross entropy loss calc
231
  # if self.args.sample_packing:
@@ -517,6 +574,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
517
  "steps" if cfg.save_steps else "epoch"
518
  )
519
 
 
 
 
 
 
520
  training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
521
  max_steps=total_num_steps if cfg.max_steps else -1,
522
  max_seq_length=cfg.sequence_len,
@@ -629,8 +691,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
629
  return_tensors="pt",
630
  **data_collator_kwargs,
631
  ),
 
 
 
 
 
632
  callbacks=callbacks,
633
  **trainer_kwargs,
634
  )
635
 
 
 
 
636
  return trainer
 
12
 
13
  import numpy as np
14
  import torch.cuda
15
+ import transformers
16
  from datasets import Dataset, set_caching_enabled
17
  from torch.optim.lr_scheduler import OneCycleLR
18
+ from torch.utils.data import (
19
+ DataLoader,
20
+ DistributedSampler,
21
+ RandomSampler,
22
+ SequentialSampler,
23
+ )
24
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
25
  from transformers.trainer_pt_utils import SequentialDistributedSampler
26
 
 
29
  GPUStatsCallback,
30
  SaveBetterTransformerModelCallback,
31
  SavePeftModelCallback,
32
+ bench_eval_callback_factory,
33
  )
34
  from axolotl.utils.collators import DataCollatorForSeq2Seq
35
  from axolotl.utils.dataloader import MultipackDistributedDataloader
 
134
  default=None,
135
  metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
136
  )
137
+ bench_split: Optional[str] = field(
138
+ default="eval", metadata={"help": "The benchmark split to run on"}
139
+ )
140
+ bench_dataset: Optional[str] = field(
141
+ default="pharaouk/dharma-1/dharma_1_mini.json",
142
+ metadata={
143
+ "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
144
+ },
145
+ )
146
+ do_bench_eval: Optional[bool] = field(
147
+ default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
148
+ )
149
+ max_bench_samples: Optional[int] = field(
150
+ default=None,
151
+ metadata={
152
+ "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
153
+ },
154
+ )
155
+ bench_source_max_len: int = field(
156
+ default=2048, metadata={"help": "Maximum source sequence length for bench."}
157
+ )
158
 
159
 
160
  class AxolotlTrainer(Trainer):
 
164
 
165
  args = None # type: AxolotlTrainingArguments
166
 
167
+ def __init__(self, *args, bench_data_collator=None, **kwargs):
168
+ self.bench_data_collator = bench_data_collator
169
+ super().__init__(*args, **kwargs)
170
+
171
  def create_scheduler(
172
  self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
173
  ):
 
258
  )
259
  return super().get_eval_dataloader(eval_dataset)
260
 
261
+ def _get_bench_sampler(
262
+ self, bench_dataset: Dataset
263
+ ) -> Optional[torch.utils.data.Sampler]:
264
+ if self.args.world_size <= 1:
265
+ return SequentialSampler(bench_dataset)
266
+ return None
267
+
268
+ def get_bench_dataloader(
269
+ self,
270
+ bench_dataset: Dataset,
271
+ ) -> Union[DataLoader, MultipackDistributedDataloader]:
272
+ dataloader_params = {
273
+ "batch_size": self.args.eval_batch_size,
274
+ "collate_fn": self.bench_data_collator,
275
+ "num_workers": self.args.dataloader_num_workers,
276
+ "pin_memory": self.args.dataloader_pin_memory,
277
+ }
278
+
279
+ if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
280
+ dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
281
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
282
+
283
+ return DataLoader(bench_dataset, **dataloader_params)
284
+ # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
285
+
286
  def compute_loss(self, model, inputs, return_outputs=False):
287
  # use one's weighted cross entropy loss calc
288
  # if self.args.sample_packing:
 
574
  "steps" if cfg.save_steps else "epoch"
575
  )
576
 
577
+ if cfg.do_bench_eval:
578
+ training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
579
+ if cfg.bench_dataset:
580
+ training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
581
+
582
  training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
583
  max_steps=total_num_steps if cfg.max_steps else -1,
584
  max_seq_length=cfg.sequence_len,
 
691
  return_tensors="pt",
692
  **data_collator_kwargs,
693
  ),
694
+ bench_data_collator=transformers.DataCollatorForSeq2Seq(
695
+ tokenizer,
696
+ return_tensors="pt",
697
+ **data_collator_kwargs,
698
+ ),
699
  callbacks=callbacks,
700
  **trainer_kwargs,
701
  )
702
 
703
+ if cfg.do_bench_eval:
704
+ trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
705
+
706
  return trainer