tmm1 commited on
Commit
312a9fa
β€’
1 Parent(s): 248bf90

move flash-attn monkey patch alongside the others

Browse files
src/axolotl/{flash_attn.py β†’ monkeypatch/llama_attn_hijack_flash.py} RENAMED
File without changes
src/axolotl/utils/models.py CHANGED
@@ -92,7 +92,9 @@ def load_model(
92
 
93
  if cfg.is_llama_derived_model and cfg.flash_attention:
94
  if cfg.device not in ["mps", "cpu"] and not cfg.inference:
95
- from axolotl.flash_attn import replace_llama_attn_with_flash_attn
 
 
96
 
97
  LOG.info("patching with flash attention")
98
  replace_llama_attn_with_flash_attn()
 
92
 
93
  if cfg.is_llama_derived_model and cfg.flash_attention:
94
  if cfg.device not in ["mps", "cpu"] and not cfg.inference:
95
+ from axolotl.monkeypatch.llama_attn_hijack_flash import (
96
+ replace_llama_attn_with_flash_attn,
97
+ )
98
 
99
  LOG.info("patching with flash attention")
100
  replace_llama_attn_with_flash_attn()