carsonhxsu
# This is a combination of 22 commits.
8453337
raw
history blame contribute delete
No virus
10.5 kB
from __future__ import print_function
import copy
import os
import pathlib
import typing
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from queue import Queue
from threading import Thread
import sys
sys.path.append('/usr/lib/lyralib')
import lyraOp
str_type_map = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
class BaichuanModel(nn.Module):
def __init__(self,
head_num,
size_per_head,
inter_size,
vocab_size,
rotary_embedding_dim,
start_id, end_id, layer_num,
max_seq_len: int,
layernorm_eps,
tensor_para_size: int,
pipeline_para_size: int,
use_gptj_residual,
lib_path: typing.Union[str, pathlib.Path],
model_path,
memopt_mode: int = 0,
quant_data_type: str = "int8",
inference_data_type: str = "fp16",
weights_data_type: typing.Union[str, np.dtype] = np.float32):
super().__init__()
self.head_num = head_num
self.size_per_head = size_per_head
self.inter_size = inter_size
self.vocab_size = vocab_size
self.rotary_embedding_dim = rotary_embedding_dim
self.start_id = start_id
self.end_id = end_id
self.max_seq_len = max_seq_len
self.layer_num = layer_num
self.use_gptj_residual = use_gptj_residual
self.layernorm_eps = layernorm_eps
self.memopt_mode = memopt_mode
self.quant_data_type = quant_data_type
# multi-gpu params
self.tensor_para_size = tensor_para_size
self.pipeline_para_size = pipeline_para_size
self.build_model = False
self.weights_data_type = weights_data_type
self.inference_data_type = inference_data_type
assert torch.cuda.is_available(), "CUDA is required for this model."
assert head_num % tensor_para_size == 0, "head_num must be a multiple of tensor_para_size."
assert layer_num % pipeline_para_size == 0, "layer_num must be a multiple of pipeline_para_size."
# queue for streaming
self.que = Queue()
self.threads = [None] * self.tensor_para_size
# Load the C++ model into Pytorch model.
# torch.classes.load_library(os.path.abspath(lib_path))
# Prepare for tensor/pipeline parallel
try:
dist.init_process_group(backend='mpi')
except:
print("[INFO] WARNING: Have initialized the process group")
self.rank = dist.get_rank()
self.device_count = torch.cuda.device_count()
self.device = self.rank % self.device_count
torch.cuda.set_device(self.device)
world_size = dist.get_world_size()
# print(tensor_para_size * pipeline_para_size)
assert world_size == tensor_para_size * pipeline_para_size, "tensor_para_size * pipeline_para_size must be equal to world_size."
self.tensor_para_rank = self.rank % self.tensor_para_size
self.pipeline_para_rank = self.rank // self.tensor_para_size
self.model = lyraOp.LyraBaichuan(
self.head_num, self.size_per_head, self.inter_size,
self.layer_num,
self.vocab_size,
self.rotary_embedding_dim,
self.layernorm_eps,
self.start_id, self.end_id,
self.tensor_para_size, self.pipeline_para_size,
self.max_seq_len,
self.use_gptj_residual,
self.memopt_mode,
self.quant_data_type,
model_path,
self.weights_data_type,
self.inference_data_type)
self.build_model = True
torch.cuda.empty_cache()
def forward(self,
start_ids: torch.Tensor,
start_lengths: torch.Tensor,
output_len,
beam_width=1,
top_k: torch.Tensor = None,
top_p: torch.Tensor = None,
beam_search_diversity_rate: torch.Tensor = None,
temperature: torch.Tensor = None,
len_penalty: torch.Tensor = None,
repetition_penalty: torch.Tensor = None,
random_seed: torch.Tensor = None,
return_output_length=False,
return_cum_log_probs=0):
input_len = start_ids.size(1)
assert input_len > 0, "input len must be larger than zero. For an unconditional case, use start_id as the first token."
# Inputs to device
input_ids = start_ids.cuda(self.device)
input_lengths = start_lengths.cuda(self.device)
# outputs: output_ids, output_lengths, output_cum_log_probs (optional)
outputs = self.model.forward(input_ids,
input_lengths,
output_len,
beam_width, # optional, can be None
top_k, # optional, can be None
top_p, # optional, can be None
beam_search_diversity_rate, # optional, can be None
temperature, # optional, can be None
len_penalty, # optional, can be None
repetition_penalty, # optional, can be None
random_seed, # optional, can be None
return_cum_log_probs) # optional, can be None
if return_cum_log_probs == 0:
output_ids, output_lengths = outputs
else:
output_ids, output_lengths, output_cum_log_probs = outputs
if return_output_length:
if return_cum_log_probs > 0:
return output_ids, output_lengths, output_cum_log_probs
else:
return output_ids, output_lengths
else:
return output_ids
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def _forward_callback(self, output_ids, seq_lengths, ctx):
self.que.put((False, (list(output_ids), list(seq_lengths))))
def _tensormap_dict_to_py_dict(self, tensormap_dict: lyraOp.TensorMap):
"""map torch tensormap to py dict."""
ret = dict()
for k, v in tensormap_dict.items():
ret[k] = v
return ret
def stream_forward(self,
start_ids: torch.Tensor,
start_lengths: torch.Tensor,
output_len,
beam_width=1,
top_k: torch.Tensor = None,
top_p: torch.Tensor = None,
beam_search_diversity_rate: torch.Tensor = None,
temperature: torch.Tensor = None,
len_penalty: torch.Tensor = None,
repetition_penalty: torch.Tensor = None,
random_seed: torch.Tensor = None,
return_output_length=False,
return_cum_log_probs=0):
# Register callback func to model
self.model.registerCallback(self._forward_callback)
batch_size = start_ids.size(0)
input_len = start_ids.size(1)
assert input_len > 0, "input len must be larger than zero. For an unconditional case, use start_id as the first token."
# Inputs to device
input_ids = start_ids.cuda(self.device)
input_lengths = start_lengths.cuda(self.device)
# outputs: output_ids, output_lengths, output_cum_log_probs (optional)
# Init thread of model inference
def _func(enque_output):
outputs = self.model.forward(input_ids,
input_lengths,
output_len,
beam_width, # optional, can be None
top_k, # optional, can be None
top_p, # optional, can be None
beam_search_diversity_rate, # optional, can be None
temperature, # optional, can be None
len_penalty, # optional, can be None
repetition_penalty, # optional, can be None
random_seed, # optional, can be None
return_cum_log_probs) # optional, can be None
if enque_output:
self.que.put((True, (outputs[0].tolist(), outputs[1].tolist())))
# Start thread of model inference
t = Thread(target=_func,
args=(True,),
daemon=True)
t.start()
self.threads[0] = t
# Generate streaming output
while True:
# while self.que.qsize() > 1:
# self.que.get()
finish, outputs = self.que.get()
output_ids, sequence_length = outputs
output_ids = torch.tensor(output_ids).view(batch_size, beam_width, -1)
sequence_length = torch.tensor(sequence_length).view(batch_size, beam_width)
if return_output_length:
if return_cum_log_probs > 0:
yield finish, output_ids, sequence_length, None
else:
yield finish, output_ids, sequence_length, None
else:
yield finish, output_ids, None, None
if finish:
for t in self.threads:
t.join()
while self.que.qsize() > 0:
self.que.get()
break
self.model.unRegisterCallback()
return finish, output_ids, None, None