winglian commited on
Commit
a032c9f
1 Parent(s): b06d3e3

fix sdp attention to use the flash/mem-efficient context manaager

Browse files
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py CHANGED
@@ -184,14 +184,15 @@ def sdp_attention_forward(
184
 
185
  # We only apply sdp attention if we don't need to output the whole attention matrix
186
  if not output_attentions:
187
- attn_output = torch.nn.functional.scaled_dot_product_attention(
188
- query_states,
189
- key_states,
190
- value_states,
191
- attn_mask=attention_mask,
192
- is_causal=False,
193
- )
194
- attn_weights = None
 
195
  else:
196
  attn_weights = torch.matmul(
197
  query_states, key_states.transpose(2, 3)
 
184
 
185
  # We only apply sdp attention if we don't need to output the whole attention matrix
186
  if not output_attentions:
187
+ with torch.backends.cuda.sdp_kernel():
188
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
189
+ query_states,
190
+ key_states,
191
+ value_states,
192
+ attn_mask=attention_mask,
193
+ is_causal=False,
194
+ )
195
+ attn_weights = None
196
  else:
197
  attn_weights = torch.matmul(
198
  query_states, key_states.transpose(2, 3)