File size: 313 Bytes
7fabc4d
 
 
 
 
6910e6a
7fabc4d
 
6910e6a
 
 
7fabc4d
1
2
3
4
5
6
7
8
9
10
11
12
13
"""
Patches to support multipack for mixtral
"""
import transformers

from axolotl.monkeypatch.utils import get_unpad_data


def replace_mixtral_attn_with_multipack_flash_attn():
    transformers.models.mixtral.modeling_mixtral._get_unpad_data = (  # pylint: disable=protected-access
        get_unpad_data
    )