File size: 6,920 Bytes
36d1bec
 
 
a4dc558
 
36d1bec
a4dc558
 
36d1bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4dc558
36d1bec
 
 
 
 
 
 
 
 
 
 
77d95ba
a4dc558
 
 
 
 
 
 
36d1bec
 
 
 
 
 
 
 
 
 
 
 
 
 
77d95ba
 
36d1bec
a4dc558
36d1bec
 
 
 
 
 
 
 
 
 
77d95ba
36d1bec
a4dc558
36d1bec
 
 
 
 
 
 
a4dc558
36d1bec
 
 
a4dc558
 
 
 
 
 
 
 
 
 
 
 
36d1bec
a4dc558
36d1bec
a4dc558
 
36d1bec
 
 
a4dc558
36d1bec
 
 
 
 
 
a4dc558
 
36d1bec
 
 
a4dc558
36d1bec
a4dc558
36d1bec
 
 
a4dc558
 
36d1bec
 
 
 
a4dc558
 
36d1bec
 
 
 
a4dc558
36d1bec
 
 
e1080e8
36d1bec
a4dc558
 
36d1bec
a4dc558
 
36d1bec
 
 
 
 
 
 
 
 
 
 
 
a4dc558
36d1bec
 
 
a4dc558
36d1bec
a4dc558
36d1bec
 
 
 
a4dc558
36d1bec
 
 
 
 
a4dc558
 
36d1bec
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from llama_cpp import Llama
import gc
import threading
import logging
import sys

log = logging.getLogger('llm_api.backend')
    
class LlmBackend:
    
    SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык."
    SYSTEM_TOKEN = 1788
    USER_TOKEN = 1404
    BOT_TOKEN = 9225
    LINEBREAK_TOKEN = 13

    ROLE_TOKENS = {
        "user": USER_TOKEN,
        "bot": BOT_TOKEN,
        "system": SYSTEM_TOKEN
    }

    _instance = None
    _model = None
    _model_params = None
    _lock = threading.Lock()
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(LlmBackend, cls).__new__(cls)
        return cls._instance
    
    
    def is_model_loaded(self):
        return self._model is not None
    
    def load_model(self, model_path, context_size=2000, enable_gpu=True, gpu_layer_number=35, chat_format='llama-2'):
        log.info('load_model - started')
        self._model_params = {}
        self._model_params['model_path'] = model_path
        self._model_params['context_size'] = context_size
        self._model_params['enable_gpu'] = enable_gpu
        self._model_params['gpu_layer_number'] = gpu_layer_number
        self._model_params['chat_format'] = chat_format
        
        if self._model is not None:
            self.unload_model()
            
        with self._lock:    
            if enable_gpu:
                self._model = Llama(
                    model_path=model_path,
                    chat_format=chat_format,
                    n_ctx=context_size,
                    n_parts=1,
                    #n_batch=100,
                    logits_all=True,
                    #n_threads=12,
                    verbose=True,
                    n_gpu_layers=gpu_layer_number
                )
                log.info('load_model - finished')
                return self._model
            else:
                self._model = Llama(
                    model_path=model_path,
                    chat_format=chat_format,
                    n_ctx=context_size,
                    n_parts=1,
                    #n_batch=100,
                    logits_all=True,
                    #n_threads=12,
                    verbose=True
                )
                log.info('load_model - finished')
                return self._model
        
    def set_system_prompt(self, prompt):
        with self._lock:
            self.SYSTEM_PROMPT = prompt
        
    def unload_model(self):
        log.info('unload_model - started')
        with self._lock:
            if self._model is not None:
                del self._model
        log.info('unload_model - finished')
    
    def ensure_model_is_loaded(self):
        log.info('ensure_model_is_loaded - started')
        if not self.is_model_loaded():
            log.info('ensure_model_is_loaded - model reloading')
            if self._model_params is not None:
                self.load_model(**self._model_params)
            else:
                log.info('ensure_model_is_loaded - No model config found. Reloading can not be done.')
        log.info('ensure_model_is_loaded - finished')
                    
    def generate_tokens(self, generator):
        log.info('generate_tokens - started')
        with self._lock:
            self.ensure_model_is_loaded()
                
            try:
                for token in generator:            
                    if token == self._model.token_eos():
                        log.info('generate_tokens - finished')
                        yield b''  # End of chunk
                        break
                        
                    token_str = self._model.detokenize([token])#.decode("utf-8", errors="ignore")
                    yield token_str 
            except Exception as e:
                log.error('generate_tokens - error')
                log.error(e)
                yield b''  # End of chunk
                
    def create_chat_completion(self, messages, stream=True):
        log.info('create_chat_completion called')
        with self._lock:
            log.info('create_chat_completion started')
            try:
                return self._model.create_chat_completion(messages=messages, stream=stream)
            except Exception as e:
                log.error('create_chat_completion - error')
                log.error(e)
                return None
                
    
    def get_message_tokens(self, role, content):
        log.info('get_message_tokens - started')
        self.ensure_model_is_loaded()
        message_tokens = self._model.tokenize(content.encode("utf-8"))
        message_tokens.insert(1, self.ROLE_TOKENS[role])
        message_tokens.insert(2, self.LINEBREAK_TOKEN)
        message_tokens.append(self._model.token_eos())
        log.info('get_message_tokens - finished')
        return message_tokens

    def get_system_tokens(self):
        return self.get_message_tokens(role="system", content=self.SYSTEM_PROMPT)
    
    def create_chat_generator_for_saiga(self, messages, parameters, use_system_prompt=True):
        log.info('create_chat_generator_for_saiga - started')
        with self._lock:
            self.ensure_model_is_loaded()
            tokens = self.get_system_tokens() if use_system_prompt else []
            for message in messages:
                message_tokens = self.get_message_tokens(role=message.get("from"), content=message.get("content", ""))
                tokens.extend(message_tokens)
            
            tokens.extend([self._model.token_bos(), self.BOT_TOKEN, self.LINEBREAK_TOKEN])
            generator = self._model.generate(
                tokens,
                top_k=parameters['top_k'],
                top_p=parameters['top_p'],
                temp=parameters['temperature'],
                repeat_penalty=parameters['repetition_penalty']
            )
            log.info('create_chat_generator_for_saiga - finished')
            return generator
        
    def generate_tokens(self, generator):
        log.info('generate_tokens - started')
        with self._lock:
            self.ensure_model_is_loaded()
            try:
                for token in generator:            
                    if token == self._model.token_eos():
                        yield b''  # End of chunk
                        log.info('generate_tokens - finished')
                        break
                        
                    token_str = self._model.detokenize([token])#.decode("utf-8", errors="ignore")
                    yield token_str 
            except Exception as e:
                log.error('generate_tokens - error')
                log.error(e)
                yield b''  # End of chunk