yuekai's picture
Upload folder using huggingface_hub
7576105 verified
import torch
from torch import nn
from transformers.trainer_pt_utils import LabelSmoother
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
DEFAULT_SPEECH_TOKEN = "<speech>"
class SPEECH_LLM(nn.Module):
"""
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
The encoder is used to extract speech features from the input speech signal.
The encoder projector is used to project the encoder outputs to the same dimension as the language model.
The language model is used to generate the text from the speech features.
Args:
encoder (:obj:`nn.Module`): The encoder module.
llm (:obj:`nn.Module`): The language model module.
encoder_projector (:obj:`nn.Module`): The encoder projector module.
"""
def __init__(
self,
encoder: nn.Module = None,
llm: nn.Module = None,
encoder_projector: nn.Module = None,
):
super().__init__()
self.encoder = encoder
self.llm = llm
self.encoder_projector = encoder_projector
if __name__ == "__main__":
speech_encoder_dim = 1280
encoder_projector_ds_rate = 8
llm_config_hidden_size = 1536
adapter_dir="/home/scratch.yuekaiz_wwfo_1/icefall_asr_multi-hans_whisper_qwen2_1.5B/epoch-2-avg-6.pt"
llm_dir="/home/scratch.yuekaiz_wwfo_1/Qwen2-1.5B-Instruct"
target_dir="/home/scratch.yuekaiz_wwfo_1/Qwen2_1.5B_merged"
llm = AutoModelForCausalLM.from_pretrained(
llm_dir,
torch_dtype=torch.float16,
)
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
model = SPEECH_LLM(
llm = llm,
)
checkpoint = torch.load(
adapter_dir, map_location="cpu"
)
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
print(missing_keys, unexpected_keys)
llm_merged = model.llm.merge_and_unload()
llm_merged.save_pretrained(target_dir)