tmm1 commited on
Commit
5fe30b1
1 Parent(s): 44454ae

use flash_attn xentropy when available (#525)

Browse files

* use flash_attn xentropy when available

* log when xentropy is not found

src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -2,7 +2,9 @@
2
 
3
  # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
4
 
 
5
  import warnings
 
6
  from typing import List, Optional, Tuple, Union
7
 
8
  import torch
@@ -33,6 +35,9 @@ except ImportError:
33
  )
34
 
35
 
 
 
 
36
  def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
37
  transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
38
  _prepare_decoder_attention_mask
@@ -44,6 +49,18 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
44
  llama_model_forward
45
  )
46
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # Disable the transformation of the attention mask in LlamaModel as the flash attention
49
  # requires the attention mask to be the same as the key_padding_mask
 
2
 
3
  # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
4
 
5
+ import logging
6
  import warnings
7
+ from functools import partial
8
  from typing import List, Optional, Tuple, Union
9
 
10
  import torch
 
35
  )
36
 
37
 
38
+ LOG = logging.getLogger("axolotl")
39
+
40
+
41
  def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
42
  transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
43
  _prepare_decoder_attention_mask
 
49
  llama_model_forward
50
  )
51
 
52
+ try:
53
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
54
+
55
+ LOG.info("patching with flash_attn.losses.cross_entropy")
56
+ transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
57
+ CrossEntropyLoss, inplace_backward=True
58
+ )
59
+ except ImportError:
60
+ LOG.info(
61
+ "optimized flash-attention CrossEntropyLoss not found (run `pip install git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy`)"
62
+ )
63
+
64
 
65
  # Disable the transformation of the attention mask in LlamaModel as the flash attention
66
  # requires the attention mask to be the same as the key_padding_mask