winglian commited on
Commit
34c99f9
1 Parent(s): 259262b

fixes to make qlora actually work

Browse files
src/axolotl/utils/models.py CHANGED
@@ -248,7 +248,7 @@ def load_model(
248
 
249
  if (
250
  (cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
251
- ) and not cfg.load_4bit:
252
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
253
  model = prepare_model_for_int8_training(model)
254
 
@@ -297,7 +297,7 @@ def load_adapter(model, cfg, adapter):
297
 
298
  if adapter is None:
299
  return model, None
300
- if adapter == "lora" or adapter == "qlora":
301
  return load_lora(model, cfg)
302
  if adapter == "llama-adapter":
303
  return load_llama_adapter(model, cfg)
 
248
 
249
  if (
250
  (cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
251
+ ) and not cfg.load_4bit and (load_in_8bit or cfg.load_in_4bit):
252
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
253
  model = prepare_model_for_int8_training(model)
254
 
 
297
 
298
  if adapter is None:
299
  return model, None
300
+ if adapter in ["lora" , "qlora"]:
301
  return load_lora(model, cfg)
302
  if adapter == "llama-adapter":
303
  return load_llama_adapter(model, cfg)
src/axolotl/utils/trainer.py CHANGED
@@ -205,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
205
  )
206
  callbacks.append(early_stop_cb)
207
 
208
- if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0
209
  callbacks.append(SavePeftModelCallback)
210
 
211
  data_collator_kwargs = {
 
205
  )
206
  callbacks.append(early_stop_cb)
207
 
208
+ if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
209
  callbacks.append(SavePeftModelCallback)
210
 
211
  data_collator_kwargs = {