""" HF Transformers MambaConfig """ from transformers import PretrainedConfig class MambaConfig(PretrainedConfig): """ modeling configuration for state space model/mamba """ model_type = "mamba" def __init__( self, vocab_size=50280, d_model=2560, n_layer=64, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, pad_vocab_size_multiple=8, pad_token_id=50277, bos_token_id=0, eos_token_id=0, tie_word_embeddings=False, **kwargs, ): self.vocab_size = vocab_size self.d_model = d_model self.n_layer = n_layer self.rms_norm = rms_norm self.residual_in_fp32 = residual_in_fp32 self.fused_add_norm = fused_add_norm self.pad_vocab_size_multiple = pad_vocab_size_multiple super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, )