winglian commited on
Commit
77085ea
1 Parent(s): db2a358

qlora w flash attention fixes (#333)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +8 -0
src/axolotl/utils/models.py CHANGED
@@ -407,6 +407,14 @@ def load_llama_adapter(model, cfg):
407
  else:
408
  model = get_peft_model(model, peft_config)
409
 
 
 
 
 
 
 
 
 
410
  model.print_trainable_parameters()
411
 
412
  return model, peft_config
 
407
  else:
408
  model = get_peft_model(model, peft_config)
409
 
410
+ if cfg.flash_attention:
411
+ for name, module in model.named_modules():
412
+ if "norm" in name:
413
+ module.to(torch.float16)
414
+ if "lm_head" in name or "embed_tokens" in name:
415
+ if hasattr(module, "weight"):
416
+ module.to(torch.float16)
417
+
418
  model.print_trainable_parameters()
419
 
420
  return model, peft_config