winglian commited on
Commit
9b790d3
1 Parent(s): 06c61d6

flash attention 2

Browse files
docker/Dockerfile-base CHANGED
@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
40
 
41
  RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
42
  cd flash-attention && \
43
- git checkout v1.0.9 && \
44
  python3 setup.py bdist_wheel && \
45
  cd csrc/fused_dense_lib && \
46
  python3 setup.py bdist_wheel && \
 
40
 
41
  RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
42
  cd flash-attention && \
43
+ git checkout v2.0.0 && \
44
  python3 setup.py bdist_wheel && \
45
  cd csrc/fused_dense_lib && \
46
  python3 setup.py bdist_wheel && \
src/axolotl/flash_attn.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
  import transformers
9
  from einops import rearrange
10
  from flash_attn.bert_padding import pad_input, unpad_input
11
- from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
12
  from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
13
 
14
 
@@ -79,7 +79,7 @@ def forward(
79
  dtype=torch.int32,
80
  device=qkv.device,
81
  )
82
- output = flash_attn_unpadded_qkvpacked_func(
83
  qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
84
  )
85
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
@@ -95,7 +95,7 @@ def forward(
95
  three=3,
96
  h=nheads,
97
  )
98
- output_unpad = flash_attn_unpadded_qkvpacked_func(
99
  x_unpad,
100
  cu_q_lens,
101
  max_s,
 
8
  import transformers
9
  from einops import rearrange
10
  from flash_attn.bert_padding import pad_input, unpad_input
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
12
  from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
13
 
14
 
 
79
  dtype=torch.int32,
80
  device=qkv.device,
81
  )
82
+ output = flash_attn_varlen_qkvpacked_func(
83
  qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
84
  )
85
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
 
95
  three=3,
96
  h=nheads,
97
  )
98
+ output_unpad = flash_attn_varlen_qkvpacked_func(
99
  x_unpad,
100
  cu_q_lens,
101
  max_s,