x54-729 commited on
Commit
8586def
1 Parent(s): d1913f2

remove unnecessary attention_drop

Browse files
Files changed (1) hide show
  1. modeling_internlm2.py +1 -3
modeling_internlm2.py CHANGED
@@ -480,10 +480,8 @@ class InternLM2FlashAttention2(InternLM2Attention):
480
  key_states = key_states.transpose(1, 2)
481
  value_states = value_states.transpose(1, 2)
482
 
483
- dropout_rate = 0.0 if not self.training else self.attention_dropout
484
-
485
  attn_output = self._flash_attention_forward(
486
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
487
  )
488
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
489
  attn_output = self.wo(attn_output)
 
480
  key_states = key_states.transpose(1, 2)
481
  value_states = value_states.transpose(1, 2)
482
 
 
 
483
  attn_output = self._flash_attention_forward(
484
+ query_states, key_states, value_states, attention_mask, q_len
485
  )
486
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
487
  attn_output = self.wo(attn_output)