xzuyn commited on
Commit
8487b97
1 Parent(s): 9cd27b2

Add `layers_to_transform` for `lora_config` (#1118)

Browse files
README.md CHANGED
@@ -677,7 +677,8 @@ lora_target_modules:
677
  # - gate_proj
678
  # - down_proj
679
  # - up_proj
680
- lora_target_linear: # If true, will target all linear layers
 
681
 
682
  # If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
683
  # For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
 
677
  # - gate_proj
678
  # - down_proj
679
  # - up_proj
680
+ lora_target_linear: # If true, will target all linear modules
681
+ peft_layers_to_transform: # The layer indices to transform, otherwise, apply to all layers
682
 
683
  # If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
684
  # For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
src/axolotl/utils/config.py CHANGED
@@ -257,6 +257,11 @@ def validate_config(cfg):
257
  if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
258
  raise ValueError("Fused modules are not supported with LoRA")
259
 
 
 
 
 
 
260
  if cfg.relora_steps:
261
  if cfg.adapter not in ("lora", "qlora"):
262
  raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
 
257
  if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
258
  raise ValueError("Fused modules are not supported with LoRA")
259
 
260
+ if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
261
+ raise ValueError(
262
+ "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
263
+ )
264
+
265
  if cfg.relora_steps:
266
  if cfg.adapter not in ("lora", "qlora"):
267
  raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
src/axolotl/utils/models.py CHANGED
@@ -733,6 +733,7 @@ def load_lora(model, cfg, inference=False):
733
  r=cfg.lora_r,
734
  lora_alpha=cfg.lora_alpha,
735
  target_modules=lora_target_modules,
 
736
  lora_dropout=cfg.lora_dropout,
737
  fan_in_fan_out=cfg.lora_fan_in_fan_out,
738
  modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
 
733
  r=cfg.lora_r,
734
  lora_alpha=cfg.lora_alpha,
735
  target_modules=lora_target_modules,
736
+ layers_to_transform=cfg.peft_layers_to_transform,
737
  lora_dropout=cfg.lora_dropout,
738
  fan_in_fan_out=cfg.lora_fan_in_fan_out,
739
  modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
tests/test_validation.py CHANGED
@@ -694,6 +694,21 @@ class ValidationTest(BaseValidation):
694
 
695
  validate_config(cfg)
696
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697
 
698
  class ValidationCheckModelConfig(BaseValidation):
699
  """
 
694
 
695
  validate_config(cfg)
696
 
697
+ def test_unfrozen_parameters_w_peft_layers_to_transform(self):
698
+ cfg = DictDefault(
699
+ {
700
+ "adapter": "lora",
701
+ "unfrozen_parameters": ["model.layers.2[0-9]+.block_sparse_moe.gate.*"],
702
+ "peft_layers_to_transform": [0, 1],
703
+ }
704
+ )
705
+
706
+ with pytest.raises(
707
+ ValueError,
708
+ match=r".*can have unexpected behavior*",
709
+ ):
710
+ validate_config(cfg)
711
+
712
 
713
  class ValidationCheckModelConfig(BaseValidation):
714
  """