File size: 2,365 Bytes
7576105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch import nn
from transformers.trainer_pt_utils import LabelSmoother
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
DEFAULT_SPEECH_TOKEN = "<speech>"


class SPEECH_LLM(nn.Module):
    """
    The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
    The encoder is used to extract speech features from the input speech signal.
    The encoder projector is used to project the encoder outputs to the same dimension as the language model.
    The language model is used to generate the text from the speech features.
    Args:
        encoder (:obj:`nn.Module`): The encoder module.
        llm (:obj:`nn.Module`): The language model module.
        encoder_projector (:obj:`nn.Module`): The encoder projector module.
    """

    def __init__(
        self,
        encoder: nn.Module = None,
        llm: nn.Module = None,
        encoder_projector: nn.Module = None,
    ):
        super().__init__()
        self.encoder = encoder
        self.llm = llm
        self.encoder_projector = encoder_projector



if __name__ == "__main__":
    speech_encoder_dim = 1280
    encoder_projector_ds_rate = 8
    llm_config_hidden_size = 1536

    adapter_dir="/home/scratch.yuekaiz_wwfo_1/icefall_asr_multi-hans_whisper_qwen2_1.5B/epoch-2-avg-6.pt"
    llm_dir="/home/scratch.yuekaiz_wwfo_1/Qwen2-1.5B-Instruct"
    target_dir="/home/scratch.yuekaiz_wwfo_1/Qwen2_1.5B_merged"

    llm = AutoModelForCausalLM.from_pretrained(
        llm_dir,
        torch_dtype=torch.float16,
    )
    lora_config = LoraConfig(
        r=64,
        lora_alpha=16,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "up_proj",
            "gate_proj",
            "down_proj",
        ],
        task_type="CAUSAL_LM",
    )
    llm = get_peft_model(llm, lora_config)
    model = SPEECH_LLM(
        llm = llm,
    )
    
    checkpoint = torch.load(
        adapter_dir, map_location="cpu"
    )
    missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)

    print(missing_keys, unexpected_keys)

    llm_merged = model.llm.merge_and_unload()

    llm_merged.save_pretrained(target_dir)