winglian commited on
Commit
c56818b
1 Parent(s): 2675fb7

don't worry about dupes

Browse files
src/axolotl/flash_attn.py CHANGED
@@ -25,6 +25,7 @@ def forward(
25
 
26
  attention_mask: [bsz, q_len]
27
  """
 
28
  bsz, q_len, _ = hidden_states.size()
29
 
30
  query_states = (
 
25
 
26
  attention_mask: [bsz, q_len]
27
  """
28
+ # pylint: disable=duplicate-code
29
  bsz, q_len, _ = hidden_states.size()
30
 
31
  query_states = (
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py CHANGED
@@ -35,6 +35,7 @@ def xformers_forward(
35
  output_attentions: bool = False,
36
  use_cache: bool = False,
37
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
38
  bsz, q_len, _ = hidden_states.size()
39
 
40
  query_states = (
@@ -143,6 +144,7 @@ def sdp_attention_forward(
143
  output_attentions: bool = False,
144
  use_cache: bool = False,
145
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
146
  bsz, q_len, _ = hidden_states.size()
147
 
148
  query_states = (
 
35
  output_attentions: bool = False,
36
  use_cache: bool = False,
37
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
38
+ # pylint: disable=duplicate-code
39
  bsz, q_len, _ = hidden_states.size()
40
 
41
  query_states = (
 
144
  output_attentions: bool = False,
145
  use_cache: bool = False,
146
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
147
+ # pylint: disable=duplicate-code
148
  bsz, q_len, _ = hidden_states.size()
149
 
150
  query_states = (