tmm1 commited on
Commit
f144e98
2 Parent(s): 3513071 3a011ea

Merge pull request #485 from maximegmd/patch-4

Browse files

fix: finetune model inference needs the dtype fix to work with flash-attn

Files changed (1) hide show
  1. src/axolotl/utils/models.py +12 -9
src/axolotl/utils/models.py CHANGED
@@ -355,6 +355,7 @@ def load_model(
355
  if hasattr(module, "weight"):
356
  module.to(torch.float32)
357
 
 
358
  if not cfg.gptq and (
359
  (cfg.adapter == "lora" and load_in_8bit)
360
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -363,16 +364,18 @@ def load_model(
363
  model = prepare_model_for_kbit_training(
364
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
365
  )
366
-
367
- # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
368
- # convert them back to fp16/bf16 for flash-attn compatibility.
369
- if cfg.flash_attention and cfg.is_llama_derived_model:
370
- for name, module in model.named_modules():
371
- if "norm" in name:
 
 
 
 
 
372
  module.to(cfg.torch_dtype)
373
- if "lm_head" in name or "embed_tokens" in name:
374
- if hasattr(module, "weight"):
375
- module.to(cfg.torch_dtype)
376
 
377
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
378
 
 
355
  if hasattr(module, "weight"):
356
  module.to(torch.float32)
357
 
358
+ needs_fa2_dtype = cfg.adapter is not None
359
  if not cfg.gptq and (
360
  (cfg.adapter == "lora" and load_in_8bit)
361
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
 
364
  model = prepare_model_for_kbit_training(
365
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
366
  )
367
+ needs_fa2_dtype = True
368
+
369
+ # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
370
+ # convert them back to fp16/bf16 for flash-attn compatibility.
371
+ if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model):
372
+ LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
373
+ for name, module in model.named_modules():
374
+ if "norm" in name:
375
+ module.to(cfg.torch_dtype)
376
+ if "lm_head" in name or "embed_tokens" in name:
377
+ if hasattr(module, "weight"):
378
  module.to(cfg.torch_dtype)
 
 
 
379
 
380
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
381