nyanko7 commited on
Commit
2a05fdc
1 Parent(s): 5e461b1

Update modules/model.py

Browse files
Files changed (1) hide show
  1. modules/model.py +4 -6
modules/model.py CHANGED
@@ -203,14 +203,14 @@ class CrossAttnProcessor(nn.Module):
203
  k_bucket_size = 1024
204
 
205
  # use flash-attention
206
- hidden_states = FlashAttn.apply(
207
  query.contiguous(),
208
  key.contiguous(),
209
  value.contiguous(),
210
  attention_mask,
211
- causal=False,
212
- q_bucket_size=q_bucket_size,
213
- k_bucket_size=k_bucket_size,
214
  )
215
  hidden_states = hidden_states.to(query.dtype)
216
 
@@ -1021,5 +1021,3 @@ class FlashAttentionFunction(Function):
1021
  dvc.add_(dv_chunk)
1022
 
1023
  return dq, dk, dv, None, None, None, None
1024
-
1025
- FlashAttn = FlashAttentionFunction()
 
203
  k_bucket_size = 1024
204
 
205
  # use flash-attention
206
+ hidden_states = FlashAttentionFunction.apply(
207
  query.contiguous(),
208
  key.contiguous(),
209
  value.contiguous(),
210
  attention_mask,
211
+ False,
212
+ q_bucket_size,
213
+ k_bucket_size,
214
  )
215
  hidden_states = hidden_states.to(query.dtype)
216
 
 
1021
  dvc.add_(dv_chunk)
1022
 
1023
  return dq, dk, dv, None, None, None, None