File size: 658 Bytes
7fabc4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
"""
Patches to support multipack for mixtral
"""
import transformers


def replace_mixtral_attn_with_multipack_flash_attn():
    from .modeling_mixtral import (
        MixtralMultipackFlashAttention2,
        mixtral_decoder_layer_forward,
        mixtral_model_forward,
    )

    transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
        mixtral_decoder_layer_forward
    )
    transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
        mixtral_model_forward
    )
    transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
        "flash_attention_2"
    ] = MixtralMultipackFlashAttention2