jrc joecummings winglian commited on
Commit
1d70f24
1 Parent(s): 317fa25

Add shifted sparse attention (#973) [skip-ci]

Browse files

* Add s2_attn to hijack flash code

* Refactor code to account for s2_attn

* Add test for models utils

* Add ``s2_attention`` option to llama configs

* Add ``s2_attention`` option to README config

* Format code to appease linter

* chore: lint

* Remove xpos and llama-landmark [bad merge]

* add e2e smoke tests for shifted sparse attention

* remove stray patch from merge

* update yml with link to paper for s2_attention/longlora

* fix assertion check for full fine tune

* increase sequence len for tests and PR feedback updates

* reduce context len to 16k for tests

* reduce context len to 16k for tests

* reduce batch size for larger context len and udpate test to check message

* fix test for message

---------

Co-authored-by: joecummings <jrcummings@devvm050.nha0.facebook.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>

README.md CHANGED
@@ -834,7 +834,8 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
834
  # Whether to use scaled-dot-product attention
835
  # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
836
  sdp_attention:
837
-
 
838
  # Resume from a specific checkpoint dir
839
  resume_from_checkpoint:
840
  # If resume_from_checkpoint isn't set and you simply want it to start where it left off.
 
834
  # Whether to use scaled-dot-product attention
835
  # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
836
  sdp_attention:
837
+ # Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
838
+ s2_attention:
839
  # Resume from a specific checkpoint dir
840
  resume_from_checkpoint:
841
  # If resume_from_checkpoint isn't set and you simply want it to start where it left off.
examples/code-llama/13b/lora.yml CHANGED
@@ -52,6 +52,7 @@ local_rank:
52
  logging_steps: 1
53
  xformers_attention:
54
  flash_attention: true
 
55
 
56
  warmup_steps: 10
57
  evals_per_epoch: 4
 
52
  logging_steps: 1
53
  xformers_attention:
54
  flash_attention: true
55
+ s2_attention:
56
 
57
  warmup_steps: 10
58
  evals_per_epoch: 4
examples/code-llama/34b/lora.yml CHANGED
@@ -52,6 +52,7 @@ local_rank:
52
  logging_steps: 1
53
  xformers_attention:
54
  flash_attention: true
 
55
 
56
  warmup_steps: 10
57
  evals_per_epoch: 4
 
52
  logging_steps: 1
53
  xformers_attention:
54
  flash_attention: true
55
+ s2_attention:
56
 
57
  warmup_steps: 10
58
  evals_per_epoch: 4
examples/code-llama/7b/lora.yml CHANGED
@@ -52,6 +52,7 @@ local_rank:
52
  logging_steps: 1
53
  xformers_attention:
54
  flash_attention: true
 
55
 
56
  warmup_steps: 10
57
  evals_per_epoch: 4
 
52
  logging_steps: 1
53
  xformers_attention:
54
  flash_attention: true
55
+ s2_attention:
56
 
57
  warmup_steps: 10
58
  evals_per_epoch: 4
examples/llama-2/lora.yml CHANGED
@@ -52,6 +52,7 @@ local_rank:
52
  logging_steps: 1
53
  xformers_attention:
54
  flash_attention: true
 
55
 
56
  warmup_steps: 10
57
  evals_per_epoch: 4
 
52
  logging_steps: 1
53
  xformers_attention:
54
  flash_attention: true
55
+ s2_attention:
56
 
57
  warmup_steps: 10
58
  evals_per_epoch: 4
examples/openllama-3b/lora.yml CHANGED
@@ -52,6 +52,7 @@ logging_steps: 1
52
  xformers_attention:
53
  flash_attention: true
54
  gptq_groupsize:
 
55
  gptq_model_v1:
56
  warmup_steps: 20
57
  evals_per_epoch: 4
 
52
  xformers_attention:
53
  flash_attention: true
54
  gptq_groupsize:
55
+ s2_attention:
56
  gptq_model_v1:
57
  warmup_steps: 20
58
  evals_per_epoch: 4
src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -70,11 +70,20 @@ def replace_llama_attn_with_flash_attn(
70
  packed: Optional[bool] = False,
71
  cross_entropy: Optional[bool] = False,
72
  rms_norm: Optional[bool] = False,
 
73
  ):
74
  transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
75
  _prepare_decoder_attention_mask
76
  )
77
- transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
 
 
 
 
 
 
 
 
78
  if packed:
79
  transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
80
  transformers.models.llama.modeling_llama.LlamaModel.forward = (
@@ -213,6 +222,136 @@ def _prepare_decoder_attention_mask(
213
  return attention_mask
214
 
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  def flashattn_forward(
217
  self,
218
  hidden_states: torch.Tensor,
 
70
  packed: Optional[bool] = False,
71
  cross_entropy: Optional[bool] = False,
72
  rms_norm: Optional[bool] = False,
73
+ use_shifted_sparse_attn: Optional[bool] = False,
74
  ):
75
  transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
76
  _prepare_decoder_attention_mask
77
  )
78
+ if use_shifted_sparse_attn:
79
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = (
80
+ flashattn_forward_with_s2attn
81
+ )
82
+ else:
83
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = (
84
+ flashattn_forward
85
+ )
86
+
87
  if packed:
88
  transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
89
  transformers.models.llama.modeling_llama.LlamaModel.forward = (
 
222
  return attention_mask
223
 
224
 
225
+ GROUP_SIZE_RATIO = 1 / 4
226
+
227
+
228
+ def flashattn_forward_with_s2attn(
229
+ self,
230
+ hidden_states: torch.Tensor,
231
+ attention_mask: Optional[torch.Tensor] = None,
232
+ position_ids: Optional[torch.Tensor] = None,
233
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
234
+ output_attentions: bool = False,
235
+ use_cache: bool = False,
236
+ padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
237
+ cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
238
+ max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
239
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
240
+ """Input shape: Batch x Time x Channel
241
+
242
+ From: https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
243
+
244
+ attention_mask: [bsz, q_len]
245
+
246
+ `cu_seqlens` will be ignored if provided
247
+ `max_seqlen` will be ignored if provided
248
+ """
249
+ if output_attentions:
250
+ warnings.warn(
251
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
252
+ )
253
+
254
+ bsz, q_len, _ = hidden_states.size()
255
+
256
+ query_states = (
257
+ self.q_proj(hidden_states)
258
+ .view(bsz, q_len, self.num_heads, self.head_dim)
259
+ .transpose(1, 2)
260
+ )
261
+ key_states = (
262
+ self.k_proj(hidden_states)
263
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
264
+ .transpose(1, 2)
265
+ )
266
+ value_states = (
267
+ self.v_proj(hidden_states)
268
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
269
+ .transpose(1, 2)
270
+ )
271
+ # [bsz, q_len, nh, hd]
272
+ # [bsz, nh, q_len, hd]
273
+ # pylint: disable=duplicate-code
274
+
275
+ kv_seq_len = key_states.shape[-2]
276
+ if past_key_value is not None:
277
+ kv_seq_len += past_key_value[0].shape[-2]
278
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
279
+ query_states, key_states = apply_rotary_pos_emb(
280
+ query_states, key_states, cos, sin, position_ids
281
+ )
282
+
283
+ # Past Key value support
284
+ if past_key_value is not None:
285
+ # reuse k, v, self_attention
286
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
287
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
288
+
289
+ past_key_value = (key_states, value_states) if use_cache else None
290
+
291
+ # repeat k/v heads if n_kv_heads < n_heads
292
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
293
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
294
+
295
+ # Flash attention codes from
296
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
297
+
298
+ # transform the data into the format required by flash attention
299
+ qkv = torch.stack(
300
+ [query_states, key_states, value_states], dim=2
301
+ ) # [bsz, nh, 3, q_len, hd]
302
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
303
+
304
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
305
+ # the attention_mask should be the same as the key_padding_mask
306
+
307
+ key_padding_mask = attention_mask.repeat(2, 1)
308
+ nheads = qkv.shape[-2]
309
+ # shift
310
+
311
+ group_size = int(q_len * GROUP_SIZE_RATIO)
312
+ if q_len % group_size > 0:
313
+ raise ValueError(
314
+ f"q_len {q_len} should be divisible by group size {group_size}."
315
+ )
316
+
317
+ qkv = (
318
+ qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim)
319
+ .permute(0, 3, 1, 2, 4, 5)
320
+ .reshape(bsz * 2, q_len, 3, self.num_heads // 2, self.head_dim)
321
+ )
322
+ x = rearrange( # pylint: disable=invalid-name
323
+ qkv, "b s three h d -> b s (three h d)"
324
+ )
325
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
326
+ cu_q_len_tmp = torch.arange(
327
+ 0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype
328
+ )
329
+ cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp + group_size // 2]).repeat(
330
+ bsz, 1
331
+ ) + cu_q_lens[:-1].unsqueeze(-1)
332
+ cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)
333
+
334
+ x_unpad = rearrange(
335
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2
336
+ )
337
+ output_unpad = flash_attn_varlen_qkvpacked_func(
338
+ x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True
339
+ )
340
+ output = rearrange(
341
+ pad_input(
342
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len
343
+ ),
344
+ "b s (h d) -> b s h d",
345
+ h=nheads // 2,
346
+ )
347
+ output = (
348
+ output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim)
349
+ .transpose(1, 2)
350
+ .reshape(bsz, q_len, nheads, self.head_dim)
351
+ )
352
+ return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
353
+
354
+
355
  def flashattn_forward(
356
  self,
357
  hidden_states: torch.Tensor,
src/axolotl/utils/models.py CHANGED
@@ -256,31 +256,55 @@ def load_model(
256
 
257
  replace_stablelm_attn_with_flash_attn(cfg.base_model)
258
 
259
- if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
260
- if cfg.device not in ["mps", "cpu"] and not inference:
 
 
 
 
 
 
 
261
  from axolotl.monkeypatch.llama_attn_hijack_flash import (
262
  replace_llama_attn_with_flash_attn,
263
  )
264
 
265
- LOG.info("patching with flash attention for sample packing")
266
- replace_llama_attn_with_flash_attn(
267
- packed=cfg.sample_packing,
268
- cross_entropy=cfg.flash_attn_cross_entropy,
269
- rms_norm=cfg.flash_attn_rms_norm,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  )
271
- elif cfg.is_llama_derived_model and cfg.xformers_attention:
272
- from axolotl.monkeypatch.llama_attn_hijack_xformers import (
273
- hijack_llama_attention,
274
- )
275
 
276
- LOG.info("patching with xformers attention")
277
- hijack_llama_attention()
278
- elif cfg.is_llama_derived_model and cfg.sdp_attention:
279
- from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention
 
 
280
 
281
- LOG.info("patching with sdp attention")
282
- hijack_llama_sdp_attention()
 
 
 
 
283
 
 
284
  if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
285
  from axolotl.monkeypatch.mistral_attn_hijack_flash import (
286
  replace_mistral_attn_with_flash_attn,
@@ -387,9 +411,12 @@ def load_model(
387
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
388
  **bnb_config,
389
  )
 
390
  # sample packing uses custom FA2 patch
391
  if cfg.flash_attention:
392
  if not cfg.sample_packing:
 
 
393
  if (
394
  cfg.is_llama_derived_model
395
  or cfg.is_falcon_derived_model
 
256
 
257
  replace_stablelm_attn_with_flash_attn(cfg.base_model)
258
 
259
+ if cfg.sample_packing and cfg.s2_attention:
260
+ raise ValueError(
261
+ "Received `sample_packing=true` and `s2_attention=true`; however, \
262
+ shifted-sparse attention does not currently support sample packing."
263
+ )
264
+
265
+ # Modify all llama derived models in one block
266
+ if cfg.is_llama_derived_model:
267
+ if cfg.flash_attention:
268
  from axolotl.monkeypatch.llama_attn_hijack_flash import (
269
  replace_llama_attn_with_flash_attn,
270
  )
271
 
272
+ if cfg.sample_packing:
273
+ if cfg.device not in ["mps", "cpu"] and not inference:
274
+ LOG.info("patching with flash attention for sample packing")
275
+ replace_llama_attn_with_flash_attn(
276
+ packed=True,
277
+ cross_entropy=cfg.flash_attn_cross_entropy,
278
+ rms_norm=cfg.flash_attn_rms_norm,
279
+ )
280
+ elif cfg.s2_attention:
281
+ LOG.info("patching w/ flash-enabled, shifted-sparse attention")
282
+ replace_llama_attn_with_flash_attn(
283
+ packed=False,
284
+ cross_entropy=cfg.flash_attn_cross_entropy,
285
+ rms_norm=cfg.flash_attn_rms_norm,
286
+ use_shifted_sparse_attn=True,
287
+ )
288
+ elif cfg.xformers_attention:
289
+ from axolotl.monkeypatch.llama_attn_hijack_xformers import (
290
+ hijack_llama_attention,
291
  )
 
 
 
 
292
 
293
+ LOG.info("patching with xformers attention")
294
+ hijack_llama_attention()
295
+ elif cfg.sdp_attention:
296
+ from axolotl.monkeypatch.llama_attn_hijack_sdp import (
297
+ hijack_llama_sdp_attention,
298
+ )
299
 
300
+ LOG.info("patching with sdp attention")
301
+ hijack_llama_sdp_attention()
302
+ elif cfg.s2_attention:
303
+ raise NotImplementedError(
304
+ "Shifted-sparse attention not currently implemented without flash attention."
305
+ )
306
 
307
+ # Modify mistral derived models
308
  if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
309
  from axolotl.monkeypatch.mistral_attn_hijack_flash import (
310
  replace_mistral_attn_with_flash_attn,
 
411
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
412
  **bnb_config,
413
  )
414
+
415
  # sample packing uses custom FA2 patch
416
  if cfg.flash_attention:
417
  if not cfg.sample_packing:
418
+ if cfg.s2_attention:
419
+ pass
420
  if (
421
  cfg.is_llama_derived_model
422
  or cfg.is_falcon_derived_model
tests/e2e/patched/test_llama_s2_attention.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for llama w/ S2 attn
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import unittest
8
+ from pathlib import Path
9
+
10
+ from axolotl.cli import load_datasets
11
+ from axolotl.common.cli import TrainerCliArgs
12
+ from axolotl.train import train
13
+ from axolotl.utils.config import normalize_config
14
+ from axolotl.utils.dict import DictDefault
15
+
16
+ from ..utils import with_temp_dir
17
+
18
+ LOG = logging.getLogger("axolotl.tests.e2e")
19
+ os.environ["WANDB_DISABLED"] = "true"
20
+
21
+
22
+ class TestLlamaShiftedSparseAttention(unittest.TestCase):
23
+ """
24
+ Test case for Llama models using S2 Attn
25
+ """
26
+
27
+ @with_temp_dir
28
+ def test_lora_s2_attn(self, temp_dir):
29
+ # pylint: disable=duplicate-code
30
+ cfg = DictDefault(
31
+ {
32
+ "base_model": "JackFram/llama-68m",
33
+ "tokenizer_type": "LlamaTokenizer",
34
+ "sequence_len": 16384,
35
+ "sample_packing": False,
36
+ "flash_attention": True,
37
+ "s2_attention": True,
38
+ "load_in_8bit": True,
39
+ "adapter": "lora",
40
+ "lora_r": 32,
41
+ "lora_alpha": 16,
42
+ "lora_dropout": 0.05,
43
+ "lora_target_linear": True,
44
+ "val_set_size": 0.1,
45
+ "special_tokens": {},
46
+ "datasets": [
47
+ {
48
+ "path": "Yukang/LongAlpaca-12k",
49
+ "type": "alpaca",
50
+ },
51
+ ],
52
+ "num_epochs": 2,
53
+ "micro_batch_size": 1,
54
+ "gradient_accumulation_steps": 1,
55
+ "output_dir": temp_dir,
56
+ "learning_rate": 0.00001,
57
+ "optimizer": "adamw_torch",
58
+ "lr_scheduler": "cosine",
59
+ "max_steps": 10,
60
+ "save_steps": 5,
61
+ "eval_steps": 5,
62
+ "bf16": "auto",
63
+ }
64
+ )
65
+
66
+ normalize_config(cfg)
67
+ cli_args = TrainerCliArgs()
68
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
69
+
70
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
71
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
72
+
73
+ @with_temp_dir
74
+ def test_fft_s2_attn(self, temp_dir):
75
+ # pylint: disable=duplicate-code
76
+ cfg = DictDefault(
77
+ {
78
+ "base_model": "JackFram/llama-68m",
79
+ "tokenizer_type": "LlamaTokenizer",
80
+ "sequence_len": 16384,
81
+ "sample_packing": False,
82
+ "flash_attention": True,
83
+ "s2_attention": True,
84
+ "val_set_size": 0.1,
85
+ "special_tokens": {},
86
+ "datasets": [
87
+ {
88
+ "path": "Yukang/LongAlpaca-12k",
89
+ "type": "alpaca",
90
+ },
91
+ ],
92
+ "num_epochs": 2,
93
+ "micro_batch_size": 1,
94
+ "gradient_accumulation_steps": 1,
95
+ "output_dir": temp_dir,
96
+ "learning_rate": 0.00001,
97
+ "optimizer": "adamw_torch",
98
+ "lr_scheduler": "cosine",
99
+ "max_steps": 10,
100
+ "save_steps": 5,
101
+ "eval_steps": 5,
102
+ "bf16": "auto",
103
+ }
104
+ )
105
+
106
+ normalize_config(cfg)
107
+ cli_args = TrainerCliArgs()
108
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
109
+
110
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
111
+ assert (Path(temp_dir) / "pytorch_model.bin").exists()
tests/utils/test_models.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for testing models utils file."""
2
+
3
+
4
+ import unittest
5
+ from unittest.mock import patch
6
+
7
+ import pytest
8
+
9
+ from axolotl.utils.dict import DictDefault
10
+ from axolotl.utils.models import load_model
11
+
12
+
13
+ class ModelsUtilsTest(unittest.TestCase):
14
+ """Testing module for models utils."""
15
+
16
+ def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
17
+ cfg = DictDefault(
18
+ {
19
+ "s2_attention": True,
20
+ "sample_packing": True,
21
+ "base_model": "",
22
+ "model_type": "LlamaForCausalLM",
23
+ }
24
+ )
25
+
26
+ # Mock out call to HF hub
27
+ with patch(
28
+ "axolotl.utils.models.load_model_config"
29
+ ) as mocked_load_model_config:
30
+ mocked_load_model_config.return_value = {}
31
+ with pytest.raises(ValueError) as exc:
32
+ # Should error before hitting tokenizer, so we pass in an empty str
33
+ load_model(cfg, tokenizer="")
34
+ assert (
35
+ "shifted-sparse attention does not currently support sample packing"
36
+ in str(exc.value)
37
+ )