winglian commited on
Commit
87e073d
1 Parent(s): 4131183

fix lora target module, require explicit flash attention, fix min logging steps, don't use adam8bit for int4, hash prepared datasets, support hf hub datasets

Browse files
configs/llama_65B_alpaca.yml CHANGED
@@ -21,7 +21,7 @@ lora_alpha: 16
21
  lora_dropout: 0.05
22
  lora_target_modules:
23
  - q_proj
24
- - w_proj
25
  lora_fan_in_fan_out: false
26
  wandb_project: llama-65b-lora
27
  wandb_watch:
 
21
  lora_dropout: 0.05
22
  lora_target_modules:
23
  - q_proj
24
+ - v_proj
25
  lora_fan_in_fan_out: false
26
  wandb_project: llama-65b-lora
27
  wandb_watch:
configs/llama_7B_4bit.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: decapoda-research/llama-7b-hf-int4
2
+ base_model_config: decapoda-research/llama-7b-hf
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: LlamaTokenizer
5
+ load_in_8bit: true
6
+ datasets:
7
+ - path: vicgalle/alpaca-gpt4
8
+ type: alpaca
9
+ dataset_prepared_path: data/last_run_prepared
10
+ val_set_size: 0.04
11
+ adapter: lora
12
+ lora_model_dir:
13
+ sequence_len: 2048
14
+ max_packed_sequence_len: 1024
15
+ lora_r: 8
16
+ lora_alpha: 16
17
+ lora_dropout: 0.05
18
+ lora_target_modules:
19
+ - q_proj
20
+ - v_proj
21
+ # - k_proj
22
+ # - o_proj
23
+ lora_fan_in_fan_out: false
24
+ wandb_project:
25
+ wandb_watch:
26
+ wandb_run_id:
27
+ wandb_log_model: checkpoint
28
+ output_dir: ./lora-test
29
+ batch_size: 8
30
+ micro_batch_size: 2
31
+ num_epochs: 3
32
+ learning_rate: 0.00003
33
+ train_on_inputs: false
34
+ group_by_length: false
35
+ bf16: true
36
+ tf32: true
37
+ gradient_checkpointing: false
38
+ early_stopping_patience: 3
39
+ resume_from_checkpoint:
40
+ local_rank:
41
+ load_4bit: true
configs/llama_7B_alpaca.yml CHANGED
@@ -21,7 +21,7 @@ lora_alpha: 16
21
  lora_dropout: 0.05
22
  lora_target_modules:
23
  - q_proj
24
- - w_proj
25
  lora_fan_in_fan_out: false
26
  wandb_project: llama-7b-lora
27
  wandb_watch:
 
21
  lora_dropout: 0.05
22
  lora_target_modules:
23
  - q_proj
24
+ - v_proj
25
  lora_fan_in_fan_out: false
26
  wandb_project: llama-7b-lora
27
  wandb_watch:
scripts/finetune.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import random
5
  import signal
6
  import sys
 
7
  from pathlib import Path
8
 
9
  import bitsandbytes as bnb
@@ -13,6 +14,7 @@ import transformers
13
  import yaml
14
  from attrdict import AttrDefault
15
  from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
 
16
  from torch import nn
17
  from transformers import (
18
  AutoModelForCausalLM,
@@ -20,6 +22,7 @@ from transformers import (
20
  LlamaForCausalLM,
21
  LlamaTokenizer,
22
  EarlyStoppingCallback,
 
23
  )
24
 
25
  # add src to the pythonpath so we don't need to pip install this
@@ -43,7 +46,7 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
43
 
44
 
45
  def setup_wandb_env_vars(cfg):
46
- if len(cfg.wandb_project) > 0:
47
  os.environ["WANDB_PROJECT"] = cfg.wandb_project
48
  cfg.use_wandb = True
49
  if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
@@ -61,7 +64,7 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
61
 
62
  if adapter != "lora":
63
  raise NotImplementedError(f"{adapter} peft adapter not available")
64
- if "llama" in base_model:
65
  if cfg.device not in ["mps", "cpu"] and inference is False:
66
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
67
  replace_llama_attn_with_flash_attn()
@@ -138,11 +141,12 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
138
  if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
139
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
140
 
 
141
  if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
142
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
143
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
144
 
145
- if load_in_8bit:
146
  model = prepare_model_for_int8_training(model)
147
 
148
  lora_config = LoraConfig(
@@ -227,14 +231,19 @@ def check_dataset_labels(dataset, tokenizer):
227
 
228
 
229
  def do_inference(cfg, model, tokenizer):
 
 
 
 
230
  instruction = "Tell me a joke about dromedaries."
231
  input = ""
232
  prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(instruction=instruction, input=input)
233
- batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
234
 
235
  model.eval()
236
  with torch.no_grad():
237
- generated = model.generate(inputs=batch["input_ids"],
 
238
  do_sample=True, use_cache=True,
239
  repetition_penalty=1.1,
240
  max_new_tokens=100,
@@ -277,7 +286,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
277
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
278
  )
279
  warmup_steps = min(int(0.03 * total_num_steps), 100)
280
- logging_steps = min(int(0.005 * total_num_steps), 10)
281
  save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
282
 
283
  training_arguments_kwargs = {}
@@ -325,21 +334,24 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
325
  },
326
  ]
327
 
328
- adam_bnb_optim = bnb.optim.Adam8bit(
329
- optimizer_grouped_parameters,
330
- betas=(training_args.adam_beta1, training_args.adam_beta2),
331
- eps=training_args.adam_epsilon,
332
- lr=training_args.learning_rate,
333
- )
334
 
335
- # TODO optionally use torch.optim.OneCycleLR
336
- lr_scheduler = transformers.get_cosine_schedule_with_warmup(
337
- adam_bnb_optim,
338
- training_args.warmup_steps,
339
- total_num_steps,
340
- )
 
 
 
 
 
 
 
 
 
341
 
342
- trainer_kwargs = {}
343
  if cfg.early_stopping_patience:
344
  early_stop_cb = EarlyStoppingCallback(
345
  cfg.early_stopping_patience,
@@ -351,7 +363,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
351
  train_dataset=train_dataset,
352
  eval_dataset=eval_dataset,
353
  args=training_args,
354
- optimizers=(adam_bnb_optim, lr_scheduler),
355
  data_collator=transformers.DataCollatorForSeq2Seq(
356
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
357
  ),
@@ -412,7 +423,11 @@ def train(
412
  do_inference(cfg, model, tokenizer)
413
  return
414
 
415
- if cfg.dataset_prepared_path and any(Path(cfg.dataset_prepared_path).glob("*")):
 
 
 
 
416
  logging.info("Loading prepared dataset from disk...")
417
  dataset = load_from_disk(cfg.dataset_prepared_path)
418
  logging.info("Prepared dataset loaded from disk...")
@@ -420,13 +435,20 @@ def train(
420
  logging.info("Loading raw datasets...")
421
  datasets = []
422
  for d in cfg.datasets:
 
 
 
 
 
 
 
 
423
  if Path(d.path).exists():
424
  ds: IterableDataset = load_dataset(
425
  "json", data_files=d.path, streaming=True, split=None
426
  )
427
- # elif d.name and d.path:
428
- # # TODO load from huggingface hub, but it only seems to support arrow or parquet atm
429
- # ds = load_dataset(d.path, split=None, data_files=d.name)
430
  else:
431
  raise Exception("unhandled dataset load")
432
 
@@ -449,7 +471,7 @@ def train(
449
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
450
  datasets.append(ds_wrapper)
451
  constant_len_dataset = ConstantLengthDataset(
452
- tokenizer, datasets, seq_length=cfg.sequence_len
453
  )
454
  logging.info("merging, packing, shuffling, and splitting master dataset")
455
  dataset = Dataset.from_list(
@@ -457,11 +479,8 @@ def train(
457
  ).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
458
 
459
  if cfg.local_rank == 0:
460
- logging.info("Saving prepared dataset to disk...")
461
- if cfg.dataset_prepared_path:
462
- dataset.save_to_disk(cfg.dataset_prepared_path)
463
- else:
464
- dataset.save_to_disk(DEFAULT_DATASET_PREPARED_PATH)
465
 
466
  if prepare_ds_only:
467
  logging.info("Finished preparing dataset. Exiting...")
 
4
  import random
5
  import signal
6
  import sys
7
+ from hashlib import md5
8
  from pathlib import Path
9
 
10
  import bitsandbytes as bnb
 
14
  import yaml
15
  from attrdict import AttrDefault
16
  from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
17
+ from huggingface_hub.hf_api import DatasetInfo
18
  from torch import nn
19
  from transformers import (
20
  AutoModelForCausalLM,
 
22
  LlamaForCausalLM,
23
  LlamaTokenizer,
24
  EarlyStoppingCallback,
25
+ GenerationConfig,
26
  )
27
 
28
  # add src to the pythonpath so we don't need to pip install this
 
46
 
47
 
48
  def setup_wandb_env_vars(cfg):
49
+ if cfg.wandb_project and len(cfg.wandb_project) > 0:
50
  os.environ["WANDB_PROJECT"] = cfg.wandb_project
51
  cfg.use_wandb = True
52
  if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
 
64
 
65
  if adapter != "lora":
66
  raise NotImplementedError(f"{adapter} peft adapter not available")
67
+ if "llama" in base_model and cfg.flash_attention:
68
  if cfg.device not in ["mps", "cpu"] and inference is False:
69
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
70
  replace_llama_attn_with_flash_attn()
 
141
  if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
142
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
143
 
144
+
145
  if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
146
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
147
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
148
 
149
+ if load_in_8bit and not cfg.load_4bit:
150
  model = prepare_model_for_int8_training(model)
151
 
152
  lora_config = LoraConfig(
 
231
 
232
 
233
  def do_inference(cfg, model, tokenizer):
234
+ tokenizer.add_special_tokens({'unk_token': '<unk>'})
235
+ tokenizer.add_special_tokens({'bos_token': '<s>'})
236
+ tokenizer.add_special_tokens({'eos_token': '</s>'})
237
+
238
  instruction = "Tell me a joke about dromedaries."
239
  input = ""
240
  prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(instruction=instruction, input=input)
241
+ batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
242
 
243
  model.eval()
244
  with torch.no_grad():
245
+ # gc = GenerationConfig() # TODO swap out and use this
246
+ generated = model.generate(inputs=batch["input_ids"].to("cuda"),
247
  do_sample=True, use_cache=True,
248
  repetition_penalty=1.1,
249
  max_new_tokens=100,
 
286
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
287
  )
288
  warmup_steps = min(int(0.03 * total_num_steps), 100)
289
+ logging_steps = max(min(int(0.005 * total_num_steps), 10), 1)
290
  save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
291
 
292
  training_arguments_kwargs = {}
 
334
  },
335
  ]
336
 
337
+ trainer_kwargs = {}
 
 
 
 
 
338
 
339
+ if cfg.load_in_8bit and not cfg.load_4bit:
340
+ adam_bnb_optim = bnb.optim.Adam8bit(
341
+ optimizer_grouped_parameters,
342
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
343
+ eps=training_args.adam_epsilon,
344
+ lr=training_args.learning_rate,
345
+ )
346
+
347
+ # TODO optionally use torch.optim.OneCycleLR
348
+ lr_scheduler = transformers.get_cosine_schedule_with_warmup(
349
+ adam_bnb_optim,
350
+ training_args.warmup_steps,
351
+ total_num_steps,
352
+ )
353
+ trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
354
 
 
355
  if cfg.early_stopping_patience:
356
  early_stop_cb = EarlyStoppingCallback(
357
  cfg.early_stopping_patience,
 
363
  train_dataset=train_dataset,
364
  eval_dataset=eval_dataset,
365
  args=training_args,
 
366
  data_collator=transformers.DataCollatorForSeq2Seq(
367
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
368
  ),
 
423
  do_inference(cfg, model, tokenizer)
424
  return
425
 
426
+ max_packed_sequence_len = cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
427
+ max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
428
+ ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
429
+ prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
430
+ if any(prepared_ds_path.glob("*")):
431
  logging.info("Loading prepared dataset from disk...")
432
  dataset = load_from_disk(cfg.dataset_prepared_path)
433
  logging.info("Prepared dataset loaded from disk...")
 
435
  logging.info("Loading raw datasets...")
436
  datasets = []
437
  for d in cfg.datasets:
438
+ ds_from_hub = False
439
+ try:
440
+ ds = load_dataset(d.path, streaming=True)
441
+ ds_from_hub = True
442
+ except FileNotFoundError:
443
+ pass
444
+
445
+ # prefer local dataset, even if hub exists
446
  if Path(d.path).exists():
447
  ds: IterableDataset = load_dataset(
448
  "json", data_files=d.path, streaming=True, split=None
449
  )
450
+ elif ds_from_hub:
451
+ ds = load_dataset(d.path, streaming=True)
 
452
  else:
453
  raise Exception("unhandled dataset load")
454
 
 
471
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
472
  datasets.append(ds_wrapper)
473
  constant_len_dataset = ConstantLengthDataset(
474
+ tokenizer, datasets, seq_length=max_packed_sequence_len,
475
  )
476
  logging.info("merging, packing, shuffling, and splitting master dataset")
477
  dataset = Dataset.from_list(
 
479
  ).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
480
 
481
  if cfg.local_rank == 0:
482
+ logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
483
+ dataset.save_to_disk(prepared_ds_path)
 
 
 
484
 
485
  if prepare_ds_only:
486
  logging.info("Finished preparing dataset. Exiting...")