Maxime commited on
Commit
f311df9
1 Parent(s): c500d02

fix: finetune model inference needs the dtype fix to work with flash-attn

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +12 -8
src/axolotl/utils/models.py CHANGED
@@ -355,6 +355,7 @@ def load_model(
355
  if hasattr(module, "weight"):
356
  module.to(torch.float32)
357
 
 
358
  if not cfg.gptq and (
359
  (cfg.adapter == "lora" and load_in_8bit)
360
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -363,16 +364,19 @@ def load_model(
363
  model = prepare_model_for_kbit_training(
364
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
365
  )
 
366
 
367
- # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
368
- # convert them back to fp16/bf16 for flash-attn compatibility.
369
- if cfg.flash_attention and cfg.is_llama_derived_model:
370
- for name, module in model.named_modules():
371
- if "norm" in name:
 
 
 
 
 
372
  module.to(cfg.torch_dtype)
373
- if "lm_head" in name or "embed_tokens" in name:
374
- if hasattr(module, "weight"):
375
- module.to(cfg.torch_dtype)
376
 
377
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
378
 
 
355
  if hasattr(module, "weight"):
356
  module.to(torch.float32)
357
 
358
+ fix_dtype = False
359
  if not cfg.gptq and (
360
  (cfg.adapter == "lora" and load_in_8bit)
361
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
 
364
  model = prepare_model_for_kbit_training(
365
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
366
  )
367
+ fix_dtype = True
368
 
369
+ # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
370
+ # convert them back to fp16/bf16 for flash-attn compatibility.
371
+ if (fix_dtype or cfg.adapter == "" or cfg.adapter == None) and (
372
+ cfg.flash_attention and cfg.is_llama_derived_model
373
+ ):
374
+ for name, module in model.named_modules():
375
+ if "norm" in name:
376
+ module.to(cfg.torch_dtype)
377
+ if "lm_head" in name or "embed_tokens" in name:
378
+ if hasattr(module, "weight"):
379
  module.to(cfg.torch_dtype)
 
 
 
380
 
381
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
382