from __future__ import print_function import copy import os import pathlib import typing import numpy as np import torch import torch.distributed as dist import torch.nn as nn import time from queue import Queue from threading import Thread import sys sys.path.append('/usr/lib/lyralib') import lyraOp str_type_map = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} class LlamaModel(nn.Module): def __init__(self, head_num, size_per_head, inter_size, vocab_size, rotary_embedding_dim, start_id, end_id, layer_num, max_seq_len: int, layernorm_eps, tensor_para_size: int, pipeline_para_size: int, use_gptj_residual, lib_path: typing.Union[str, pathlib.Path], model_path, kvqparams_fpath: str = "", memopt_mode: int = 0, quant_data_type: str = "int8", inference_data_type: str = "fp16", weights_data_type: typing.Union[str, np.dtype] = np.float32, kv_head_num = 0, rope_theta = 10000.0): super().__init__() self.head_num = head_num self.kv_head_num = kv_head_num self.size_per_head = size_per_head self.inter_size = inter_size self.vocab_size = vocab_size self.rotary_embedding_dim = rotary_embedding_dim self.start_id = start_id self.end_id = end_id self.max_seq_len = max_seq_len self.layer_num = layer_num self.use_gptj_residual = use_gptj_residual self.layernorm_eps = layernorm_eps self.memopt_mode = memopt_mode self.quant_data_type = quant_data_type self.rope_theta = rope_theta # multi-gpu params self.tensor_para_size = tensor_para_size self.pipeline_para_size = pipeline_para_size self.build_model = False self.weights_data_type = weights_data_type self.inference_data_type = inference_data_type # queue for streaming self.que = Queue() self.threads = [None] * self.tensor_para_size assert torch.cuda.is_available(), "CUDA is required for this model." assert head_num % tensor_para_size == 0, "head_num must be a multiple of tensor_para_size." assert layer_num % pipeline_para_size == 0, "layer_num must be a multiple of pipeline_para_size." # Load the C++ model into Pytorch model. # torch.classes.load_library(os.path.abspath(lib_path)) # Prepare for tensor/pipeline parallel try: dist.init_process_group(backend='mpi') except: print("[INFO] WARNING: Have initialized the process group") self.rank = dist.get_rank() self.device_count = torch.cuda.device_count() self.device = self.rank % self.device_count torch.cuda.set_device(self.device) world_size = dist.get_world_size() # print(tensor_para_size * pipeline_para_size) assert world_size == tensor_para_size * pipeline_para_size, "tensor_para_size * pipeline_para_size must be equal to world_size." self.tensor_para_rank = self.rank % self.tensor_para_size self.pipeline_para_rank = self.rank // self.tensor_para_size if self.kv_head_num == 0: self.kv_head_num = self.head_num self.model = lyraOp.LyraLlama( self.head_num, self.size_per_head, self.inter_size, self.layer_num, self.vocab_size, self.rotary_embedding_dim, self.layernorm_eps, self.start_id, self.end_id, self.tensor_para_size, self.pipeline_para_size, self.max_seq_len, self.use_gptj_residual, self.memopt_mode, self.quant_data_type, model_path, kvqparams_fpath, self.weights_data_type, self.inference_data_type, self.kv_head_num, self.rope_theta) self.build_model = True torch.cuda.empty_cache() def forward(self, start_ids: torch.Tensor, start_lengths: torch.Tensor, output_len, beam_width=1, top_k: torch.Tensor = None, top_p: torch.Tensor = None, beam_search_diversity_rate: torch.Tensor = None, temperature: torch.Tensor = None, len_penalty: torch.Tensor = None, repetition_penalty: torch.Tensor = None, random_seed: torch.Tensor = None, return_output_length=False, return_cum_log_probs=0): input_len = start_ids.size(1) assert input_len > 0, "input len must be larger than zero. For an unconditional case, use start_id as the first token." # Inputs to device input_ids = start_ids.cuda(self.device) input_lengths = start_lengths.cuda(self.device) # outputs: output_ids, output_lengths, output_cum_log_probs (optional) outputs = self.model.forward(input_ids, input_lengths, output_len, beam_width, # optional, can be None top_k, # optional, can be None top_p, # optional, can be None beam_search_diversity_rate, # optional, can be None temperature, # optional, can be None len_penalty, # optional, can be None repetition_penalty, # optional, can be None random_seed, # optional, can be None return_cum_log_probs) # optional, can be None if return_cum_log_probs == 0: output_ids, output_lengths = outputs else: output_ids, output_lengths, output_cum_log_probs = outputs if return_output_length: if return_cum_log_probs > 0: return output_ids, output_lengths, output_cum_log_probs else: return output_ids, output_lengths else: return output_ids def set_input_tensor(self, input_tensor): """Set input tensor to be used instead of forward()'s input. When doing pipeline parallelism the input from the previous stage comes from communication, not from the input, so the model's forward_step_func won't have it. This function is thus used by internal code to bypass the input provided by the forward_step_func""" self.input_tensor = input_tensor def _forward_callback(self, output_ids, seq_lengths, ctx): self.que.put((False, (list(output_ids), list(seq_lengths)))) def _tensormap_dict_to_py_dict(self, tensormap_dict: lyraOp.TensorMap): """map torch tensormap to py dict.""" ret = dict() for k, v in tensormap_dict.items(): ret[k] = v return ret def stream_forward(self, start_ids: torch.Tensor, start_lengths: torch.Tensor, output_len, beam_width=1, top_k: torch.Tensor = None, top_p: torch.Tensor = None, beam_search_diversity_rate: torch.Tensor = None, temperature: torch.Tensor = None, len_penalty: torch.Tensor = None, repetition_penalty: torch.Tensor = None, random_seed: torch.Tensor = None, return_output_length=False, return_cum_log_probs=0): # Register callback func to model self.model.registerCallback(self._forward_callback) batch_size = start_ids.size(0) input_len = start_ids.size(1) assert input_len > 0, "input len must be larger than zero. For an unconditional case, use start_id as the first token." # Inputs to device input_ids = start_ids.cuda(self.device) input_lengths = start_lengths.cuda(self.device) # outputs: output_ids, output_lengths, output_cum_log_probs (optional) # Init thread of model inference def _func(enque_output): outputs = self.model.forward(input_ids, input_lengths, output_len, beam_width, # optional, can be None top_k, # optional, can be None top_p, # optional, can be None beam_search_diversity_rate, # optional, can be None temperature, # optional, can be None len_penalty, # optional, can be None repetition_penalty, # optional, can be None random_seed, # optional, can be None return_cum_log_probs) # optional, can be None if enque_output: self.que.put((True, (outputs[0].tolist(), outputs[1].tolist()))) # Start thread of model inference t = Thread(target=_func, args=(True,), daemon=True) t.start() self.threads[0] = t # Generate streaming output while True: while self.que.qsize() > 1: self.que.get() finish, outputs = self.que.get() output_ids, sequence_length = outputs output_ids_tensor = torch.tensor(output_ids).view(batch_size, beam_width, -1) sequence_length_tensor = torch.tensor(sequence_length).view(batch_size, beam_width) if return_output_length: if return_cum_log_probs > 0: yield finish, output_ids_tensor, sequence_length_tensor, None else: yield finish, output_ids_tensor, sequence_length_tensor, None else: yield finish, output_ids_tensor, None, None if finish: for t in self.threads: t.join() while self.que.qsize() > 0: self.que.get() break self.model.unRegisterCallback()