winglian commited on
Commit
2d60ba3
1 Parent(s): eb480df

flash_attention + sample packing for stablelm 3b (#671)

Browse files

* stablelm epoch fa patch

* is causal for fa

* working stablelm fa w packing

* chore: pre-commit linting

src/axolotl/monkeypatch/btlm_attn_hijack_flash.py CHANGED
@@ -7,6 +7,7 @@ import logging
7
  from typing import Optional, Tuple
8
 
9
  import torch
 
10
  from flash_attn.flash_attn_interface import flash_attn_func
11
  from transformers import AutoConfig, AutoModelForCausalLM
12
 
@@ -17,7 +18,8 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
17
  # this is a wonky hack to get the remotely loaded module
18
  model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
19
  # we need to load the model here in order for modeling_btlm to be available
20
- AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
 
21
  module_name = model_config.__class__.__module__.replace(
22
  ".configuration_btlm", ".modeling_btlm"
23
  )
 
7
  from typing import Optional, Tuple
8
 
9
  import torch
10
+ from accelerate import init_empty_weights
11
  from flash_attn.flash_attn_interface import flash_attn_func
12
  from transformers import AutoConfig, AutoModelForCausalLM
13
 
 
18
  # this is a wonky hack to get the remotely loaded module
19
  model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
20
  # we need to load the model here in order for modeling_btlm to be available
21
+ with init_empty_weights():
22
+ AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
23
  module_name = model_config.__class__.__module__.replace(
24
  ".configuration_btlm", ".modeling_btlm"
25
  )
src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # This code is based off the following work:
17
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
18
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
19
+ """ PyTorch StableLM Epoch model. """
20
+ import importlib
21
+ import math
22
+ from typing import Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from accelerate import init_empty_weights
27
+ from einops import rearrange
28
+ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
29
+ flash_attn_varlen_qkvpacked_func,
30
+ )
31
+ from torch import nn
32
+ from transformers import AutoConfig, AutoModelForCausalLM
33
+ from transformers.modeling_outputs import BaseModelOutputWithPast
34
+ from transformers.utils import logging
35
+
36
+ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+
41
+ def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):
42
+ # this is a wonky hack to get the remotely loaded module
43
+ model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
44
+ # we need to load the model here in order for modeling_stablelm_epoch to be available
45
+ with init_empty_weights():
46
+ AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
47
+ module_name = model_config.__class__.__module__.replace(
48
+ ".configuration_stablelm_epoch", ".modeling_stablelm_epoch"
49
+ )
50
+ modeling_stablelm = importlib.import_module(module_name)
51
+ modeling_stablelm.Attention.forward = ( # pylint: disable=protected-access
52
+ flashattn_attn
53
+ )
54
+ modeling_stablelm.StableLMEpochModel.forward = ( # pylint: disable=protected-access
55
+ stablelm_model_forward
56
+ )
57
+ modeling_stablelm.DecoderLayer.forward = ( # pylint: disable=protected-access
58
+ decoder_layer_forward
59
+ )
60
+
61
+
62
+ def rotate_half(x: torch.Tensor):
63
+ """Rotates half the hidden dims of the input."""
64
+ # pylint: disable=invalid-name
65
+ x1, x2 = torch.chunk(x, 2, dim=-1)
66
+ return torch.cat((-x2, x1), dim=-1)
67
+
68
+
69
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
70
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
71
+ # pylint: disable=invalid-name
72
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
73
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
74
+ cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
75
+ sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
76
+ q_embed = (q * cos) + (rotate_half(q) * sin)
77
+ k_embed = (k * cos) + (rotate_half(k) * sin)
78
+ return q_embed, k_embed
79
+
80
+
81
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
82
+ """
83
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
84
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
85
+ """
86
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
87
+ if n_rep == 1:
88
+ return hidden_states
89
+ hidden_states = hidden_states[:, :, None, :, :].expand(
90
+ batch, num_key_value_heads, n_rep, slen, head_dim
91
+ )
92
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
93
+
94
+
95
+ def flashattn_attn(
96
+ self,
97
+ hidden_states: torch.FloatTensor,
98
+ attention_mask: torch.FloatTensor,
99
+ position_ids: torch.LongTensor,
100
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
101
+ output_attentions: Optional[bool] = False, # pylint: disable=unused-argument
102
+ use_cache: Optional[bool] = False,
103
+ cu_seqlens: Optional[torch.Tensor] = None,
104
+ max_seqlen: Optional[torch.Tensor] = None,
105
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
106
+ bsz, q_len, _ = hidden_states.size()
107
+
108
+ query_states = self.q_proj(hidden_states)
109
+ key_states = self.k_proj(hidden_states)
110
+ value_states = self.v_proj(hidden_states)
111
+
112
+ query_states = query_states.view(
113
+ bsz, q_len, self.num_heads, self.head_dim
114
+ ).transpose(1, 2)
115
+ key_states = key_states.view(
116
+ bsz, q_len, self.num_key_value_heads, self.head_dim
117
+ ).transpose(1, 2)
118
+ value_states = value_states.view(
119
+ bsz, q_len, self.num_key_value_heads, self.head_dim
120
+ ).transpose(1, 2)
121
+
122
+ query_rot = query_states[..., : self.rotary_ndims]
123
+ query_pass = query_states[..., self.rotary_ndims :]
124
+ key_rot = key_states[..., : self.rotary_ndims]
125
+ key_pass = key_states[..., self.rotary_ndims :]
126
+
127
+ kv_seq_len = key_states.shape[-2]
128
+ if past_key_value is not None:
129
+ kv_seq_len += past_key_value[0].shape[-2]
130
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
131
+ query_states, key_states = apply_rotary_pos_emb(
132
+ query_rot, key_rot, cos, sin, position_ids
133
+ )
134
+
135
+ # [batch_size, num_heads, seq_len, head_dim]
136
+ query_states = torch.cat((query_states, query_pass), dim=-1)
137
+ key_states = torch.cat((key_states, key_pass), dim=-1)
138
+
139
+ if past_key_value is not None:
140
+ # Reuse k, v, self_attention
141
+ key_states = torch.cat((past_key_value[0], key_states), dim=2)
142
+ value_states = torch.cat((past_key_value[1], value_states), dim=2)
143
+
144
+ past_key_value = (key_states, value_states) if use_cache else None
145
+
146
+ # Repeat k/v heads if n_kv_heads < n_heads
147
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
148
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
149
+
150
+ if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
151
+ # special handling using sample packing
152
+ qkv = torch.stack(
153
+ [query_states, key_states, value_states], dim=2
154
+ ) # [bsz, nh, 3, q_len, hd]
155
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
156
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
157
+ softmax_scale = None
158
+
159
+ output = flash_attn_varlen_qkvpacked_func(
160
+ qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True
161
+ )
162
+
163
+ attn_output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
164
+ attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
165
+ else:
166
+ attn_weights = torch.matmul(
167
+ query_states, key_states.transpose(2, 3)
168
+ ) / math.sqrt(self.head_dim)
169
+
170
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
171
+ raise ValueError(
172
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
173
+ f" {attn_weights.size()}"
174
+ )
175
+
176
+ if attention_mask is not None:
177
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
178
+ raise ValueError(
179
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
180
+ )
181
+ attn_weights = attn_weights + attention_mask
182
+
183
+ # Upcast attention to fp32
184
+ attn_weights = nn.functional.softmax(
185
+ attn_weights, dim=-1, dtype=torch.float32
186
+ ).to(query_states.dtype)
187
+ attn_output = torch.matmul(attn_weights, value_states)
188
+
189
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
190
+ raise ValueError(
191
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
192
+ f" {attn_output.size()}"
193
+ )
194
+
195
+ # Merge heads
196
+ attn_output = attn_output.transpose(1, 2).contiguous()
197
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
198
+
199
+ # Final linear projection
200
+ attn_output = self.o_proj(attn_output)
201
+
202
+ return attn_output, None, past_key_value
203
+
204
+
205
+ def decoder_layer_forward(
206
+ self,
207
+ hidden_states: Optional[torch.FloatTensor],
208
+ attention_mask: Optional[torch.FloatTensor] = None,
209
+ position_ids: Optional[torch.LongTensor] = None,
210
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
211
+ output_attentions: Optional[bool] = False,
212
+ use_cache: Optional[bool] = False,
213
+ cu_seqlens: Optional[torch.Tensor] = None,
214
+ max_seqlen: Optional[torch.Tensor] = None,
215
+ ) -> Union[
216
+ Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]
217
+ ]:
218
+ # pylint: disable=duplicate-code
219
+ residual = hidden_states
220
+
221
+ hidden_states = self.input_layernorm(hidden_states)
222
+
223
+ # Self Attention
224
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
225
+ hidden_states=hidden_states,
226
+ attention_mask=attention_mask,
227
+ position_ids=position_ids,
228
+ past_key_value=past_key_value,
229
+ output_attentions=output_attentions,
230
+ use_cache=use_cache,
231
+ cu_seqlens=cu_seqlens,
232
+ max_seqlen=max_seqlen,
233
+ )
234
+ hidden_states = residual + hidden_states
235
+
236
+ # Fully Connected
237
+ residual = hidden_states
238
+ hidden_states = self.post_attention_layernorm(hidden_states)
239
+ hidden_states = self.mlp(hidden_states)
240
+ hidden_states = residual + hidden_states
241
+
242
+ outputs = (hidden_states,)
243
+
244
+ if output_attentions:
245
+ outputs += (self_attn_weights,)
246
+
247
+ if use_cache:
248
+ outputs += (present_key_value,)
249
+
250
+ return outputs
251
+
252
+
253
+ def stablelm_model_forward(
254
+ self,
255
+ input_ids: Optional[torch.LongTensor] = None,
256
+ attention_mask: Optional[torch.FloatTensor] = None,
257
+ position_ids: Optional[torch.LongTensor] = None,
258
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
259
+ inputs_embeds: Optional[torch.FloatTensor] = None,
260
+ use_cache: Optional[bool] = None,
261
+ output_attentions: Optional[bool] = None,
262
+ output_hidden_states: Optional[bool] = None,
263
+ return_dict: Optional[bool] = None,
264
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
265
+ # pylint: disable=duplicate-code
266
+ output_attentions = (
267
+ output_attentions
268
+ if output_attentions is not None
269
+ else self.config.output_attentions
270
+ )
271
+ output_hidden_states = (
272
+ output_hidden_states
273
+ if output_hidden_states is not None
274
+ else self.config.output_hidden_states
275
+ )
276
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
277
+
278
+ return_dict = (
279
+ return_dict if return_dict is not None else self.config.use_return_dict
280
+ )
281
+
282
+ # Retrieve input_ids and inputs_embeds
283
+ if input_ids is not None and inputs_embeds is not None:
284
+ raise ValueError(
285
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
286
+ )
287
+ if input_ids is not None:
288
+ batch_size, seq_length = input_ids.shape
289
+ elif inputs_embeds is not None:
290
+ batch_size, seq_length, _ = inputs_embeds.shape
291
+ else:
292
+ raise ValueError(
293
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
294
+ )
295
+
296
+ seq_length_with_past = seq_length
297
+ past_key_values_length = 0
298
+
299
+ if past_key_values is not None:
300
+ past_key_values_length = past_key_values[0][0].shape[2]
301
+ seq_length_with_past = seq_length_with_past + past_key_values_length
302
+
303
+ cu_seqlens = None
304
+ max_seqlen = None
305
+ if position_ids is None:
306
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
307
+ position_ids = torch.arange(
308
+ past_key_values_length,
309
+ seq_length + past_key_values_length,
310
+ dtype=torch.long,
311
+ device=device,
312
+ )
313
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
314
+ else:
315
+ position_ids = position_ids.view(-1, seq_length).long()
316
+ cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
317
+ cu_seqlens = cu_seqlens.squeeze()
318
+
319
+ if inputs_embeds is None:
320
+ inputs_embeds = self.embed_tokens(input_ids)
321
+ # Embed positions
322
+ if attention_mask is None:
323
+ attention_mask = torch.ones(
324
+ (batch_size, seq_length_with_past),
325
+ dtype=torch.bool,
326
+ device=inputs_embeds.device,
327
+ )
328
+ attention_mask = (
329
+ self._prepare_decoder_attention_mask( # pylint: disable=protected-access
330
+ attention_mask,
331
+ (batch_size, seq_length),
332
+ inputs_embeds,
333
+ past_key_values_length,
334
+ )
335
+ )
336
+
337
+ hidden_states = inputs_embeds
338
+
339
+ if self.gradient_checkpointing and self.training:
340
+ if use_cache:
341
+ logger.warning(
342
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
343
+ )
344
+ use_cache = False
345
+
346
+ # Decoder layers
347
+ all_hidden_states = () if output_hidden_states else None
348
+ all_self_attns = () if output_attentions else None
349
+ next_decoder_cache = () if use_cache else None
350
+
351
+ for idx, decoder_layer in enumerate(self.layers):
352
+ if output_hidden_states:
353
+ all_hidden_states += (hidden_states,)
354
+
355
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
356
+
357
+ if self.gradient_checkpointing and self.training:
358
+
359
+ def create_custom_forward(module):
360
+ def custom_forward(*inputs):
361
+ # None for past_key_value
362
+ return module(*inputs)
363
+
364
+ return custom_forward
365
+
366
+ layer_outputs = torch.utils.checkpoint.checkpoint(
367
+ create_custom_forward(decoder_layer),
368
+ hidden_states,
369
+ attention_mask,
370
+ position_ids,
371
+ past_key_value,
372
+ output_attentions,
373
+ None,
374
+ cu_seqlens,
375
+ max_seqlen,
376
+ )
377
+ else:
378
+ layer_outputs = decoder_layer(
379
+ hidden_states,
380
+ attention_mask=attention_mask,
381
+ position_ids=position_ids,
382
+ past_key_value=past_key_value,
383
+ output_attentions=output_attentions,
384
+ use_cache=use_cache,
385
+ cu_seqlens=cu_seqlens,
386
+ max_seqlen=max_seqlen,
387
+ )
388
+
389
+ hidden_states = layer_outputs[0]
390
+
391
+ if use_cache:
392
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
393
+
394
+ if output_attentions:
395
+ all_self_attns += (layer_outputs[1],)
396
+
397
+ hidden_states = self.norm(hidden_states)
398
+
399
+ # Add hidden states from the last decoder layer
400
+ if output_hidden_states:
401
+ all_hidden_states += (hidden_states,)
402
+
403
+ next_cache = next_decoder_cache if use_cache else None
404
+ if not return_dict:
405
+ return tuple(
406
+ v
407
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
408
+ if v is not None
409
+ )
410
+ return BaseModelOutputWithPast(
411
+ last_hidden_state=hidden_states,
412
+ past_key_values=next_cache,
413
+ hidden_states=all_hidden_states,
414
+ attentions=all_self_attns,
415
+ )
src/axolotl/utils/models.py CHANGED
@@ -124,6 +124,17 @@ def load_model(
124
 
125
  replace_btlm_attn_with_flash_attn(cfg.base_model)
126
 
 
 
 
 
 
 
 
 
 
 
 
127
  if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
128
  if cfg.device not in ["mps", "cpu"] and not inference:
129
  from axolotl.monkeypatch.llama_attn_hijack_flash import (
 
124
 
125
  replace_btlm_attn_with_flash_attn(cfg.base_model)
126
 
127
+ if (
128
+ hasattr(model_config, "model_type")
129
+ and model_config.model_type == "stablelm_epoch"
130
+ ):
131
+ if cfg.flash_attention and cfg.sample_packing:
132
+ from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
133
+ replace_stablelm_attn_with_flash_attn,
134
+ )
135
+
136
+ replace_stablelm_attn_with_flash_attn(cfg.base_model)
137
+
138
  if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
139
  if cfg.device not in ["mps", "cpu"] and not inference:
140
  from axolotl.monkeypatch.llama_attn_hijack_flash import (