winglian commited on
Commit
c67fb71
1 Parent(s): 25e037f

Peft deepspeed resume (#1227)

Browse files

* import deepspeed integration

* monkeypatch peft adapater with deepspeed for resume from checkpoint

* fix patch

* fix patches attempt 2

* make sure to set lora_model_dir

* skip pylint for deepspeed.utils

* pick up upstream fix in transformers

* remove monkeypatch for deepspeed/peft fix

* no need to set the lora_model_dir on resume

* unset load_in_*bit when using quant config

* guard before del

* better handling of load_in* kwargs

requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
  peft @ git+https://github.com/huggingface/peft.git
4
- transformers==4.37.0
5
  tokenizers==0.15.0
6
  bitsandbytes>=0.41.1
7
  accelerate==0.26.1
 
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
  peft @ git+https://github.com/huggingface/peft.git
4
+ transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
5
  tokenizers==0.15.0
6
  bitsandbytes>=0.41.1
7
  accelerate==0.26.1
src/axolotl/cli/train.py CHANGED
@@ -6,8 +6,9 @@ from pathlib import Path
6
  from typing import Tuple
7
 
8
  import fire
9
- import transformers
10
- from transformers import PreTrainedModel, PreTrainedTokenizer
 
11
 
12
  from axolotl.cli import (
13
  check_accelerate_default_config,
@@ -27,7 +28,7 @@ LOG = logging.getLogger("axolotl.cli.train")
27
  def do_cli(config: Path = Path("examples/"), **kwargs):
28
  # pylint: disable=duplicate-code
29
  parsed_cfg = load_cfg(config, **kwargs)
30
- parser = transformers.HfArgumentParser((TrainerCliArgs))
31
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
32
  return_remaining_strings=True
33
  )
 
6
  from typing import Tuple
7
 
8
  import fire
9
+ from transformers.hf_argparser import HfArgumentParser
10
+ from transformers.modeling_utils import PreTrainedModel
11
+ from transformers.tokenization_utils import PreTrainedTokenizer
12
 
13
  from axolotl.cli import (
14
  check_accelerate_default_config,
 
28
  def do_cli(config: Path = Path("examples/"), **kwargs):
29
  # pylint: disable=duplicate-code
30
  parsed_cfg = load_cfg(config, **kwargs)
31
+ parser = HfArgumentParser((TrainerCliArgs))
32
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
33
  return_remaining_strings=True
34
  )
src/axolotl/train.py CHANGED
@@ -57,6 +57,21 @@ def train(
57
  eval_dataset = dataset_meta.eval_dataset
58
  total_num_steps = dataset_meta.total_num_steps
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # Load the model and tokenizer
61
  msg = "loading model"
62
  if cfg.adapter:
@@ -79,21 +94,6 @@ def train(
79
 
80
  safe_serialization = cfg.save_safetensors is True
81
 
82
- if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
83
- possible_checkpoints = [
84
- str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
85
- ]
86
- if len(possible_checkpoints) > 0:
87
- sorted_paths = sorted(
88
- possible_checkpoints,
89
- key=lambda path: int(path.split("-")[-1]),
90
- )
91
- cfg.resume_from_checkpoint = sorted_paths[-1]
92
- LOG.info(
93
- f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
94
- )
95
- resume_from_checkpoint = cfg.resume_from_checkpoint
96
-
97
  if cfg.unfrozen_parameters:
98
  freeze_parameters_except(model, cfg.unfrozen_parameters)
99
 
 
57
  eval_dataset = dataset_meta.eval_dataset
58
  total_num_steps = dataset_meta.total_num_steps
59
 
60
+ if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
61
+ possible_checkpoints = [
62
+ str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
63
+ ]
64
+ if len(possible_checkpoints) > 0:
65
+ sorted_paths = sorted(
66
+ possible_checkpoints,
67
+ key=lambda path: int(path.split("-")[-1]),
68
+ )
69
+ cfg.resume_from_checkpoint = sorted_paths[-1]
70
+ LOG.info(
71
+ f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
72
+ )
73
+ resume_from_checkpoint = cfg.resume_from_checkpoint
74
+
75
  # Load the model and tokenizer
76
  msg = "loading model"
77
  if cfg.adapter:
 
94
 
95
  safe_serialization = cfg.save_safetensors is True
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  if cfg.unfrozen_parameters:
98
  freeze_parameters_except(model, cfg.unfrozen_parameters)
99
 
src/axolotl/utils/models.py CHANGED
@@ -473,6 +473,18 @@ def load_model(
473
  **bnb_config,
474
  )
475
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  # sample packing uses custom FA2 patch
477
  if cfg.flash_attention:
478
  if not cfg.sample_packing:
@@ -506,8 +518,6 @@ def load_model(
506
  model = LlamaForCausalLM.from_pretrained(
507
  base_model,
508
  config=model_config,
509
- load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
510
- load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
511
  **model_kwargs,
512
  )
513
 
@@ -575,8 +585,6 @@ def load_model(
575
  model = getattr(transformers, model_type).from_pretrained(
576
  base_model,
577
  config=model_config,
578
- load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
579
- load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
580
  trust_remote_code=cfg.trust_remote_code or False,
581
  **model_kwargs,
582
  )
@@ -608,8 +616,6 @@ def load_model(
608
  model = AutoModelForCausalLM.from_pretrained(
609
  base_model,
610
  config=model_config,
611
- load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
612
- load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
613
  trust_remote_code=cfg.trust_remote_code or False,
614
  **model_kwargs,
615
  )
@@ -678,7 +684,9 @@ def load_model(
678
  skip_prepare_model_for_kbit_training = False
679
 
680
  if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
681
- from deepspeed.utils import set_z3_leaf_modules
 
 
682
  from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
683
 
684
  set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
 
473
  **bnb_config,
474
  )
475
 
476
+ if cfg.load_in_8bit and cfg.adapter is not None:
477
+ model_kwargs["load_in_8bit"] = True
478
+ if cfg.load_in_4bit and cfg.adapter is not None:
479
+ model_kwargs["load_in_4bit"] = True
480
+
481
+ # no longer needed per https://github.com/huggingface/transformers/pull/26610
482
+ if "quantization_config" in model_kwargs or cfg.gptq:
483
+ if "load_in_8bit" in model_kwargs:
484
+ del model_kwargs["load_in_8bit"]
485
+ if "load_in_4bit" in model_kwargs:
486
+ del model_kwargs["load_in_4bit"]
487
+
488
  # sample packing uses custom FA2 patch
489
  if cfg.flash_attention:
490
  if not cfg.sample_packing:
 
518
  model = LlamaForCausalLM.from_pretrained(
519
  base_model,
520
  config=model_config,
 
 
521
  **model_kwargs,
522
  )
523
 
 
585
  model = getattr(transformers, model_type).from_pretrained(
586
  base_model,
587
  config=model_config,
 
 
588
  trust_remote_code=cfg.trust_remote_code or False,
589
  **model_kwargs,
590
  )
 
616
  model = AutoModelForCausalLM.from_pretrained(
617
  base_model,
618
  config=model_config,
 
 
619
  trust_remote_code=cfg.trust_remote_code or False,
620
  **model_kwargs,
621
  )
 
684
  skip_prepare_model_for_kbit_training = False
685
 
686
  if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
687
+ from deepspeed.utils import ( # pylint: disable=no-name-in-module
688
+ set_z3_leaf_modules,
689
+ )
690
  from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
691
 
692
  set_z3_leaf_modules(model, [MixtralSparseMoeBlock])