winglian commited on
Commit
8746b70
1 Parent(s): 6045345

attempt xformers hijack attention

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +4 -0
src/axolotl/utils/models.py CHANGED
@@ -43,6 +43,10 @@ def load_model(
43
 
44
  logging.info("patching with flash attention")
45
  replace_llama_attn_with_flash_attn()
 
 
 
 
46
 
47
  torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
48
  try:
 
43
 
44
  logging.info("patching with flash attention")
45
  replace_llama_attn_with_flash_attn()
46
+ elif is_llama_derived_model and cfg.xformers_attention:
47
+ from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
48
+ logging.info("patching with xformers attention")
49
+ hijack_llama_attention()
50
 
51
  torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
52
  try: