File size: 339 Bytes
40a6362
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
"""
Modeling module for Mamba models
"""


def fix_mamba_attn_for_loss():
    from mamba_ssm.models import mixer_seq_simple

    from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed

    mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
    return mixer_seq_simple.MambaLMHeadModel  # pylint: disable=invalid-name