winglian commited on
Commit
a03a7d7
1 Parent(s): 931e606

add support to extend context with xpos rope

Browse files
src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ """
3
+ Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
4
+ """
5
+ import torch
6
+ import transformers
7
+ import transformers.models.llama.modeling_llama
8
+ from einops import rearrange
9
+
10
+
11
+ class XposRotaryEmbedding(torch.nn.Module):
12
+ def __init__(
13
+ self,
14
+ dim,
15
+ max_position_embeddings=2048,
16
+ base=10000,
17
+ device=None,
18
+ scale_base=2048,
19
+ use_xpos=True,
20
+ ):
21
+ super().__init__()
22
+ self.max_seq_len_cached = max_position_embeddings
23
+ self.scale_base = scale_base
24
+
25
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
26
+ t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
27
+ freqs = torch.einsum("i , j -> i j", t, inv_freq)
28
+ freqs = torch.cat((freqs, freqs), dim=-1)
29
+
30
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
31
+ self.register_buffer("freqs_cached", freqs, persistent=False)
32
+
33
+ if not use_xpos:
34
+ self.register_buffer("scale", None)
35
+ self.register_buffer("scale_cached", torch.ones(1))
36
+ return
37
+
38
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
39
+ power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
40
+ scale_cached = scale ** rearrange(power, "n -> n 1")
41
+ scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
42
+
43
+ self.register_buffer("scale", scale, persistent=False)
44
+ self.register_buffer("scale_cached", scale_cached, persistent=False)
45
+
46
+ def forward(
47
+ self,
48
+ x,
49
+ seq_len,
50
+ ):
51
+ if seq_len > self.max_seq_len_cached:
52
+ self.max_seq_len_cached = seq_len
53
+ t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
54
+ self.inv_freq
55
+ )
56
+ freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
57
+ freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
58
+
59
+ self.register_buffer("freqs_cached", freqs)
60
+
61
+ if self.scale is None:
62
+ self.register_buffer(
63
+ "scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
64
+ )
65
+
66
+ return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
67
+
68
+ power = (t - (seq_len // 2)) / self.scale_base
69
+ scale = self.scale ** rearrange(power, "n -> n 1")
70
+ scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
71
+ self.register_buffer("scale_cached", scale)
72
+
73
+ return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
74
+
75
+
76
+ def rotate_half(x):
77
+ x1, x2 = x.chunk(2, dim=-1)
78
+ return torch.cat((-x2, x1), dim=-1)
79
+
80
+
81
+ def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
82
+ freqs = freqs[position_ids, :]
83
+ if scale.shape[-1] != 1:
84
+ scale = scale[position_ids, :]
85
+
86
+ q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
87
+ k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
88
+
89
+ return q_embed, k_embed
90
+
91
+
92
+ def replace_llama_rope_with_xpos_rope():
93
+ transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
94
+ transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
src/axolotl/utils/models.py CHANGED
@@ -127,6 +127,14 @@ def load_model(
127
  # TODO: Check if this would overwrite previous additional_special_tokens
128
  tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
129
 
 
 
 
 
 
 
 
 
130
  if cfg.bf16:
131
  torch_dtype = torch.bfloat16
132
  elif cfg.load_in_8bit or cfg.fp16:
 
127
  # TODO: Check if this would overwrite previous additional_special_tokens
128
  tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
129
 
130
+ if cfg.is_llama_derived_model and cfg.xpos_rope:
131
+ from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
132
+ replace_llama_rope_with_xpos_rope,
133
+ )
134
+
135
+ logging.info("patching with xpos rope")
136
+ replace_llama_rope_with_xpos_rope()
137
+
138
  if cfg.bf16:
139
  torch_dtype = torch.bfloat16
140
  elif cfg.load_in_8bit or cfg.fp16: