tmm1 commited on
Commit
248bf90
1 Parent(s): 77085ea

ensure flash-attn fixes happen in both adapter/lora modes, and use torch_dtype

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +8 -8
src/axolotl/utils/models.py CHANGED
@@ -331,6 +331,14 @@ def load_model(
331
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
332
  )
333
 
 
 
 
 
 
 
 
 
334
  model, lora_config = load_adapter(model, cfg, adapter)
335
 
336
  if cfg.ddp and not load_in_8bit:
@@ -407,14 +415,6 @@ def load_llama_adapter(model, cfg):
407
  else:
408
  model = get_peft_model(model, peft_config)
409
 
410
- if cfg.flash_attention:
411
- for name, module in model.named_modules():
412
- if "norm" in name:
413
- module.to(torch.float16)
414
- if "lm_head" in name or "embed_tokens" in name:
415
- if hasattr(module, "weight"):
416
- module.to(torch.float16)
417
-
418
  model.print_trainable_parameters()
419
 
420
  return model, peft_config
 
331
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
332
  )
333
 
334
+ if cfg.flash_attention:
335
+ for name, module in model.named_modules():
336
+ if "norm" in name:
337
+ module.to(torch_dtype)
338
+ if "lm_head" in name or "embed_tokens" in name:
339
+ if hasattr(module, "weight"):
340
+ module.to(torch_dtype)
341
+
342
  model, lora_config = load_adapter(model, cfg, adapter)
343
 
344
  if cfg.ddp and not load_in_8bit:
 
415
  else:
416
  model = get_peft_model(model, peft_config)
417
 
 
 
 
 
 
 
 
 
418
  model.print_trainable_parameters()
419
 
420
  return model, peft_config