winglian commited on
Commit
176b888
1 Parent(s): 3392270

ensure enable_input_require_grads is called on model before getting the peft model (#345)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +2 -0
src/axolotl/utils/models.py CHANGED
@@ -391,6 +391,8 @@ def load_adapter(model, cfg, adapter):
391
 
392
  if adapter is None:
393
  return model, None
 
 
394
  if adapter in ["lora", "qlora"]:
395
  return load_lora(model, cfg)
396
  if adapter == "llama-adapter":
 
391
 
392
  if adapter is None:
393
  return model, None
394
+ if hasattr(model, "enable_input_require_grads"):
395
+ model.enable_input_require_grads()
396
  if adapter in ["lora", "qlora"]:
397
  return load_lora(model, cfg)
398
  if adapter == "llama-adapter":