File size: 305 Bytes
f5a828a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
"""
Patches to support multipack for qwen2
"""
import transformers

from axolotl.monkeypatch.utils import get_unpad_data


def replace_qwen2_attn_with_multipack_flash_attn():
    transformers.models.qwen2.modeling_qwen2._get_unpad_data = (  # pylint: disable=protected-access
        get_unpad_data
    )