winglian commited on
Commit
895f0a0
1 Parent(s): e7d3e2d

skip some flash attn patches unless explicitly enabled (#643)

Browse files

* skip some flash attn patches if explicitly disabled

* make the other patches optional

README.md CHANGED
@@ -636,6 +636,8 @@ flash_optimum:
636
  xformers_attention:
637
  # whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
638
  flash_attention:
 
 
639
  # whether to use scaled-dot-product attention
640
  # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
641
  sdp_attention:
 
636
  xformers_attention:
637
  # whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
638
  flash_attention:
639
+ flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
640
+ flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
641
  # whether to use scaled-dot-product attention
642
  # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
643
  sdp_attention:
src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -38,7 +38,11 @@ except ImportError:
38
  LOG = logging.getLogger("axolotl")
39
 
40
 
41
- def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
 
 
 
 
42
  transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
43
  _prepare_decoder_attention_mask
44
  )
@@ -49,33 +53,37 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
49
  llama_model_forward
50
  )
51
 
52
- try:
53
- from flash_attn.losses.cross_entropy import CrossEntropyLoss
 
 
54
 
55
- LOG.info("patching with flash_attn.losses.cross_entropy")
56
- transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
57
- CrossEntropyLoss, inplace_backward=True
58
- )
59
- except ImportError:
60
- LOG.info(
61
- "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
62
- )
63
 
64
- try:
65
- from flash_attn.ops.rms_norm import RMSNorm
 
 
66
 
67
- class LlamaRMSNorm(RMSNorm):
68
- """Patched LLamaRMSNorm"""
69
 
70
- def __init__(self, hidden_size, eps=1e-6):
71
- super().__init__(hidden_size, eps=eps)
72
 
73
- LOG.info("patching with flash_attn.ops.rms_norm")
74
- transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
75
- except ImportError:
76
- LOG.info(
77
- "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
78
- )
79
 
80
 
81
  # Disable the transformation of the attention mask in LlamaModel as the flash attention
 
38
  LOG = logging.getLogger("axolotl")
39
 
40
 
41
+ def replace_llama_attn_with_flash_attn(
42
+ packed: Optional[bool] = False,
43
+ cross_entropy: Optional[bool] = False,
44
+ rms_norm: Optional[bool] = False,
45
+ ):
46
  transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
47
  _prepare_decoder_attention_mask
48
  )
 
53
  llama_model_forward
54
  )
55
 
56
+ # skip only if explicitly disabled
57
+ if cross_entropy:
58
+ try:
59
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
60
 
61
+ LOG.info("patching with flash_attn.losses.cross_entropy")
62
+ transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
63
+ CrossEntropyLoss, inplace_backward=True
64
+ )
65
+ except ImportError:
66
+ LOG.info(
67
+ "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
68
+ )
69
 
70
+ # skip only if explicitly disabled
71
+ if rms_norm:
72
+ try:
73
+ from flash_attn.ops.rms_norm import RMSNorm
74
 
75
+ class LlamaRMSNorm(RMSNorm):
76
+ """Patched LLamaRMSNorm"""
77
 
78
+ def __init__(self, hidden_size, eps=1e-6):
79
+ super().__init__(hidden_size, eps=eps)
80
 
81
+ LOG.info("patching with flash_attn.ops.rms_norm")
82
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
83
+ except ImportError:
84
+ LOG.info(
85
+ "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
86
+ )
87
 
88
 
89
  # Disable the transformation of the attention mask in LlamaModel as the flash attention
src/axolotl/utils/models.py CHANGED
@@ -121,7 +121,11 @@ def load_model(
121
  )
122
 
123
  LOG.info("patching with flash attention for sample packing")
124
- replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
 
 
 
 
125
  elif cfg.is_llama_derived_model and cfg.xformers_attention:
126
  from axolotl.monkeypatch.llama_attn_hijack_xformers import (
127
  hijack_llama_attention,
 
121
  )
122
 
123
  LOG.info("patching with flash attention for sample packing")
124
+ replace_llama_attn_with_flash_attn(
125
+ packed=cfg.sample_packing,
126
+ cross_entropy=cfg.flash_attn_cross_entropy,
127
+ rms_norm=cfg.flash_attn_rms_norm,
128
+ )
129
  elif cfg.is_llama_derived_model and cfg.xformers_attention:
130
  from axolotl.monkeypatch.llama_attn_hijack_xformers import (
131
  hijack_llama_attention,