tmm1 winglian commited on
Commit
06edf17
1 Parent(s): 0a22847

standardize attn hijack patches (#381)

Browse files

* split sdp attn into its own patch

* sync xformers patch to follow shared format and be diffable

* update flash-attn patch for 70B/GQA and inference using helper from flash-attn tests

* speed up flash-attn inference

* fix patch to check position ids and don't use multipack for evals

* copy LlamaModel.forward and LlamaDecoderLayer.forward into monkeypatch

* update forwards so we only calculate cu_seqlens once

* enable eval dataloader using multipack again

* fix the patch to work properly and work with FSDP

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -2,26 +2,63 @@
2
 
3
  # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
4
 
5
- from typing import Optional, Tuple
 
6
 
7
  import torch
 
8
  import transformers
9
  from einops import rearrange
10
  from flash_attn.bert_padding import pad_input, unpad_input
 
 
 
 
 
 
 
11
 
12
  try:
13
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
 
 
 
 
14
  except ImportError:
 
 
 
15
  from flash_attn.flash_attn_interface import (
16
  flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
17
  )
18
 
19
- from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
20
 
21
- from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
 
 
 
 
 
 
 
 
 
22
 
23
 
24
- def forward(
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  self,
26
  hidden_states: torch.Tensor,
27
  attention_mask: Optional[torch.Tensor] = None,
@@ -29,6 +66,8 @@ def forward(
29
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
30
  output_attentions: bool = False,
31
  use_cache: bool = False,
 
 
32
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
33
  """Input shape: Batch x Time x Channel
34
 
@@ -37,124 +76,523 @@ def forward(
37
  # pylint: disable=duplicate-code
38
  bsz, q_len, _ = hidden_states.size()
39
 
40
- query_states = (
41
- self.q_proj(hidden_states)
42
- .view(bsz, q_len, self.num_heads, self.head_dim)
43
- .transpose(1, 2)
44
- )
45
- key_states = (
46
- self.k_proj(hidden_states)
47
- .view(bsz, q_len, self.num_heads, self.head_dim)
48
- .transpose(1, 2)
49
- )
50
- value_states = (
51
- self.v_proj(hidden_states)
52
- .view(bsz, q_len, self.num_heads, self.head_dim)
53
- .transpose(1, 2)
54
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # [bsz, q_len, nh, hd]
56
  # [bsz, nh, q_len, hd]
57
 
58
  kv_seq_len = key_states.shape[-2]
59
- assert past_key_value is None, "past_key_value is not supported"
 
60
 
61
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
62
  query_states, key_states = apply_rotary_pos_emb(
63
  query_states, key_states, cos, sin, position_ids
64
  )
65
  # [bsz, nh, t, hd]
66
- assert not output_attentions, "output_attentions is not supported"
67
- assert not use_cache, "use_cache is not supported"
68
-
69
- # Flash attention codes from
70
- # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
71
-
72
- # transform the data into the format required by flash attention
73
- qkv = torch.stack(
74
- [query_states, key_states, value_states], dim=2
75
- ) # [bsz, nh, 3, q_len, hd]
76
- qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
77
- # We have disabled _prepare_decoder_attention_mask in LlamaModel
78
- # the attention_mask should be the same as the key_padding_mask
79
- key_padding_mask = attention_mask
80
-
81
- if key_padding_mask is None:
82
- qkv = rearrange(qkv, "b s ... -> (b s) ...")
83
- max_s = q_len
84
- cu_q_lens = torch.arange(
85
- 0,
86
- (bsz + 1) * q_len,
87
- step=q_len,
88
- dtype=torch.int32,
89
- device=qkv.device,
90
- )
91
- output = flash_attn_varlen_qkvpacked_func(
92
- qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
93
  )
94
- output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
95
- elif attention_mask.shape[0] == 1:
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # special handling using sample packing
 
 
 
 
97
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
98
- cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
99
- cu_q_lens = cu_q_lens.squeeze()
100
 
101
  output = flash_attn_varlen_qkvpacked_func(
102
- qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
103
  )
104
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
105
- else:
106
- nheads = qkv.shape[-2]
107
-
108
- # pylint: disable=invalid-name
109
- x = rearrange(qkv, "b s three h d -> b s (three h d)")
110
- x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
111
- x_unpad = rearrange(
112
- x_unpad,
113
- "nnz (three h d) -> nnz three h d",
114
- three=3,
115
- h=nheads,
 
 
 
 
116
  )
117
  output_unpad = flash_attn_varlen_qkvpacked_func(
118
- x_unpad,
119
- cu_q_lens,
120
- max_s,
121
  0.0,
122
  softmax_scale=None,
123
- causal=True,
124
  )
125
- output = rearrange(
126
- pad_input(
127
- rearrange(output_unpad, "nnz h d -> nnz (h d)"),
128
- indices,
129
- bsz,
130
- q_len,
131
- ),
132
- "b s (h d) -> b s h d",
133
- h=nheads,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  )
135
 
136
  return (
137
- self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
138
- None,
139
- None,
 
 
 
 
 
 
 
 
140
  )
141
 
142
 
143
- # Disable the transformation of the attention mask in LlamaModel as the flash attention
144
- # requires the attention mask to be the same as the key_padding_mask
145
- def _prepare_decoder_attention_mask(
146
  self,
147
- attention_mask,
148
- input_shape,
149
- inputs_embeds,
150
- past_key_values_length,
151
- ): # pylint: disable=unused-argument
152
- # [bsz, seq_len]
153
- return attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
 
 
 
155
 
156
- def replace_llama_attn_with_flash_attn():
157
- transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
158
- _prepare_decoder_attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  )
160
- transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
4
 
5
+ import warnings
6
+ from typing import List, Optional, Tuple, Union
7
 
8
  import torch
9
+ import torch.nn.functional as F
10
  import transformers
11
  from einops import rearrange
12
  from flash_attn.bert_padding import pad_input, unpad_input
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast
14
+ from transformers.models.llama.modeling_llama import (
15
+ LlamaDecoderLayer as OriginalLlamaDecoderLayer,
16
+ )
17
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
18
+
19
+ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
20
 
21
  try:
22
+ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
23
+ flash_attn_kvpacked_func,
24
+ flash_attn_varlen_kvpacked_func,
25
+ flash_attn_varlen_qkvpacked_func,
26
+ )
27
  except ImportError:
28
+ from flash_attn.flash_attn_interface import (
29
+ flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
30
+ )
31
  from flash_attn.flash_attn_interface import (
32
  flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
33
  )
34
 
 
35
 
36
+ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
37
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
38
+ _prepare_decoder_attention_mask
39
+ )
40
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
41
+ if packed:
42
+ transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
43
+ transformers.models.llama.modeling_llama.LlamaModel.forward = (
44
+ llama_model_forward
45
+ )
46
 
47
 
48
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
49
+ # requires the attention mask to be the same as the key_padding_mask
50
+ def _prepare_decoder_attention_mask(
51
+ self,
52
+ attention_mask,
53
+ input_shape,
54
+ inputs_embeds,
55
+ past_key_values_length,
56
+ ): # pylint: disable=unused-argument
57
+ # [bsz, seq_len]
58
+ return attention_mask
59
+
60
+
61
+ def flashattn_forward(
62
  self,
63
  hidden_states: torch.Tensor,
64
  attention_mask: Optional[torch.Tensor] = None,
 
66
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
67
  output_attentions: bool = False,
68
  use_cache: bool = False,
69
+ cu_seqlens: Optional[torch.Tensor] = None,
70
+ max_seqlen: Optional[torch.Tensor] = None,
71
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
72
  """Input shape: Batch x Time x Channel
73
 
 
76
  # pylint: disable=duplicate-code
77
  bsz, q_len, _ = hidden_states.size()
78
 
79
+ if not hasattr(self, "pretraining_tp"):
80
+ self.pretraining_tp = 1
81
+
82
+ if self.pretraining_tp > 1:
83
+ key_value_slicing = (
84
+ self.num_key_value_heads * self.head_dim
85
+ ) // self.pretraining_tp
86
+ query_slices = self.q_proj.weight.split(
87
+ (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
88
+ )
89
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
90
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
91
+
92
+ query_states = [
93
+ F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
94
+ ]
95
+ query_states = torch.cat(query_states, dim=-1)
96
+
97
+ key_states = [
98
+ F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
99
+ ]
100
+ key_states = torch.cat(key_states, dim=-1)
101
+
102
+ value_states = [
103
+ F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
104
+ ]
105
+ value_states = torch.cat(value_states, dim=-1)
106
+
107
+ else:
108
+ query_states = self.q_proj(hidden_states)
109
+ key_states = self.k_proj(hidden_states)
110
+ value_states = self.v_proj(hidden_states)
111
+
112
+ query_states = query_states.view(
113
+ bsz, q_len, self.num_heads, self.head_dim
114
+ ).transpose(1, 2)
115
+ key_states = key_states.view(
116
+ bsz, q_len, self.num_key_value_heads, self.head_dim
117
+ ).transpose(1, 2)
118
+ value_states = value_states.view(
119
+ bsz, q_len, self.num_key_value_heads, self.head_dim
120
+ ).transpose(1, 2)
121
  # [bsz, q_len, nh, hd]
122
  # [bsz, nh, q_len, hd]
123
 
124
  kv_seq_len = key_states.shape[-2]
125
+ if past_key_value is not None:
126
+ kv_seq_len += past_key_value[0].shape[-2]
127
 
128
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
129
  query_states, key_states = apply_rotary_pos_emb(
130
  query_states, key_states, cos, sin, position_ids
131
  )
132
  # [bsz, nh, t, hd]
133
+
134
+ if past_key_value is not None:
135
+ # reuse k, v, self_attention
136
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
137
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
138
+
139
+ past_key_value = (key_states, value_states) if use_cache else None
140
+
141
+ # repeat k/v heads if n_kv_heads < n_heads
142
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
143
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
144
+
145
+ if output_attentions:
146
+ warnings.warn(
147
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
 
 
 
 
 
 
 
 
 
 
 
 
148
  )
149
+
150
+ #
151
+ # flash-attn v2 start
152
+ #
153
+
154
+ if self.training:
155
+ # during training q,k,v always have same seqlen
156
+ assert key_states.shape == query_states.shape
157
+ is_causal = True
158
+ else:
159
+ # turn off FA causal mask after first inference autoregressive iteration
160
+ # only on first autoregressive step q,k,v have same seqlen
161
+ is_causal = past_key_value is not None
162
+
163
+ if cu_seqlens is not None and max_seqlen is not None:
164
  # special handling using sample packing
165
+ qkv = torch.stack(
166
+ [query_states, key_states, value_states], dim=2
167
+ ) # [bsz, nh, 3, q_len, hd]
168
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
169
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
 
 
170
 
171
  output = flash_attn_varlen_qkvpacked_func(
172
+ qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=is_causal
173
  )
174
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
175
+ elif query_states.shape == key_states.shape:
176
+ query_states = query_states.transpose(1, 2)
177
+ key_states = key_states.transpose(1, 2)
178
+ value_states = value_states.transpose(1, 2)
179
+ qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
180
+ query_states,
181
+ key_states,
182
+ value_states,
183
+ qkvpacked=True,
184
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
185
+ # the attention_mask should be the same as the key_padding_mask
186
+ key_padding_mask=attention_mask,
187
+ query_padding_mask=attention_mask[:, -query_states.size(1) :]
188
+ if attention_mask is not None
189
+ else None,
190
  )
191
  output_unpad = flash_attn_varlen_qkvpacked_func(
192
+ qkv_unpad,
193
+ cu_seqlens_q,
194
+ max_seqlen_q,
195
  0.0,
196
  softmax_scale=None,
197
+ causal=is_causal,
198
  )
199
+ output = output_pad_fn(output_unpad)
200
+ else:
201
+ query_states = query_states.transpose(1, 2)
202
+ key_states = key_states.transpose(1, 2)
203
+ value_states = value_states.transpose(1, 2)
204
+ if attention_mask is None or attention_mask.all().item():
205
+ output = flash_attn_kvpacked_func(
206
+ query_states,
207
+ torch.stack([key_states, value_states], 2),
208
+ causal=is_causal,
209
+ )
210
+ else:
211
+ ( # pylint: disable=unbalanced-tuple-unpacking
212
+ q_unpad,
213
+ kv_unpad,
214
+ cu_seqlens_q,
215
+ cu_seqlens_k,
216
+ max_seqlen_q,
217
+ max_seqlen_k,
218
+ _,
219
+ _,
220
+ output_pad_fn,
221
+ ) = generate_qkv(
222
+ query_states,
223
+ key_states,
224
+ value_states,
225
+ kvpacked=True,
226
+ key_padding_mask=attention_mask,
227
+ query_padding_mask=attention_mask[:, -query_states.size(1) :]
228
+ if attention_mask is not None
229
+ else None,
230
+ )
231
+ output_unpad = flash_attn_varlen_kvpacked_func(
232
+ q_unpad,
233
+ kv_unpad,
234
+ cu_seqlens_q,
235
+ cu_seqlens_k,
236
+ max_seqlen_q,
237
+ max_seqlen_k,
238
+ 0.0,
239
+ softmax_scale=None,
240
+ causal=is_causal,
241
+ )
242
+ output = output_pad_fn(output_unpad)
243
+
244
+ attn_output = output
245
+ if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
246
+ raise ValueError(
247
+ f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
248
+ f" {attn_output.size()}"
249
+ )
250
+ attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
251
+
252
+ #
253
+ # flash-attn v2 end
254
+ #
255
+
256
+ if self.pretraining_tp > 1:
257
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
258
+ o_proj_slices = self.o_proj.weight.split(
259
+ self.hidden_size // self.pretraining_tp, dim=1
260
+ )
261
+ attn_output = sum(
262
+ F.linear(attn_output[i], o_proj_slices[i])
263
+ for i in range(self.pretraining_tp)
264
+ )
265
+ else:
266
+ attn_output = self.o_proj(attn_output)
267
+
268
+ return attn_output, None, past_key_value
269
+
270
+
271
+ # based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
272
+ def generate_qkv(
273
+ q,
274
+ k,
275
+ v,
276
+ query_padding_mask=None,
277
+ key_padding_mask=None,
278
+ kvpacked=False,
279
+ qkvpacked=False,
280
+ ): # pylint: disable=invalid-name,unnecessary-lambda-assignment
281
+ """
282
+ Arguments:
283
+ q: (batch_size, seqlen_q, nheads, d)
284
+ k: (batch_size, seqlen_k, nheads_k, d)
285
+ v: (batch_size, seqlen_k, nheads_k, d)
286
+ query_padding_mask: (batch_size, seqlen), bool
287
+ key_padding_mask: (batch_size, seqlen), bool
288
+ """
289
+ assert not (kvpacked and qkvpacked)
290
+ batch_size, seqlen_q, nheads, d = q.shape
291
+ _, seqlen_k, nheads_k, _ = k.shape
292
+ assert k.shape == (batch_size, seqlen_k, nheads_k, d)
293
+ assert v.shape == (batch_size, seqlen_k, nheads_k, d)
294
+
295
+ if query_padding_mask is not None:
296
+ q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
297
+ q, query_padding_mask
298
+ )
299
+
300
+ output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
301
+ output_unpad, indices_q, batch_size, seqlen_q
302
+ )
303
+
304
+ else:
305
+ q_unpad = rearrange(q, "b s h d -> (b s) h d")
306
+ cu_seqlens_q = torch.arange(
307
+ 0,
308
+ (batch_size + 1) * seqlen_q,
309
+ step=seqlen_q,
310
+ dtype=torch.int32,
311
+ device=q_unpad.device,
312
+ )
313
+ max_seqlen_q = seqlen_q
314
+
315
+ output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
316
+ output_unpad, "(b s) h d -> b s h d", b=batch_size
317
+ )
318
+
319
+ if key_padding_mask is not None:
320
+ k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
321
+ v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
322
+ else:
323
+ k_unpad = rearrange(k, "b s h d -> (b s) h d")
324
+ v_unpad = rearrange(v, "b s h d -> (b s) h d")
325
+ cu_seqlens_k = torch.arange(
326
+ 0,
327
+ (batch_size + 1) * seqlen_k,
328
+ step=seqlen_k,
329
+ dtype=torch.int32,
330
+ device=k_unpad.device,
331
+ )
332
+ max_seqlen_k = seqlen_k
333
+
334
+ if qkvpacked:
335
+ assert nheads == nheads_k
336
+ qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
337
+ qkv = torch.stack([q, k, v], dim=2)
338
+ return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
339
+
340
+ if kvpacked:
341
+ kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
342
+ kv = torch.stack([k, v], dim=2)
343
+ return (
344
+ q_unpad,
345
+ kv_unpad,
346
+ cu_seqlens_q,
347
+ cu_seqlens_k,
348
+ max_seqlen_q,
349
+ max_seqlen_k,
350
+ q,
351
+ kv,
352
+ output_pad_fn,
353
  )
354
 
355
  return (
356
+ q_unpad,
357
+ k_unpad,
358
+ v_unpad,
359
+ cu_seqlens_q,
360
+ cu_seqlens_k,
361
+ max_seqlen_q,
362
+ max_seqlen_k,
363
+ q,
364
+ k,
365
+ v,
366
+ output_pad_fn,
367
  )
368
 
369
 
370
+ def llama_model_forward(
 
 
371
  self,
372
+ input_ids: torch.LongTensor = None,
373
+ attention_mask: Optional[torch.Tensor] = None,
374
+ position_ids: Optional[torch.LongTensor] = None,
375
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
376
+ inputs_embeds: Optional[torch.FloatTensor] = None,
377
+ use_cache: Optional[bool] = None,
378
+ output_attentions: Optional[bool] = None,
379
+ output_hidden_states: Optional[bool] = None,
380
+ return_dict: Optional[bool] = None,
381
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
382
+ output_attentions = (
383
+ output_attentions
384
+ if output_attentions is not None
385
+ else self.config.output_attentions
386
+ )
387
+ output_hidden_states = (
388
+ output_hidden_states
389
+ if output_hidden_states is not None
390
+ else self.config.output_hidden_states
391
+ )
392
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
393
 
394
+ return_dict = (
395
+ return_dict if return_dict is not None else self.config.use_return_dict
396
+ )
397
 
398
+ # retrieve input_ids and inputs_embeds
399
+ if input_ids is not None and inputs_embeds is not None:
400
+ raise ValueError(
401
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
402
+ )
403
+ if input_ids is not None:
404
+ batch_size, seq_length = input_ids.shape
405
+ elif inputs_embeds is not None:
406
+ batch_size, seq_length, _ = inputs_embeds.shape
407
+ else:
408
+ raise ValueError(
409
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
410
+ )
411
+
412
+ seq_length_with_past = seq_length
413
+ past_key_values_length = 0
414
+
415
+ if past_key_values is not None:
416
+ past_key_values_length = past_key_values[0][0].shape[2]
417
+ seq_length_with_past = seq_length_with_past + past_key_values_length
418
+
419
+ cu_seqlens = None
420
+ max_seqlen = None
421
+ if position_ids is None:
422
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
423
+ position_ids = torch.arange(
424
+ past_key_values_length,
425
+ seq_length + past_key_values_length,
426
+ dtype=torch.long,
427
+ device=device,
428
+ )
429
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
430
+ else:
431
+ position_ids = position_ids.view(-1, seq_length).long()
432
+ cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
433
+ cu_seqlens = cu_seqlens.squeeze()
434
+
435
+ if inputs_embeds is None:
436
+ inputs_embeds = self.embed_tokens(input_ids)
437
+ # embed positions
438
+ if attention_mask is None:
439
+ attention_mask = torch.ones(
440
+ (batch_size, seq_length_with_past),
441
+ dtype=torch.bool,
442
+ device=inputs_embeds.device,
443
+ )
444
+ attention_mask = (
445
+ self._prepare_decoder_attention_mask( # pylint: disable=protected-access
446
+ attention_mask,
447
+ (batch_size, seq_length),
448
+ inputs_embeds,
449
+ past_key_values_length,
450
+ )
451
+ )
452
+
453
+ hidden_states = inputs_embeds
454
+
455
+ if self.gradient_checkpointing and self.training:
456
+ if use_cache:
457
+ transformers.logger.warning_once(
458
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
459
+ )
460
+ use_cache = False
461
+
462
+ # decoder layers
463
+ all_hidden_states = () if output_hidden_states else None
464
+ all_self_attns = () if output_attentions else None
465
+ next_decoder_cache = () if use_cache else None
466
+
467
+ for idx, decoder_layer in enumerate(self.layers):
468
+ if output_hidden_states:
469
+ all_hidden_states += (hidden_states,)
470
+
471
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
472
+
473
+ if self.gradient_checkpointing and self.training:
474
+
475
+ def create_custom_forward(module):
476
+ def custom_forward(*inputs):
477
+ # None for past_key_value
478
+ return module(*inputs)
479
+
480
+ return custom_forward
481
+
482
+ layer_outputs = torch.utils.checkpoint.checkpoint(
483
+ create_custom_forward(decoder_layer),
484
+ hidden_states,
485
+ attention_mask,
486
+ position_ids,
487
+ None,
488
+ output_attentions,
489
+ None,
490
+ cu_seqlens,
491
+ max_seqlen,
492
+ )
493
+ else:
494
+ layer_outputs = decoder_layer(
495
+ hidden_states,
496
+ attention_mask=attention_mask,
497
+ position_ids=position_ids,
498
+ past_key_value=past_key_value,
499
+ output_attentions=output_attentions,
500
+ use_cache=use_cache,
501
+ cu_seqlens=cu_seqlens,
502
+ max_seqlen=max_seqlen,
503
+ )
504
+
505
+ hidden_states = layer_outputs[0]
506
+
507
+ if use_cache:
508
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
509
+
510
+ if output_attentions:
511
+ all_self_attns += (layer_outputs[1],)
512
+
513
+ hidden_states = self.norm(hidden_states)
514
+
515
+ # add hidden states from the last decoder layer
516
+ if output_hidden_states:
517
+ all_hidden_states += (hidden_states,)
518
+
519
+ next_cache = next_decoder_cache if use_cache else None
520
+ if not return_dict:
521
+ return tuple(
522
+ v
523
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
524
+ if v is not None
525
+ )
526
+ return BaseModelOutputWithPast(
527
+ last_hidden_state=hidden_states,
528
+ past_key_values=next_cache,
529
+ hidden_states=all_hidden_states,
530
+ attentions=all_self_attns,
531
  )
532
+
533
+
534
+ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
535
+ """
536
+ patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
537
+ """
538
+
539
+ def forward(
540
+ self,
541
+ hidden_states: torch.Tensor,
542
+ attention_mask: Optional[torch.Tensor] = None,
543
+ position_ids: Optional[torch.LongTensor] = None,
544
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
545
+ output_attentions: Optional[bool] = False,
546
+ use_cache: Optional[bool] = False,
547
+ cu_seqlens: Optional[torch.Tensor] = None,
548
+ max_seqlen: Optional[torch.Tensor] = None,
549
+ ) -> Tuple[
550
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
551
+ ]:
552
+ """
553
+ Args:
554
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
555
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
556
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
557
+ output_attentions (`bool`, *optional*):
558
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
559
+ returned tensors for more detail.
560
+ use_cache (`bool`, *optional*):
561
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
562
+ (see `past_key_values`).
563
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
564
+ cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
565
+ """
566
+
567
+ residual = hidden_states
568
+
569
+ hidden_states = self.input_layernorm(hidden_states)
570
+
571
+ # Self Attention
572
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
573
+ hidden_states=hidden_states,
574
+ attention_mask=attention_mask,
575
+ position_ids=position_ids,
576
+ past_key_value=past_key_value,
577
+ output_attentions=output_attentions,
578
+ use_cache=use_cache,
579
+ cu_seqlens=cu_seqlens,
580
+ max_seqlen=max_seqlen,
581
+ )
582
+ hidden_states = residual + hidden_states
583
+
584
+ # Fully Connected
585
+ residual = hidden_states
586
+ hidden_states = self.post_attention_layernorm(hidden_states)
587
+ hidden_states = self.mlp(hidden_states)
588
+ hidden_states = residual + hidden_states
589
+
590
+ outputs = (hidden_states,)
591
+
592
+ if output_attentions:
593
+ outputs += (self_attn_weights,)
594
+
595
+ if use_cache:
596
+ outputs += (present_key_value,)
597
+
598
+ return outputs
src/axolotl/monkeypatch/llama_attn_hijack_sdp.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
3
+ """
4
+
5
+ import warnings
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import transformers.models.llama.modeling_llama
11
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
12
+
13
+
14
+ def hijack_llama_sdp_attention():
15
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = (
16
+ sdp_attention_forward
17
+ )
18
+
19
+
20
+ def sdp_attention_forward(
21
+ self,
22
+ hidden_states: torch.Tensor,
23
+ attention_mask: Optional[torch.Tensor] = None,
24
+ position_ids: Optional[torch.LongTensor] = None,
25
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
26
+ output_attentions: bool = False,
27
+ use_cache: bool = False,
28
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
29
+ # pylint: disable=duplicate-code
30
+ bsz, q_len, _ = hidden_states.size()
31
+
32
+ if not hasattr(self, "pretraining_tp"):
33
+ self.pretraining_tp = 1
34
+
35
+ if self.pretraining_tp > 1:
36
+ key_value_slicing = (
37
+ self.num_key_value_heads * self.head_dim
38
+ ) // self.pretraining_tp
39
+ query_slices = self.q_proj.weight.split(
40
+ (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
41
+ )
42
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
43
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
44
+
45
+ query_states = [
46
+ F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
47
+ ]
48
+ query_states = torch.cat(query_states, dim=-1)
49
+
50
+ key_states = [
51
+ F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
52
+ ]
53
+ key_states = torch.cat(key_states, dim=-1)
54
+
55
+ value_states = [
56
+ F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
57
+ ]
58
+ value_states = torch.cat(value_states, dim=-1)
59
+
60
+ else:
61
+ query_states = self.q_proj(hidden_states)
62
+ key_states = self.k_proj(hidden_states)
63
+ value_states = self.v_proj(hidden_states)
64
+
65
+ query_states = query_states.view(
66
+ bsz, q_len, self.num_heads, self.head_dim
67
+ ).transpose(1, 2)
68
+ key_states = key_states.view(
69
+ bsz, q_len, self.num_key_value_heads, self.head_dim
70
+ ).transpose(1, 2)
71
+ value_states = value_states.view(
72
+ bsz, q_len, self.num_key_value_heads, self.head_dim
73
+ ).transpose(1, 2)
74
+ # [bsz, q_len, nh, hd]
75
+ # [bsz, nh, q_len, hd]
76
+
77
+ kv_seq_len = key_states.shape[-2]
78
+ if past_key_value is not None:
79
+ kv_seq_len += past_key_value[0].shape[-2]
80
+
81
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
82
+ query_states, key_states = apply_rotary_pos_emb(
83
+ query_states, key_states, cos, sin, position_ids
84
+ )
85
+ # [bsz, nh, t, hd]
86
+
87
+ if past_key_value is not None:
88
+ # reuse k, v, self_attention
89
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
90
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
91
+
92
+ past_key_value = (key_states, value_states) if use_cache else None
93
+
94
+ # repeat k/v heads if n_kv_heads < n_heads
95
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
96
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
97
+
98
+ if output_attentions:
99
+ warnings.warn(
100
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
101
+ )
102
+
103
+ #
104
+ # sdp-attn start
105
+ #
106
+
107
+ with torch.backends.cuda.sdp_kernel():
108
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
109
+ query_states,
110
+ key_states,
111
+ value_states,
112
+ attn_mask=attention_mask,
113
+ is_causal=False,
114
+ )
115
+
116
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
117
+ raise ValueError(
118
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
119
+ f" {attn_output.size()}"
120
+ )
121
+ attn_output = attn_output.transpose(1, 2)
122
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
123
+
124
+ #
125
+ # sdp-attn end
126
+ #
127
+
128
+ if self.pretraining_tp > 1:
129
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
130
+ o_proj_slices = self.o_proj.weight.split(
131
+ self.hidden_size // self.pretraining_tp, dim=1
132
+ )
133
+ attn_output = sum(
134
+ F.linear(attn_output[i], o_proj_slices[i])
135
+ for i in range(self.pretraining_tp)
136
+ )
137
+ else:
138
+ attn_output = self.o_proj(attn_output)
139
+
140
+ return attn_output, None, past_key_value
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py CHANGED
@@ -3,13 +3,13 @@ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-g
3
  """
4
 
5
  import logging
6
- import math
7
  from typing import Optional, Tuple
8
 
9
  import torch
10
  import torch.nn.functional as F
11
  import transformers.models.llama.modeling_llama
12
- from torch import nn
13
 
14
  try:
15
  import xformers.ops
@@ -21,12 +21,6 @@ def hijack_llama_attention():
21
  transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
22
 
23
 
24
- def hijack_llama_sdp_attention():
25
- transformers.models.llama.modeling_llama.LlamaAttention.forward = (
26
- sdp_attention_forward
27
- )
28
-
29
-
30
  def xformers_forward(
31
  self,
32
  hidden_states: torch.Tensor,
@@ -81,15 +75,15 @@ def xformers_forward(
81
  value_states = value_states.view(
82
  bsz, q_len, self.num_key_value_heads, self.head_dim
83
  ).transpose(1, 2)
 
 
84
 
85
  kv_seq_len = key_states.shape[-2]
86
  if past_key_value is not None:
87
  kv_seq_len += past_key_value[0].shape[-2]
 
88
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
89
- (
90
- query_states,
91
- key_states,
92
- ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
93
  query_states, key_states, cos, sin, position_ids
94
  )
95
  # [bsz, nh, t, hd]
@@ -102,74 +96,50 @@ def xformers_forward(
102
  past_key_value = (key_states, value_states) if use_cache else None
103
 
104
  # repeat k/v heads if n_kv_heads < n_heads
105
- key_states = transformers.models.llama.modeling_llama.repeat_kv(
106
- key_states, self.num_key_value_groups
107
- )
108
- value_states = transformers.models.llama.modeling_llama.repeat_kv(
109
- value_states, self.num_key_value_groups
110
- )
111
 
112
- # We only apply xformers optimizations if we don't need to output the whole attention matrix
113
- if not output_attentions:
114
- query_states = query_states.transpose(1, 2)
115
- key_states = key_states.transpose(1, 2)
116
- value_states = value_states.transpose(1, 2)
117
-
118
- # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
119
- # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
120
- if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
121
- # input and output should be of form (bsz, q_len, num_heads, head_dim)
122
- attn_output = xformers.ops.memory_efficient_attention(
123
- query_states, key_states, value_states, attn_bias=None
124
- )
125
- else:
126
- # input and output should be of form (bsz, q_len, num_heads, head_dim)
127
- attn_output = xformers.ops.memory_efficient_attention(
128
- query_states,
129
- key_states,
130
- value_states,
131
- # attn_bias=attention_mask,
132
- attn_bias=xformers.ops.LowerTriangularMask(),
133
- )
134
- attn_weights = None
135
- else:
136
- attn_weights = torch.matmul(
137
- query_states, key_states.transpose(2, 3)
138
- ) / math.sqrt(self.head_dim)
139
-
140
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
141
- raise ValueError(
142
- f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
143
- f" {attn_weights.size()}"
144
- )
145
-
146
- if attention_mask is not None:
147
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
148
- raise ValueError(
149
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
150
- )
151
- attn_weights = attn_weights + attention_mask
152
- attn_weights = torch.max(
153
- attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
154
- )
155
 
156
- # upcast attention to fp32
157
- attn_weights = nn.functional.softmax(
158
- attn_weights, dim=-1, dtype=torch.float32
159
- ).to(query_states.dtype)
160
- attn_output = torch.matmul(attn_weights, value_states)
161
 
162
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
163
- raise ValueError(
164
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
165
- f" {attn_output.size()}"
166
- )
167
 
168
- attn_output = attn_output.transpose(1, 2).contiguous()
169
- # end x-formers vs. not x-formers if-else block
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
 
 
 
 
 
171
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
172
 
 
 
 
 
173
  if self.pretraining_tp > 1:
174
  attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
175
  o_proj_slices = self.o_proj.weight.split(
@@ -182,103 +152,4 @@ def xformers_forward(
182
  else:
183
  attn_output = self.o_proj(attn_output)
184
 
185
- return attn_output, attn_weights, past_key_value
186
-
187
-
188
- def sdp_attention_forward(
189
- self,
190
- hidden_states: torch.Tensor,
191
- attention_mask: Optional[torch.Tensor] = None,
192
- position_ids: Optional[torch.LongTensor] = None,
193
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
194
- output_attentions: bool = False,
195
- use_cache: bool = False,
196
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
197
- # pylint: disable=duplicate-code
198
- bsz, q_len, _ = hidden_states.size()
199
-
200
- query_states = (
201
- self.q_proj(hidden_states)
202
- .view(bsz, q_len, self.num_heads, self.head_dim)
203
- .transpose(1, 2)
204
- )
205
- key_states = (
206
- self.k_proj(hidden_states)
207
- .view(bsz, q_len, self.num_heads, self.head_dim)
208
- .transpose(1, 2)
209
- )
210
- value_states = (
211
- self.v_proj(hidden_states)
212
- .view(bsz, q_len, self.num_heads, self.head_dim)
213
- .transpose(1, 2)
214
- )
215
-
216
- kv_seq_len = key_states.shape[-2]
217
- if past_key_value is not None:
218
- kv_seq_len += past_key_value[0].shape[-2]
219
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
220
- (
221
- query_states,
222
- key_states,
223
- ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
224
- query_states, key_states, cos, sin, position_ids
225
- )
226
- # [bsz, nh, t, hd]
227
-
228
- if past_key_value is not None:
229
- # reuse k, v, self_attention
230
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
231
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
232
-
233
- past_key_value = (key_states, value_states) if use_cache else None
234
-
235
- # We only apply sdp attention if we don't need to output the whole attention matrix
236
- if not output_attentions:
237
- with torch.backends.cuda.sdp_kernel():
238
- attn_output = torch.nn.functional.scaled_dot_product_attention(
239
- query_states,
240
- key_states,
241
- value_states,
242
- attn_mask=attention_mask,
243
- is_causal=False,
244
- )
245
- attn_weights = None
246
- else:
247
- attn_weights = torch.matmul(
248
- query_states, key_states.transpose(2, 3)
249
- ) / math.sqrt(self.head_dim)
250
-
251
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
252
- raise ValueError(
253
- f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
254
- f" {attn_weights.size()}"
255
- )
256
-
257
- if attention_mask is not None:
258
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
259
- raise ValueError(
260
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
261
- )
262
- attn_weights = attn_weights + attention_mask
263
- attn_weights = torch.max(
264
- attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
265
- )
266
-
267
- # upcast attention to fp32
268
- attn_weights = nn.functional.softmax(
269
- attn_weights, dim=-1, dtype=torch.float32
270
- ).to(query_states.dtype)
271
- attn_output = torch.matmul(attn_weights, value_states)
272
-
273
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
274
- raise ValueError(
275
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
276
- f" {attn_output.size()}"
277
- )
278
-
279
- attn_output = attn_output.transpose(1, 2)
280
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
281
-
282
- attn_output = self.o_proj(attn_output)
283
-
284
- return attn_output, attn_weights, past_key_value
 
3
  """
4
 
5
  import logging
6
+ import warnings
7
  from typing import Optional, Tuple
8
 
9
  import torch
10
  import torch.nn.functional as F
11
  import transformers.models.llama.modeling_llama
12
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
13
 
14
  try:
15
  import xformers.ops
 
21
  transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
22
 
23
 
 
 
 
 
 
 
24
  def xformers_forward(
25
  self,
26
  hidden_states: torch.Tensor,
 
75
  value_states = value_states.view(
76
  bsz, q_len, self.num_key_value_heads, self.head_dim
77
  ).transpose(1, 2)
78
+ # [bsz, q_len, nh, hd]
79
+ # [bsz, nh, q_len, hd]
80
 
81
  kv_seq_len = key_states.shape[-2]
82
  if past_key_value is not None:
83
  kv_seq_len += past_key_value[0].shape[-2]
84
+
85
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
86
+ query_states, key_states = apply_rotary_pos_emb(
 
 
 
87
  query_states, key_states, cos, sin, position_ids
88
  )
89
  # [bsz, nh, t, hd]
 
96
  past_key_value = (key_states, value_states) if use_cache else None
97
 
98
  # repeat k/v heads if n_kv_heads < n_heads
99
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
100
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
 
 
 
 
101
 
102
+ if output_attentions:
103
+ warnings.warn(
104
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
105
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ #
108
+ # xformers-attn start
109
+ #
 
 
110
 
111
+ query_states = query_states.transpose(1, 2)
112
+ key_states = key_states.transpose(1, 2)
113
+ value_states = value_states.transpose(1, 2)
 
 
114
 
115
+ # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
116
+ # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
117
+ if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
118
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
119
+ attn_output = xformers.ops.memory_efficient_attention(
120
+ query_states, key_states, value_states, attn_bias=None
121
+ )
122
+ else:
123
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
124
+ attn_output = xformers.ops.memory_efficient_attention(
125
+ query_states,
126
+ key_states,
127
+ value_states,
128
+ # attn_bias=attention_mask,
129
+ attn_bias=xformers.ops.LowerTriangularMask(),
130
+ )
131
 
132
+ if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
133
+ raise ValueError(
134
+ f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
135
+ f" {attn_output.size()}"
136
+ )
137
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
138
 
139
+ #
140
+ # xformers-attn end
141
+ #
142
+
143
  if self.pretraining_tp > 1:
144
  attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
145
  o_proj_slices = self.o_proj.weight.split(
 
152
  else:
153
  attn_output = self.o_proj(attn_output)
154
 
155
+ return attn_output, None, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/utils/models.py CHANGED
@@ -103,7 +103,7 @@ def load_model(
103
  )
104
 
105
  LOG.info("patching with flash attention")
106
- replace_llama_attn_with_flash_attn()
107
  elif cfg.is_llama_derived_model and cfg.xformers_attention:
108
  from axolotl.monkeypatch.llama_attn_hijack_xformers import (
109
  hijack_llama_attention,
@@ -112,9 +112,7 @@ def load_model(
112
  LOG.info("patching with xformers attention")
113
  hijack_llama_attention()
114
  elif cfg.is_llama_derived_model and cfg.sdp_attention:
115
- from axolotl.monkeypatch.llama_attn_hijack_xformers import (
116
- hijack_llama_sdp_attention,
117
- )
118
 
119
  LOG.info("patching with sdp attention")
120
  hijack_llama_sdp_attention()
 
103
  )
104
 
105
  LOG.info("patching with flash attention")
106
+ replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
107
  elif cfg.is_llama_derived_model and cfg.xformers_attention:
108
  from axolotl.monkeypatch.llama_attn_hijack_xformers import (
109
  hijack_llama_attention,
 
112
  LOG.info("patching with xformers attention")
113
  hijack_llama_attention()
114
  elif cfg.is_llama_derived_model and cfg.sdp_attention:
115
+ from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention
 
 
116
 
117
  LOG.info("patching with sdp attention")
118
  hijack_llama_sdp_attention()