Nanobit commited on
Commit
e44c9e0
1 Parent(s): 55b8542

Fix patching via import instead of hijacking

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +7 -5
src/axolotl/utils/models.py CHANGED
@@ -20,7 +20,9 @@ from transformers import ( # noqa: F401
20
  )
21
 
22
  try:
23
- from transformers import LlamaForCausalLM
 
 
24
  except ImportError:
25
  logging.warning(
26
  "This version of transformers does not support Llama. Consider upgrading."
@@ -115,15 +117,15 @@ def load_model(
115
  logging.info("patching with sdp attention")
116
  hijack_llama_sdp_attention()
117
  elif cfg.is_llama_derived_model and cfg.landmark_attention:
118
- from axolotl.monkeypatch.llama_landmark_attn import (
119
  MEM_TOKEN,
120
- hijack_llama_landmark_attn,
121
  )
122
 
123
  logging.info("patching with landmark attention")
124
- hijack_llama_landmark_attn()
125
 
126
- tokenizer.add_special_tokens({"mem_token": MEM_TOKEN})
 
127
 
128
  if cfg.bf16:
129
  torch_dtype = torch.bfloat16
 
20
  )
21
 
22
  try:
23
+ from transformers import ( # pylint: disable=unused-import # noqa: F401
24
+ LlamaForCausalLM,
25
+ )
26
  except ImportError:
27
  logging.warning(
28
  "This version of transformers does not support Llama. Consider upgrading."
 
117
  logging.info("patching with sdp attention")
118
  hijack_llama_sdp_attention()
119
  elif cfg.is_llama_derived_model and cfg.landmark_attention:
120
+ from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
121
  MEM_TOKEN,
122
+ LlamaForCausalLM,
123
  )
124
 
125
  logging.info("patching with landmark attention")
 
126
 
127
+ # TODO: Check if this would overwrite previous additional_special_tokens
128
+ tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
129
 
130
  if cfg.bf16:
131
  torch_dtype = torch.bfloat16