""" Patches to support multipack for mixtral """ import transformers def replace_mixtral_attn_with_multipack_flash_attn(): from .modeling_mixtral import ( MixtralMultipackFlashAttention2, mixtral_decoder_layer_forward, mixtral_model_forward, ) transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = ( mixtral_decoder_layer_forward ) transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = ( mixtral_model_forward ) transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[ "flash_attention_2" ] = MixtralMultipackFlashAttention2