winglian commited on
Commit
87d7825
1 Parent(s): eb80890

Tokenization open assistant (#1)

Browse files

* refactor prompt tokenization to more easily support open assistant

* add open assisstant handling, more logging, black formatting

Files changed (2) hide show
  1. scripts/finetune.py +115 -39
  2. src/axolotl/prompt_tokenizers.py +34 -12
scripts/finetune.py CHANGED
@@ -37,6 +37,7 @@ from axolotl.prompt_tokenizers import (
37
  ShareGPTPromptTokenizingStrategy,
38
  LLAMA_DEFAULT_PAD_TOKEN,
39
  GPTeacherPromptTokenizingStrategy,
 
40
  )
41
  from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
42
 
@@ -56,7 +57,15 @@ def setup_wandb_env_vars(cfg):
56
  os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
57
 
58
 
59
- def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
 
 
 
 
 
 
 
 
60
  # TODO refactor as a kwarg
61
  load_in_8bit = cfg.load_in_8bit
62
  tokenizer = None
@@ -67,13 +76,17 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
67
  if is_llama_derived_model and cfg.flash_attention:
68
  if cfg.device not in ["mps", "cpu"] and inference is False:
69
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
 
70
  logging.info("patching with flash attention")
71
  replace_llama_attn_with_flash_attn()
72
 
73
- torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
74
  try:
75
  if cfg.load_4bit:
76
- from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model
 
 
 
77
  replace_peft_model_with_int4_lora_model()
78
 
79
  from peft import (
@@ -92,18 +105,26 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
92
  from huggingface_hub import snapshot_download
93
 
94
  cache_model_path = Path(snapshot_download(base_model))
95
- files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensors')) + list(cache_model_path.glob('*.bin'))
 
 
 
 
96
  if len(files) > 0:
97
  model_path = str(files[0])
98
  else:
99
- logging.warning("unable to find a cached model file, this will likely fail...")
 
 
100
  model_path = str(cache_model_path)
101
  model, tokenizer = load_llama_model_4bit_low_ram(
102
  base_model_config if base_model_config else base_model,
103
  model_path,
104
  device_map=cfg.device_map,
105
  groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
106
- is_v1_model=cfg.gptq_model_v1 if cfg.gptq_model_v1 is not None else True,
 
 
107
  )
108
  load_in_8bit = False
109
  elif is_llama_derived_model:
@@ -120,7 +141,11 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
120
  torch_dtype=torch_dtype,
121
  device_map=cfg.device_map,
122
  )
123
- except:
 
 
 
 
124
  model = AutoModelForCausalLM.from_pretrained(
125
  base_model,
126
  load_in_8bit=cfg.load_in_8bit,
@@ -145,7 +170,6 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
145
  if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
146
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
147
 
148
-
149
  if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
150
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
151
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -165,7 +189,12 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
165
  )
166
 
167
  if cfg.lora_model_dir:
168
- model = PeftModel.from_pretrained(model, cfg.lora_model_dir, device_map = cfg.device_map, torch_dtype=torch.float16)
 
 
 
 
 
169
  else:
170
  model = get_peft_model(model, lora_config)
171
 
@@ -174,9 +203,11 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
174
 
175
  if cfg.load_4bit:
176
  # Scales to half
177
- logging.info('Fitting 4bit scales and zeros to half')
178
  for n, m in model.named_modules():
179
- if 'Autograd4bitQuantLinear' in str(type(m)) or 'Linear4bitLt' in str(type(m)):
 
 
180
  if hasattr(m, "is_v1_model") and m.is_v1_model:
181
  m.zeros = m.zeros.half()
182
  m.scales = m.scales.half()
@@ -236,37 +267,44 @@ def check_dataset_labels(dataset, tokenizer):
236
 
237
 
238
  def do_inference(cfg, model, tokenizer):
239
- tokenizer.add_special_tokens({'unk_token': '<unk>'})
240
- tokenizer.add_special_tokens({'bos_token': '<s>'})
241
- tokenizer.add_special_tokens({'eos_token': '</s>'})
242
 
243
  instruction = "Tell me a joke about dromedaries."
244
  input = ""
245
- prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(instruction=instruction, input=input)
 
 
246
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
247
 
248
  model.eval()
249
  with torch.no_grad():
250
  # gc = GenerationConfig() # TODO swap out and use this
251
- generated = model.generate(inputs=batch["input_ids"].to("cuda"),
252
- do_sample=True, use_cache=True,
253
- repetition_penalty=1.1,
254
- max_new_tokens=100,
255
- temperature=0.9,
256
- top_p=0.95,
257
- top_k=40,
258
- return_dict_in_generate=True,
259
- output_attentions=False,
260
- output_hidden_states=False,
261
- output_scores=False)
262
- print(tokenizer.decode(generated['sequences'].cpu().tolist()[0]))
 
 
 
263
 
264
 
265
  def choose_config(path: Path):
266
  yaml_files = [file for file in path.glob("*.yml")]
267
 
268
  if not yaml_files:
269
- raise ValueError("No YAML config files found in the specified directory. Are you using a .yml extension?")
 
 
270
 
271
  print("Choose a YAML file:")
272
  for idx, file in enumerate(yaml_files):
@@ -376,6 +414,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
376
 
377
  return trainer
378
 
 
379
  def train(
380
  config: Path = Path("configs/"),
381
  prepare_ds_only: bool = False,
@@ -420,7 +459,13 @@ def train(
420
  # Load the model and tokenizer
421
  logging.info("loading model, tokenizer, and lora_config...")
422
  model, tokenizer, lora_config = load_model(
423
- cfg.base_model, cfg.base_model_config, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
 
 
 
 
 
 
424
  )
425
 
426
  if "inference" in kwargs:
@@ -428,10 +473,26 @@ def train(
428
  do_inference(cfg, model, tokenizer)
429
  return
430
 
431
- max_packed_sequence_len = cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
432
- max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
433
- ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
434
- prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
  if any(prepared_ds_path.glob("*")):
437
  logging.info("Loading prepared dataset from disk...")
@@ -464,9 +525,18 @@ def train(
464
  )
465
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
466
  datasets.append(ds_wrapper)
 
 
 
 
 
 
467
  elif d.type == "gpteacher":
468
  ds_strategy = GPTeacherPromptTokenizingStrategy(
469
- GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
470
  )
471
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
472
  datasets.append(ds_wrapper)
@@ -476,13 +546,17 @@ def train(
476
  )
477
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
478
  datasets.append(ds_wrapper)
 
 
479
  constant_len_dataset = ConstantLengthDataset(
480
- tokenizer, datasets, seq_length=max_packed_sequence_len,
 
 
481
  )
482
  logging.info("merging, packing, shuffling, and splitting master dataset")
483
- dataset = Dataset.from_list(
484
- [_ for _ in constant_len_dataset]
485
- ).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
486
 
487
  if cfg.local_rank == 0:
488
  logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
@@ -525,7 +599,9 @@ def train(
525
 
526
  if cfg.local_rank == 0:
527
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
528
- logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
 
 
529
  model.save_pretrained(cfg.output_dir)
530
 
531
 
 
37
  ShareGPTPromptTokenizingStrategy,
38
  LLAMA_DEFAULT_PAD_TOKEN,
39
  GPTeacherPromptTokenizingStrategy,
40
+ OpenAssistantPromptTokenizingStrategy,
41
  )
42
  from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
43
 
 
57
  os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
58
 
59
 
60
+ def load_model(
61
+ base_model,
62
+ base_model_config,
63
+ model_type,
64
+ tokenizer_type,
65
+ cfg,
66
+ adapter="lora",
67
+ inference: bool = False,
68
+ ):
69
  # TODO refactor as a kwarg
70
  load_in_8bit = cfg.load_in_8bit
71
  tokenizer = None
 
76
  if is_llama_derived_model and cfg.flash_attention:
77
  if cfg.device not in ["mps", "cpu"] and inference is False:
78
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
79
+
80
  logging.info("patching with flash attention")
81
  replace_llama_attn_with_flash_attn()
82
 
83
+ torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
84
  try:
85
  if cfg.load_4bit:
86
+ from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
87
+ replace_peft_model_with_int4_lora_model,
88
+ )
89
+
90
  replace_peft_model_with_int4_lora_model()
91
 
92
  from peft import (
 
105
  from huggingface_hub import snapshot_download
106
 
107
  cache_model_path = Path(snapshot_download(base_model))
108
+ files = (
109
+ list(cache_model_path.glob("*.pt"))
110
+ + list(cache_model_path.glob("*.safetensors"))
111
+ + list(cache_model_path.glob("*.bin"))
112
+ )
113
  if len(files) > 0:
114
  model_path = str(files[0])
115
  else:
116
+ logging.warning(
117
+ "unable to find a cached model file, this will likely fail..."
118
+ )
119
  model_path = str(cache_model_path)
120
  model, tokenizer = load_llama_model_4bit_low_ram(
121
  base_model_config if base_model_config else base_model,
122
  model_path,
123
  device_map=cfg.device_map,
124
  groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
125
+ is_v1_model=cfg.gptq_model_v1
126
+ if cfg.gptq_model_v1 is not None
127
+ else True,
128
  )
129
  load_in_8bit = False
130
  elif is_llama_derived_model:
 
141
  torch_dtype=torch_dtype,
142
  device_map=cfg.device_map,
143
  )
144
+ except Exception as e:
145
+ logging.error(
146
+ "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
147
+ )
148
+ logging.exception(e)
149
  model = AutoModelForCausalLM.from_pretrained(
150
  base_model,
151
  load_in_8bit=cfg.load_in_8bit,
 
170
  if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
171
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
172
 
 
173
  if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
174
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
175
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
189
  )
190
 
191
  if cfg.lora_model_dir:
192
+ model = PeftModel.from_pretrained(
193
+ model,
194
+ cfg.lora_model_dir,
195
+ device_map=cfg.device_map,
196
+ torch_dtype=torch.float16,
197
+ )
198
  else:
199
  model = get_peft_model(model, lora_config)
200
 
 
203
 
204
  if cfg.load_4bit:
205
  # Scales to half
206
+ logging.info("Fitting 4bit scales and zeros to half")
207
  for n, m in model.named_modules():
208
+ if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
209
+ type(m)
210
+ ):
211
  if hasattr(m, "is_v1_model") and m.is_v1_model:
212
  m.zeros = m.zeros.half()
213
  m.scales = m.scales.half()
 
267
 
268
 
269
  def do_inference(cfg, model, tokenizer):
270
+ tokenizer.add_special_tokens({"unk_token": "<unk>"})
271
+ tokenizer.add_special_tokens({"bos_token": "<s>"})
272
+ tokenizer.add_special_tokens({"eos_token": "</s>"})
273
 
274
  instruction = "Tell me a joke about dromedaries."
275
  input = ""
276
+ prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(
277
+ instruction=instruction, input=input
278
+ )
279
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
280
 
281
  model.eval()
282
  with torch.no_grad():
283
  # gc = GenerationConfig() # TODO swap out and use this
284
+ generated = model.generate(
285
+ inputs=batch["input_ids"].to("cuda"),
286
+ do_sample=True,
287
+ use_cache=True,
288
+ repetition_penalty=1.1,
289
+ max_new_tokens=100,
290
+ temperature=0.9,
291
+ top_p=0.95,
292
+ top_k=40,
293
+ return_dict_in_generate=True,
294
+ output_attentions=False,
295
+ output_hidden_states=False,
296
+ output_scores=False,
297
+ )
298
+ print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
299
 
300
 
301
  def choose_config(path: Path):
302
  yaml_files = [file for file in path.glob("*.yml")]
303
 
304
  if not yaml_files:
305
+ raise ValueError(
306
+ "No YAML config files found in the specified directory. Are you using a .yml extension?"
307
+ )
308
 
309
  print("Choose a YAML file:")
310
  for idx, file in enumerate(yaml_files):
 
414
 
415
  return trainer
416
 
417
+
418
  def train(
419
  config: Path = Path("configs/"),
420
  prepare_ds_only: bool = False,
 
459
  # Load the model and tokenizer
460
  logging.info("loading model, tokenizer, and lora_config...")
461
  model, tokenizer, lora_config = load_model(
462
+ cfg.base_model,
463
+ cfg.base_model_config,
464
+ cfg.model_type,
465
+ cfg.tokenizer_type,
466
+ cfg,
467
+ adapter=cfg.adapter,
468
+ inference=("inference" in kwargs),
469
  )
470
 
471
  if "inference" in kwargs:
 
473
  do_inference(cfg, model, tokenizer)
474
  return
475
 
476
+ max_packed_sequence_len = (
477
+ cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
478
+ )
479
+ max_packed_sequence_len = min(
480
+ max_packed_sequence_len, cfg.sequence_len
481
+ ) # make sure we don't accidentally set it larger than sequence_len
482
+ ds_hash = str(
483
+ md5(
484
+ (
485
+ str(max_packed_sequence_len)
486
+ + "@"
487
+ + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
488
+ ).encode("utf-8")
489
+ ).hexdigest()
490
+ )
491
+ prepared_ds_path = (
492
+ Path(cfg.dataset_prepared_path) / ds_hash
493
+ if cfg.dataset_prepared_path
494
+ else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
495
+ )
496
 
497
  if any(prepared_ds_path.glob("*")):
498
  logging.info("Loading prepared dataset from disk...")
 
525
  )
526
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
527
  datasets.append(ds_wrapper)
528
+ elif d.type == "oasst":
529
+ ds_strategy = OpenAssistantPromptTokenizingStrategy(
530
+ AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
531
+ )
532
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
533
+ datasets.append(ds_wrapper)
534
  elif d.type == "gpteacher":
535
  ds_strategy = GPTeacherPromptTokenizingStrategy(
536
+ GPTeacherPrompter(),
537
+ tokenizer,
538
+ cfg.train_on_inputs,
539
+ cfg.sequence_len,
540
  )
541
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
542
  datasets.append(ds_wrapper)
 
546
  )
547
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
548
  datasets.append(ds_wrapper)
549
+ else:
550
+ logging.error(f"unhandled prompt tokenization strategy: {d.type}")
551
  constant_len_dataset = ConstantLengthDataset(
552
+ tokenizer,
553
+ datasets,
554
+ seq_length=max_packed_sequence_len,
555
  )
556
  logging.info("merging, packing, shuffling, and splitting master dataset")
557
+ dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
558
+ test_size=cfg.val_set_size, shuffle=True, seed=42
559
+ )
560
 
561
  if cfg.local_rank == 0:
562
  logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
 
599
 
600
  if cfg.local_rank == 0:
601
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
602
+ logging.info(
603
+ f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
604
+ )
605
  model.save_pretrained(cfg.output_dir)
606
 
607
 
src/axolotl/prompt_tokenizers.py CHANGED
@@ -31,14 +31,18 @@ class PromptTokenizingStrategy(abc.ABC):
31
  pass
32
 
33
 
34
- class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
 
 
 
35
  def tokenize_prompt(self, prompt):
36
- full_prompt = self._tokenize_full_prompt(prompt)
 
37
  tokenized_full_prompt = self._tokenize(full_prompt)
38
  if not self.train_on_inputs:
39
  user_prompt = self.prompter.build_prompt(
40
- prompt["instruction"],
41
- prompt["input"] if "input" in prompt else "",
42
  )
43
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
44
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
@@ -49,11 +53,11 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
49
 
50
  return tokenized_full_prompt
51
 
52
- def _tokenize_full_prompt(self, prompt):
53
  return self.prompter.build_prompt(
54
- prompt["instruction"],
55
- prompt["input"] if "input" in prompt else "",
56
- prompt["output"],
57
  )
58
 
59
  def _tokenize(self, prompt, add_eos_token=True):
@@ -76,11 +80,29 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
76
  return result
77
 
78
 
79
- class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
80
- def _tokenize_full_prompt(self, prompt):
81
- return self.prompter.build_prompt(
82
  prompt["instruction"],
83
- prompt["input"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  prompt["response"],
85
  )
86
 
 
31
  pass
32
 
33
 
34
+ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
35
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
36
+ raise NotImplementedError
37
+
38
  def tokenize_prompt(self, prompt):
39
+ instruction, input, response = self.parse_instruction_fields(prompt)
40
+ full_prompt = self._build_full_prompt(instruction, input, response)
41
  tokenized_full_prompt = self._tokenize(full_prompt)
42
  if not self.train_on_inputs:
43
  user_prompt = self.prompter.build_prompt(
44
+ instruction,
45
+ input,
46
  )
47
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
48
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
 
53
 
54
  return tokenized_full_prompt
55
 
56
+ def _build_full_prompt(self, instruction, input, response):
57
  return self.prompter.build_prompt(
58
+ instruction,
59
+ input,
60
+ response,
61
  )
62
 
63
  def _tokenize(self, prompt, add_eos_token=True):
 
80
  return result
81
 
82
 
83
+ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
84
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
85
+ return (
86
  prompt["instruction"],
87
+ prompt["input"] if "input" in prompt else "",
88
+ prompt["output"],
89
+ )
90
+
91
+
92
+ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
93
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
94
+ return (
95
+ prompt["INSTRUCTION"],
96
+ "",
97
+ prompt["RESPONSE"],
98
+ )
99
+
100
+
101
+ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
102
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
103
+ return (
104
+ prompt["instruction"],
105
+ prompt["input"] if "input" in prompt else "",
106
  prompt["response"],
107
  )
108