winglian commited on
Commit
78c5b19
1 Parent(s): 23495a8

add gptneox embeddings, fix phi2 inputs, also fix the casting (#1083)

Browse files
src/axolotl/utils/lora_embeddings.py CHANGED
@@ -8,5 +8,7 @@ def get_linear_embedding_layers(model_type):
8
  returns the linear embedding layers needed for loras, dependent on the model arch
9
  """
10
  if model_type == "phi-msft":
11
- return ["embd", "lm_head.linear"]
12
- return ["lm_head", "embed_tokens"]
 
 
 
8
  returns the linear embedding layers needed for loras, dependent on the model arch
9
  """
10
  if model_type == "phi-msft":
11
+ return ["embd.wte", "lm_head.linear"]
12
+ if model_type == "gpt_neox":
13
+ return ["embed_in", "embed_out"]
14
+ return ["embed_tokens", "lm_head"]
src/axolotl/utils/models.py CHANGED
@@ -588,13 +588,14 @@ def load_model(
588
  log_gpu_memory_usage(LOG, "after model load", model.device)
589
 
590
  # make sure these are fp32 per Ramesh et al. (2021)
 
591
  for name, module in model.named_modules():
592
  if "norm" in name:
593
  module.to(torch.float32)
594
  if model_config.model_type == "btlm":
595
  # don't upcast lm_head for btlm
596
  continue
597
- if "lm_head" in name or "embed_tokens" in name:
598
  if hasattr(module, "weight"):
599
  module.to(torch.float32)
600
 
@@ -619,15 +620,12 @@ def load_model(
619
 
620
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
621
  # convert them back to fp16/bf16 for flash-attn compatibility.
622
- if needs_fa2_dtype or (
623
- cfg.flash_attention
624
- and (cfg.is_llama_derived_model or cfg.is_mistral_derived_model)
625
- ):
626
  LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
627
  for name, module in model.named_modules():
628
  if "norm" in name:
629
  module.to(cfg.torch_dtype)
630
- if "lm_head" in name or "embed_tokens" in name:
631
  if hasattr(module, "weight"):
632
  module.to(cfg.torch_dtype)
633
 
 
588
  log_gpu_memory_usage(LOG, "after model load", model.device)
589
 
590
  # make sure these are fp32 per Ramesh et al. (2021)
591
+ embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
592
  for name, module in model.named_modules():
593
  if "norm" in name:
594
  module.to(torch.float32)
595
  if model_config.model_type == "btlm":
596
  # don't upcast lm_head for btlm
597
  continue
598
+ if any(m in name for m in embedding_modules):
599
  if hasattr(module, "weight"):
600
  module.to(torch.float32)
601
 
 
620
 
621
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
622
  # convert them back to fp16/bf16 for flash-attn compatibility.
623
+ if needs_fa2_dtype or cfg.flash_attention:
 
 
 
624
  LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
625
  for name, module in model.named_modules():
626
  if "norm" in name:
627
  module.to(cfg.torch_dtype)
628
+ if any(m in name for m in embedding_modules):
629
  if hasattr(module, "weight"):
630
  module.to(cfg.torch_dtype)
631
 
tests/core/test_trainer_builder.py CHANGED
@@ -30,6 +30,7 @@ def fixture_cfg():
30
  "adam_epsilon": 0.00001,
31
  "dataloader_num_workers": 1,
32
  "dataloader_pin_memory": True,
 
33
  }
34
  )
35
 
 
30
  "adam_epsilon": 0.00001,
31
  "dataloader_num_workers": 1,
32
  "dataloader_pin_memory": True,
33
+ "model_config_type": "llama",
34
  }
35
  )
36
 
tests/test_validation.py CHANGED
@@ -770,7 +770,7 @@ class ValidationCheckModelConfig(BaseValidation):
770
  "adapter": "qlora",
771
  "load_in_4bit": True,
772
  "tokens": ["<|imstart|>"],
773
- "lora_modules_to_save": ["embd", "lm_head.linear"],
774
  }
775
  )
776
 
 
770
  "adapter": "qlora",
771
  "load_in_4bit": True,
772
  "tokens": ["<|imstart|>"],
773
+ "lora_modules_to_save": ["embd.wte", "lm_head.linear"],
774
  }
775
  )
776