Nanobit commited on
Commit
919727b
1 Parent(s): d9f713e

Refactor landmark attention patch

Browse files
src/axolotl/monkeypatch/llama_landmark_attn.py CHANGED
@@ -1593,3 +1593,12 @@ def add_mem_tokens(example, mem_freq, mem_id):
1593
  ret.extend(x[prev_idx:])
1594
  # drop attention_mask
1595
  return {"input_ids": ret}
 
 
 
 
 
 
 
 
 
 
1593
  ret.extend(x[prev_idx:])
1594
  # drop attention_mask
1595
  return {"input_ids": ret}
1596
+
1597
+
1598
+ def patch_llama_with_landmark_attn():
1599
+ import transformers
1600
+
1601
+ transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
1602
+ transformers.models.llama.modeling_llama.LlamaModel = LlamaModel
1603
+ transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
1604
+ transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
src/axolotl/utils/models.py CHANGED
@@ -19,15 +19,6 @@ from transformers import ( # noqa: F401
19
  LlamaConfig,
20
  )
21
 
22
- try:
23
- from transformers import ( # pylint: disable=unused-import # noqa: F401
24
- LlamaForCausalLM,
25
- )
26
- except ImportError:
27
- logging.warning(
28
- "This version of transformers does not support Llama. Consider upgrading."
29
- )
30
-
31
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
32
 
33
  if TYPE_CHECKING:
@@ -118,14 +109,15 @@ def load_model(
118
  logging.info("patching with sdp attention")
119
  hijack_llama_sdp_attention()
120
  elif cfg.is_llama_derived_model and cfg.landmark_attention:
121
- from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
122
  MEM_TOKEN,
123
- LlamaForCausalLM,
124
  )
125
 
126
  logging.info("patching with landmark attention")
 
127
 
128
- # TODO: Check if this would overwrite previous additional_special_tokens
129
  tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
130
 
131
  if cfg.is_llama_derived_model and cfg.xpos_rope:
@@ -211,6 +203,13 @@ def load_model(
211
  )
212
  load_in_8bit = False
213
  elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
 
 
 
 
 
 
 
214
  config = LlamaConfig.from_pretrained(base_model_config)
215
  model = LlamaForCausalLM.from_pretrained(
216
  base_model,
 
19
  LlamaConfig,
20
  )
21
 
 
 
 
 
 
 
 
 
 
22
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
23
 
24
  if TYPE_CHECKING:
 
109
  logging.info("patching with sdp attention")
110
  hijack_llama_sdp_attention()
111
  elif cfg.is_llama_derived_model and cfg.landmark_attention:
112
+ from axolotl.monkeypatch.llama_landmark_attn import (
113
  MEM_TOKEN,
114
+ patch_llama_with_landmark_attn,
115
  )
116
 
117
  logging.info("patching with landmark attention")
118
+ patch_llama_with_landmark_attn()
119
 
120
+ # Note: This might overwrite previous additional_special_tokens
121
  tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
122
 
123
  if cfg.is_llama_derived_model and cfg.xpos_rope:
 
203
  )
204
  load_in_8bit = False
205
  elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
206
+ try:
207
+ from transformers import LlamaForCausalLM
208
+ except ImportError:
209
+ logging.warning(
210
+ "This version of transformers does not support Llama. Consider upgrading."
211
+ )
212
+
213
  config = LlamaConfig.from_pretrained(base_model_config)
214
  model = LlamaForCausalLM.from_pretrained(
215
  base_model,