winglian commited on
Commit
bbfc333
2 Parent(s): 3815c05 a5bf838

Merge pull request #62 from OpenAccess-AI-Collective/qlora-fixes

Browse files
README.md CHANGED
@@ -24,7 +24,7 @@
24
 
25
  ## Quickstart ⚡
26
 
27
- **Requirements**: Python 3.9.
28
 
29
  ```bash
30
  git clone https://github.com/OpenAccess-AI-Collective/axolotl
@@ -45,7 +45,7 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
45
 
46
  ### Environment
47
 
48
- - Docker
49
  ```bash
50
  docker run --gpus '"all"' --rm -it winglian/axolotl:main
51
  ```
@@ -334,7 +334,7 @@ strict:
334
 
335
  ### Accelerate
336
 
337
- Configure accelerate
338
 
339
  ```bash
340
  accelerate config
@@ -368,7 +368,7 @@ Pass the appropriate flag to the train command:
368
  Add below flag to train command above
369
 
370
  ```bash
371
- --merge_lora --lora_model_dir="./completed-model"
372
  ```
373
 
374
  ## Common Errors 🧰
@@ -389,7 +389,7 @@ Try set `fp16: true`
389
  Try to turn off xformers.
390
 
391
  ## Need help? 🙋‍♂️
392
-
393
  Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
394
 
395
  ## Contributing 🤝
 
24
 
25
  ## Quickstart ⚡
26
 
27
+ **Requirements**: Python 3.9.
28
 
29
  ```bash
30
  git clone https://github.com/OpenAccess-AI-Collective/axolotl
 
45
 
46
  ### Environment
47
 
48
+ - Docker
49
  ```bash
50
  docker run --gpus '"all"' --rm -it winglian/axolotl:main
51
  ```
 
334
 
335
  ### Accelerate
336
 
337
+ Configure accelerate
338
 
339
  ```bash
340
  accelerate config
 
368
  Add below flag to train command above
369
 
370
  ```bash
371
+ --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
372
  ```
373
 
374
  ## Common Errors 🧰
 
389
  Try to turn off xformers.
390
 
391
  ## Need help? 🙋‍♂️
392
+
393
  Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
394
 
395
  ## Contributing 🤝
scripts/finetune.py CHANGED
@@ -176,6 +176,7 @@ def train(
176
  if "merge_lora" in kwargs and cfg.adapter is not None:
177
  logging.info("running merge of LoRA with base model")
178
  model = model.merge_and_unload()
 
179
 
180
  if cfg.local_rank == 0:
181
  logging.info("saving merged model")
 
176
  if "merge_lora" in kwargs and cfg.adapter is not None:
177
  logging.info("running merge of LoRA with base model")
178
  model = model.merge_and_unload()
179
+ model.to(dtype=torch.float16)
180
 
181
  if cfg.local_rank == 0:
182
  logging.info("saving merged model")
src/axolotl/utils/data.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  from hashlib import md5
3
  from pathlib import Path
 
4
 
5
  from datasets import (
6
  load_from_disk,
@@ -80,7 +81,7 @@ def load_tokenized_prepared_datasets(
80
  logging.info("Loading raw datasets...")
81
  datasets = []
82
  for d in cfg.datasets:
83
- ds = None
84
  ds_from_hub = False
85
  try:
86
  load_dataset(d.path, streaming=True, use_auth_token=True)
@@ -90,36 +91,32 @@ def load_tokenized_prepared_datasets(
90
 
91
  # prefer local dataset, even if hub exists
92
  if Path(d.path).exists():
93
- ds: IterableDataset = load_dataset(
94
  "json", data_files=d.path, streaming=False, split=None
95
  )
96
  elif ds_from_hub:
97
  if d.data_files:
98
- ds = load_dataset(
99
  d.path,
100
  streaming=False,
101
  data_files=d.data_files,
102
  use_auth_token=True,
103
  )
104
  else:
105
- ds = load_dataset(d.path, streaming=False, use_auth_token=True)
106
  else:
107
  fp = hf_hub_download(
108
  repo_id=d.path, repo_type="dataset", filename=d.data_files
109
  )
110
- ds = load_dataset("json", data_files=fp, streaming=False, split=None)
111
  if not ds:
112
  raise Exception("unhandled dataset load")
113
  # support for using a subset of the data
114
  if d.shards:
115
- <<<<<<< Updated upstream
116
- ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
117
- =======
118
  if "train" in ds:
119
- ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
120
  else:
121
- ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0)
122
- >>>>>>> Stashed changes
123
  d_type = d.type
124
  d_type_split = d_type.split(":")
125
  d_base_type = d_type_split[0]
 
1
  import logging
2
  from hashlib import md5
3
  from pathlib import Path
4
+ from typing import Union
5
 
6
  from datasets import (
7
  load_from_disk,
 
81
  logging.info("Loading raw datasets...")
82
  datasets = []
83
  for d in cfg.datasets:
84
+ ds: Union[Dataset, DatasetDict] = None
85
  ds_from_hub = False
86
  try:
87
  load_dataset(d.path, streaming=True, use_auth_token=True)
 
91
 
92
  # prefer local dataset, even if hub exists
93
  if Path(d.path).exists():
94
+ ds: Dataset = load_dataset(
95
  "json", data_files=d.path, streaming=False, split=None
96
  )
97
  elif ds_from_hub:
98
  if d.data_files:
99
+ ds: Dataset = load_dataset(
100
  d.path,
101
  streaming=False,
102
  data_files=d.data_files,
103
  use_auth_token=True,
104
  )
105
  else:
106
+ ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=True)
107
  else:
108
  fp = hf_hub_download(
109
  repo_id=d.path, repo_type="dataset", filename=d.data_files
110
  )
111
+ ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None)
112
  if not ds:
113
  raise Exception("unhandled dataset load")
114
  # support for using a subset of the data
115
  if d.shards:
 
 
 
116
  if "train" in ds:
117
+ ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
118
  else:
119
+ ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
 
120
  d_type = d.type
121
  d_type_split = d_type.split(":")
122
  d_base_type = d_type_split[0]
src/axolotl/utils/models.py CHANGED
@@ -85,7 +85,7 @@ def load_model(
85
  raise e
86
 
87
  model_kwargs = {}
88
- if cfg.adapter == "qlora":
89
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
90
  load_in_4bit=True,
91
  llm_int8_threshold=6.0,
@@ -247,8 +247,10 @@ def load_model(
247
  model.resize_token_embeddings(embeddings_len)
248
 
249
  if (
250
- (cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
251
- ) and not cfg.load_4bit:
 
 
252
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
253
  model = prepare_model_for_int8_training(model)
254
 
@@ -297,7 +299,7 @@ def load_adapter(model, cfg, adapter):
297
 
298
  if adapter is None:
299
  return model, None
300
- if adapter == "lora" or adapter == "qlora":
301
  return load_lora(model, cfg)
302
  if adapter == "llama-adapter":
303
  return load_llama_adapter(model, cfg)
 
85
  raise e
86
 
87
  model_kwargs = {}
88
+ if cfg.adapter == "qlora" and cfg.load_in_4bit:
89
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
90
  load_in_4bit=True,
91
  llm_int8_threshold=6.0,
 
247
  model.resize_token_embeddings(embeddings_len)
248
 
249
  if (
250
+ ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
251
+ and not cfg.load_4bit
252
+ and (load_in_8bit or cfg.load_in_4bit)
253
+ ):
254
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
255
  model = prepare_model_for_int8_training(model)
256
 
 
299
 
300
  if adapter is None:
301
  return model, None
302
+ if adapter in ["lora", "qlora"]:
303
  return load_lora(model, cfg)
304
  if adapter == "llama-adapter":
305
  return load_llama_adapter(model, cfg)
src/axolotl/utils/trainer.py CHANGED
@@ -205,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
205
  )
206
  callbacks.append(early_stop_cb)
207
 
208
- if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0
209
  callbacks.append(SavePeftModelCallback)
210
 
211
  data_collator_kwargs = {
 
205
  )
206
  callbacks.append(early_stop_cb)
207
 
208
+ if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
209
  callbacks.append(SavePeftModelCallback)
210
 
211
  data_collator_kwargs = {
src/axolotl/utils/validation.py CHANGED
@@ -1,9 +1,20 @@
 
 
 
1
  def validate_config(cfg):
2
  if cfg.adapter == "qlora":
3
- assert cfg.load_in_8bit is False
4
- assert cfg.load_4bit is False
5
- assert cfg.load_in_4bit is True
6
- pass
 
 
 
 
 
 
 
 
7
  # TODO
8
  # MPT 7b
9
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
1
+ import logging
2
+
3
+
4
  def validate_config(cfg):
5
  if cfg.adapter == "qlora":
6
+ if cfg.merge_lora:
7
+ # can't merge qlora if loaded in 8bit or 4bit
8
+ assert cfg.load_in_8bit is False
9
+ assert cfg.load_4bit is False
10
+ assert cfg.load_in_4bit is False
11
+ else:
12
+ assert cfg.load_in_8bit is False
13
+ assert cfg.load_4bit is False
14
+ assert cfg.load_in_4bit is True
15
+ if cfg.load_in_8bit and cfg.adapter == "lora":
16
+ logging.warning("we recommend setting `load_in_8bit: true`")
17
+
18
  # TODO
19
  # MPT 7b
20
  # https://github.com/facebookresearch/bitsandbytes/issues/25