pseudotensor commited on
Commit
a6d8676
1 Parent(s): d9fa842

Upload h2oai_pipeline.py

Browse files
Files changed (1) hide show
  1. h2oai_pipeline.py +344 -162
h2oai_pipeline.py CHANGED
@@ -71,8 +71,8 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
71
  # unknown
72
  model_max_length = None
73
 
 
74
  if model_max_length is not None:
75
- num_prompt_tokens = None
76
  # can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
77
  # For https://github.com/h2oai/h2ogpt/issues/192
78
  for trial in range(0, 3):
@@ -108,10 +108,10 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
108
  print("Reduced max_new_tokens from %s -> %s" % (
109
  generate_kwargs['max_new_tokens'], max_new_tokens))
110
  generate_kwargs['max_new_tokens'] = max_new_tokens
111
- return prompt_text
112
 
113
  def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
114
- prompt_text = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
115
 
116
  data_point = dict(context='', instruction=prompt_text, input='')
117
  if self.prompter is not None:
@@ -132,7 +132,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
132
  outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
133
  sanitize_bot_response=self.sanitize_bot_response)
134
  elif self.bot and self.human:
135
- outputs = rec['generated_text'].split(self.bot)[1].strip().split(self.human)[0].strip()
136
  else:
137
  outputs = rec['generated_text']
138
  rec['generated_text'] = outputs
@@ -195,83 +195,6 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
195
  else:
196
  raise ValueError("TF not avaialble.")
197
  return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
198
- import torch
199
- from transformers import StoppingCriteria, StoppingCriteriaList
200
-
201
-
202
-
203
- class StoppingCriteriaSub(StoppingCriteria):
204
-
205
- def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None):
206
- super().__init__()
207
- assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
208
- self.encounters = encounters
209
- self.stops = [stop.to(device) for stop in stops]
210
- self.num_stops = [0] * len(stops)
211
- self.model_max_length = model_max_length
212
-
213
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
214
- for stopi, stop in enumerate(self.stops):
215
- if torch.all((stop == input_ids[0][-len(stop):])).item():
216
- self.num_stops[stopi] += 1
217
- if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
218
- # print("Stopped", flush=True)
219
- return True
220
- if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
221
- # critical limit
222
- return True
223
- # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
224
- # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
225
- return False
226
-
227
-
228
- def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:", model_max_length=None):
229
- # FIXME: prompt_dict unused currently
230
- if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
231
- if prompt_type == PromptType.human_bot.name:
232
- # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
233
- # stopping only starts once output is beyond prompt
234
- # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
235
- stop_words = [human, bot, '\n' + human, '\n' + bot]
236
- encounters = [1, 2]
237
- elif prompt_type == PromptType.instruct_vicuna.name:
238
- # even below is not enough, generic strings and many ways to encode
239
- stop_words = [
240
- '### Human:',
241
- """
242
- ### Human:""",
243
- """
244
- ### Human:
245
- """,
246
- '### Assistant:',
247
- """
248
- ### Assistant:""",
249
- """
250
- ### Assistant:
251
- """,
252
- ]
253
- encounters = [1, 2]
254
- else:
255
- # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
256
- stop_words = ['### End']
257
- encounters = [1]
258
- stop_words_ids = [
259
- tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
260
- # handle single token case
261
- stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
262
- stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
263
- # avoid padding in front of tokens
264
- if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
265
- stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
266
- # handle fake \n added
267
- stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
268
- # build stopper
269
- stopping_criteria = StoppingCriteriaList(
270
- [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device,
271
- model_max_length=model_max_length)])
272
- else:
273
- stopping_criteria = StoppingCriteriaList()
274
- return stopping_criteria
275
  from enum import Enum
276
 
277
 
@@ -296,6 +219,12 @@ class PromptType(Enum):
296
  wizard2 = 16
297
  wizard3 = 17
298
  instruct_simple = 18
 
 
 
 
 
 
299
 
300
 
301
  class DocumentChoices(Enum):
@@ -318,9 +247,41 @@ class LangChainMode(Enum):
318
  MY_DATA = "MyData"
319
  GITHUB_H2OGPT = "github h2oGPT"
320
  H2O_DAI_DOCS = "DriverlessAI docs"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  import ast
322
  import time
323
- from enums import PromptType # also supports imports from this file from other files
324
 
325
  non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
326
 
@@ -344,23 +305,29 @@ prompt_type_to_model_name = {
344
  'mosaicml/mpt-7b-storywriter',
345
  'mosaicml/mpt-7b-instruct', # internal code handles instruct
346
  'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
347
- 'gptj', # internally handles prompting
348
- 'llama', # plain, or need to choose prompt_type for given TheBloke model
349
- 'gpt4all_llama', # internally handles prompting
350
  ],
 
351
  'prompt_answer': [
352
  'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
353
  'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
354
  'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
355
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
356
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
357
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
358
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
359
  'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
360
  'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
 
361
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
362
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
363
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
 
 
 
 
 
 
 
 
 
 
364
  ],
365
  'instruct': [],
366
  'instruct_with_end': ['databricks/dolly-v2-12b'],
@@ -373,6 +340,7 @@ prompt_type_to_model_name = {
373
  'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
374
  'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
375
  'h2oai/h2ogpt-research-oasst1-512-30b',
 
376
  'h2oai/h2ogpt-oasst1-falcon-40b',
377
  'h2oai/h2ogpt-oig-oasst1-falcon-40b',
378
  ],
@@ -385,7 +353,16 @@ prompt_type_to_model_name = {
385
  "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
386
  "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
387
  "instruct_simple": ['JosephusCheung/Guanaco'],
 
 
 
 
388
  }
 
 
 
 
 
389
 
390
  inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
391
  inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
@@ -399,18 +376,29 @@ for p in PromptType:
399
  prompt_types.extend([p.name, p.value, str(p.value)])
400
 
401
 
402
- def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=False):
403
  prompt_dict_error = ''
 
 
404
  if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
405
  try:
406
  prompt_dict = ast.literal_eval(prompt_dict)
407
  except BaseException as e:
408
  prompt_dict_error = str(e)
409
- if prompt_dict_error:
410
- return dict(), prompt_dict_error
411
-
412
- if prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
413
- PromptType.custom.name]:
 
 
 
 
 
 
 
 
 
414
  promptA = prompt_dict.get('promptA', '')
415
  promptB = prompt_dict('promptB', '')
416
  PreInstruct = prompt_dict.get('PreInstruct', '')
@@ -418,21 +406,23 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=Fal
418
  PreResponse = prompt_dict.get('PreResponse', '')
419
  terminate_response = prompt_dict.get('terminate_response', None)
420
  chat_sep = prompt_dict.get('chat_sep', '\n')
 
421
  humanstr = prompt_dict.get('humanstr', '')
422
  botstr = prompt_dict.get('botstr', '')
423
  elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
424
  PromptType.plain.name]:
425
- promptA = promptB = PreInstruct = PreInput = PreResponse = ''
426
  terminate_response = []
427
- chat_sep = ''
428
- humanstr = ''
429
- botstr = ''
 
430
  elif prompt_type == 'simple_instruct':
431
  promptA = promptB = PreInstruct = PreInput = PreResponse = None
432
  terminate_response = []
433
- chat_sep = '\n'
434
- humanstr = ''
435
- botstr = ''
436
  elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
437
  PromptType.instruct.name] + [PromptType.instruct_with_end.value,
438
  str(PromptType.instruct_with_end.value),
@@ -458,7 +448,7 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=Fal
458
  terminate_response = ['### End']
459
  else:
460
  terminate_response = None
461
- chat_sep = '\n'
462
  humanstr = PreInstruct
463
  botstr = PreResponse
464
  elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
@@ -480,7 +470,7 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=Fal
480
  ### Response:
481
  """
482
  terminate_response = None
483
- chat_sep = '\n'
484
  humanstr = PreInstruct # first thing human says
485
  botstr = PreResponse # first thing bot says
486
  elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
@@ -502,14 +492,14 @@ Current Time: {}
502
 
503
  """
504
  preprompt = PRE_PROMPT.format(cur_date, cur_time)
505
- start = human
506
- promptB = promptA = '%s%s ' % (preprompt, start)
507
 
508
- PreInstruct = ""
509
 
510
  PreInput = None
511
 
512
- if reduced:
513
  # when making context, want it to appear as-if LLM generated, which starts with space after :
514
  PreResponse = bot + ' '
515
  else:
@@ -517,10 +507,11 @@ Current Time: {}
517
  # if add space here, non-unique tokenization will often make LLM produce wrong output
518
  PreResponse = bot
519
 
520
- terminate_response = [start, PreResponse]
521
- chat_sep = '\n'
522
  humanstr = human # tag before human talks
523
  botstr = bot # tag before bot talks
 
524
  elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
525
  PromptType.dai_faq.name]:
526
  promptA = ''
@@ -536,7 +527,7 @@ Current Time: {}
536
  ### Driverless AI documentation answer:
537
  """
538
  terminate_response = ['\n\n']
539
- chat_sep = terminate_response
540
  humanstr = PreInstruct
541
  botstr = PreResponse
542
  elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
@@ -545,7 +536,7 @@ Current Time: {}
545
  PreInstruct = '## Main Text\n\n'
546
  PreResponse = '\n\n## Summary\n\n'
547
  terminate_response = None
548
- chat_sep = '\n'
549
  humanstr = PreInstruct
550
  botstr = PreResponse
551
  elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
@@ -565,7 +556,7 @@ Current Time: {}
565
  """
566
  terminate_response = [
567
  '### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
568
- chat_sep = '\n'
569
  humanstr = PreInstruct
570
  botstr = PreResponse
571
  elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
@@ -573,33 +564,50 @@ Current Time: {}
573
  preprompt = ''
574
  prompt_tokens = "<|prompt|>"
575
  answer_tokens = "<|answer|>"
576
- start = prompt_tokens
577
  promptB = promptA = '%s%s' % (preprompt, start)
578
- PreInstruct = ""
579
  PreInput = None
580
  PreResponse = answer_tokens
581
  eos = '<|endoftext|>' # neox eos
582
- terminate_response = [start, PreResponse, eos]
583
- chat_sep = eos
584
  humanstr = prompt_tokens
585
  botstr = answer_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
587
  PromptType.open_assistant.name]:
588
  # From added_tokens.json
589
  preprompt = ''
590
  prompt_tokens = "<|prompter|>"
591
  answer_tokens = "<|assistant|>"
592
- start = prompt_tokens
593
  promptB = promptA = '%s%s' % (preprompt, start)
594
- PreInstruct = ""
595
  PreInput = None
596
  PreResponse = answer_tokens
597
  pend = "<|prefix_end|>"
598
  eos = "</s>"
599
- terminate_response = [start, PreResponse, pend, eos]
600
- chat_sep = eos
601
  humanstr = prompt_tokens
602
  botstr = answer_tokens
 
 
603
  elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
604
  PromptType.wizard_lm.name]:
605
  # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
@@ -611,7 +619,7 @@ Current Time: {}
611
  PreResponse = "\n\n### Response\n"
612
  eos = "</s>"
613
  terminate_response = [PreResponse, eos]
614
- chat_sep = eos
615
  humanstr = promptA
616
  botstr = PreResponse
617
  elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
@@ -627,13 +635,12 @@ Current Time: {}
627
  ### Assistant:
628
  """
629
  terminate_response = [PreResponse]
630
- chat_sep = '\n'
631
  humanstr = PreInstruct
632
  botstr = PreResponse
633
  elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
634
  PromptType.instruct_vicuna2.name]:
635
- promptA = promptB = "" if not (
636
- chat and reduced) else ''
637
 
638
  PreInstruct = """
639
  HUMAN:
@@ -646,13 +653,12 @@ ASSISTANT:
646
  """
647
  terminate_response = [
648
  'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
649
- chat_sep = '\n'
650
  humanstr = PreInstruct
651
  botstr = PreResponse
652
  elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
653
  PromptType.instruct_vicuna3.name]:
654
- promptA = promptB = "" if not (
655
- chat and reduced) else ''
656
 
657
  PreInstruct = """
658
  ### User:
@@ -665,13 +671,14 @@ ASSISTANT:
665
  """
666
  terminate_response = [
667
  '### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
668
- chat_sep = '\n'
669
  humanstr = PreInstruct
670
  botstr = PreResponse
671
  elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
672
  PromptType.wizard2.name]:
673
  # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
674
- preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
 
675
  start = ''
676
  promptB = promptA = '%s%s' % (preprompt, start)
677
  PreInstruct = """
@@ -682,27 +689,39 @@ ASSISTANT:
682
  ### Response:
683
  """
684
  terminate_response = [PreResponse]
685
- chat_sep = '\n'
686
  humanstr = PreInstruct
687
  botstr = PreResponse
688
  elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
689
  PromptType.wizard3.name]:
690
  # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
691
- preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
 
692
  start = ''
693
  promptB = promptA = '%s%s' % (preprompt, start)
694
  PreInstruct = """USER: """
695
  PreInput = None
696
  PreResponse = """ASSISTANT: """
697
  terminate_response = [PreResponse]
698
- chat_sep = '\n'
 
 
 
 
 
 
 
 
 
 
 
 
699
  humanstr = PreInstruct
700
  botstr = PreResponse
701
 
702
  elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
703
  PromptType.instruct_simple.name]:
704
- promptA = '' if not (chat and reduced) else ''
705
- promptB = '' if not (chat and reduced) else ''
706
 
707
  PreInstruct = """
708
  ### Instruction:
@@ -716,21 +735,90 @@ ASSISTANT:
716
  ### Response:
717
  """
718
  terminate_response = None
719
- chat_sep = '\n'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720
  humanstr = PreInstruct
721
  botstr = PreResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
  else:
723
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
724
 
725
- if return_dict:
726
- return dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput,
 
 
727
  PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
728
- humanstr=humanstr, botstr=botstr), ''
 
 
 
 
 
729
  else:
730
- return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep, humanstr, botstr
731
 
732
 
733
- def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
734
  context = data_point.get('context')
735
  if context is None:
736
  context = ''
@@ -741,9 +829,12 @@ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
741
  prompt_dict = data_point.get('prompt_dict', prompt_dict)
742
  assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
743
  promptA, promptB, PreInstruct, PreInput, PreResponse, \
744
- terminate_response, chat_sep, humanstr, botstr = get_prompt(prompt_type, prompt_dict, chat, context, reduced)
 
 
745
 
746
- prompt = context if not reduced else ''
 
747
 
748
  if input and promptA:
749
  prompt += f"""{promptA}"""
@@ -793,7 +884,7 @@ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
793
  if output:
794
  prompt += f"""{output}"""
795
 
796
- return prompt, pre_response, terminate_response, chat_sep
797
 
798
 
799
  def inject_chatsep(prompt_type, prompt, chat_sep=None):
@@ -808,9 +899,6 @@ class Prompter(object):
808
  allowed_repeat_line_length=10):
809
  self.prompt_type = prompt_type
810
  self.prompt_dict = prompt_dict
811
- data_point = dict(instruction='', input='', output='')
812
- _, self.pre_response, self.terminate_response, self.chat_sep = \
813
- generate_prompt(data_point, self.prompt_type, self.prompt_dict, chat, False)
814
  self.debug = debug
815
  self.chat = chat
816
  self.stream_output = stream_output
@@ -819,15 +907,33 @@ class Prompter(object):
819
  self.prompt = None
820
  context = "" # not for chat context
821
  reduced = False # not for chat context
 
822
  self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
823
- self.terminate_response, self.chat_sep, self.humanstr, self.botstr = \
824
- get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced)
 
 
825
 
826
- def generate_prompt(self, data_point):
827
- reduced = False
828
- prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced)
 
 
 
 
 
 
 
 
829
  if self.debug:
830
  print("prompt: %s" % prompt, flush=True)
 
 
 
 
 
 
 
831
  self.prompt = prompt
832
  return prompt
833
 
@@ -846,7 +952,8 @@ class Prompter(object):
846
  if sanitize_bot_response:
847
  from better_profanity import profanity
848
  response = profanity.censor(response)
849
- response = response.strip("\n")
 
850
  return response
851
 
852
  def clean_repeats(response):
@@ -868,12 +975,12 @@ class Prompter(object):
868
  # then use most basic parsing like pipeline
869
  if self.botstr in output:
870
  if self.humanstr:
871
- output = clean_response(output.split(self.botstr)[1].strip().split(self.humanstr)[0].strip())
872
  else:
873
  # i.e. use after bot but only up to next bot
874
- output = clean_response(output.split(self.botstr)[1].strip().split(self.botstr)[0].strip())
875
  else:
876
- # output = clean_response(output.strip())
877
  # assume just not printed yet
878
  output = ""
879
  else:
@@ -900,9 +1007,9 @@ class Prompter(object):
900
  allow_terminate = True
901
  output = output[len(prompt):]
902
  # clean after subtract prompt out, so correct removal of pre_response
903
- output = clean_response(output).strip()
904
  if self.repeat_penalty:
905
- output = clean_repeats(output).strip()
906
  if self.terminate_response and allow_terminate:
907
  finds = []
908
  for term in self.terminate_response:
@@ -910,11 +1017,9 @@ class Prompter(object):
910
  finds = [x for x in finds if x >= 0]
911
  if len(finds) > 0:
912
  termi = finds[0]
913
- output = output[:termi].strip()
914
  else:
915
- output = output.strip()
916
- else:
917
- output = output.strip()
918
  if multi_output:
919
  # prefix with output counter
920
  output = "\n=========== Output %d\n\n" % (1 + oi) + output
@@ -927,3 +1032,80 @@ class Prompter(object):
927
  if self.debug:
928
  print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
929
  return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # unknown
72
  model_max_length = None
73
 
74
+ num_prompt_tokens = None
75
  if model_max_length is not None:
 
76
  # can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
77
  # For https://github.com/h2oai/h2ogpt/issues/192
78
  for trial in range(0, 3):
 
108
  print("Reduced max_new_tokens from %s -> %s" % (
109
  generate_kwargs['max_new_tokens'], max_new_tokens))
110
  generate_kwargs['max_new_tokens'] = max_new_tokens
111
+ return prompt_text, num_prompt_tokens
112
 
113
  def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
114
+ prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
115
 
116
  data_point = dict(context='', instruction=prompt_text, input='')
117
  if self.prompter is not None:
 
132
  outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
133
  sanitize_bot_response=self.sanitize_bot_response)
134
  elif self.bot and self.human:
135
+ outputs = rec['generated_text'].split(self.bot)[1].split(self.human)[0]
136
  else:
137
  outputs = rec['generated_text']
138
  rec['generated_text'] = outputs
 
195
  else:
196
  raise ValueError("TF not avaialble.")
197
  return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  from enum import Enum
199
 
200
 
 
219
  wizard2 = 16
220
  wizard3 = 17
221
  instruct_simple = 18
222
+ wizard_vicuna = 19
223
+ openai = 20
224
+ openai_chat = 21
225
+ gptj = 22
226
+ prompt_answer_openllama = 23
227
+ vicuna11 = 24
228
 
229
 
230
  class DocumentChoices(Enum):
 
247
  MY_DATA = "MyData"
248
  GITHUB_H2OGPT = "github h2oGPT"
249
  H2O_DAI_DOCS = "DriverlessAI docs"
250
+
251
+
252
+ no_server_str = no_lora_str = no_model_str = '[None/Remove]'
253
+
254
+
255
+ # from site-packages/langchain/llms/openai.py, but needed since ChatOpenAI doesn't have this information
256
+ model_token_mapping = {
257
+ "gpt-4": 8192,
258
+ "gpt-4-0314": 8192,
259
+ "gpt-4-32k": 32768,
260
+ "gpt-4-32k-0314": 32768,
261
+ "gpt-3.5-turbo": 4096,
262
+ "gpt-3.5-turbo-16k": 16*1024,
263
+ "gpt-3.5-turbo-0301": 4096,
264
+ "text-ada-001": 2049,
265
+ "ada": 2049,
266
+ "text-babbage-001": 2040,
267
+ "babbage": 2049,
268
+ "text-curie-001": 2049,
269
+ "curie": 2049,
270
+ "davinci": 2049,
271
+ "text-davinci-003": 4097,
272
+ "text-davinci-002": 4097,
273
+ "code-davinci-002": 8001,
274
+ "code-davinci-001": 8001,
275
+ "code-cushman-002": 2048,
276
+ "code-cushman-001": 2048,
277
+ }
278
+
279
+
280
+ source_prefix = "Sources [Score | Link]:"
281
+ source_postfix = "End Sources<p>"
282
+ import os
283
  import ast
284
  import time
 
285
 
286
  non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
287
 
 
305
  'mosaicml/mpt-7b-storywriter',
306
  'mosaicml/mpt-7b-instruct', # internal code handles instruct
307
  'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
308
+ 'mosaicml/mpt-30b-instruct', # internal code handles instruct
 
 
309
  ],
310
+ 'gptj': ['gptj', 'gpt4all_llama'],
311
  'prompt_answer': [
312
  'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
313
  'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
314
  'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
 
 
 
 
315
  'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
316
  'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
317
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3',
318
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
319
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
320
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
321
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2',
322
+ 'h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k',
323
+ 'h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k',
324
+ ],
325
+ 'prompt_answer_openllama': [
326
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
327
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
328
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
329
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
330
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b',
331
  ],
332
  'instruct': [],
333
  'instruct_with_end': ['databricks/dolly-v2-12b'],
 
340
  'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
341
  'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
342
  'h2oai/h2ogpt-research-oasst1-512-30b',
343
+ 'h2oai/h2ogpt-research-oasst1-llama-65b',
344
  'h2oai/h2ogpt-oasst1-falcon-40b',
345
  'h2oai/h2ogpt-oig-oasst1-falcon-40b',
346
  ],
 
353
  "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
354
  "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
355
  "instruct_simple": ['JosephusCheung/Guanaco'],
356
+ "wizard_vicuna": ['ehartford/Wizard-Vicuna-13B-Uncensored'],
357
+ "wizard2": ['llama', 'mosaicml/mpt-30b-instruct'],
358
+ "vicuna11": ['lmsys/vicuna-33b-v1.3'],
359
+ # could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
360
  }
361
+ if os.getenv('OPENAI_API_KEY'):
362
+ prompt_type_to_model_name.update({
363
+ "openai": ["text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001"],
364
+ "openai_chat": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k"],
365
+ })
366
 
367
  inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
368
  inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
 
376
  prompt_types.extend([p.name, p.value, str(p.value)])
377
 
378
 
379
+ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context, return_dict=False):
380
  prompt_dict_error = ''
381
+ generates_leading_space = False
382
+
383
  if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
384
  try:
385
  prompt_dict = ast.literal_eval(prompt_dict)
386
  except BaseException as e:
387
  prompt_dict_error = str(e)
388
+ if prompt_dict_error:
389
+ promptA = None
390
+ promptB = None
391
+ PreInstruct = None
392
+ PreInput = ''
393
+ PreResponse = ''
394
+ terminate_response = None
395
+ chat_sep = ''
396
+ chat_turn_sep = ''
397
+ humanstr = ''
398
+ botstr = ''
399
+ generates_leading_space = False
400
+ elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
401
+ PromptType.custom.name]:
402
  promptA = prompt_dict.get('promptA', '')
403
  promptB = prompt_dict('promptB', '')
404
  PreInstruct = prompt_dict.get('PreInstruct', '')
 
406
  PreResponse = prompt_dict.get('PreResponse', '')
407
  terminate_response = prompt_dict.get('terminate_response', None)
408
  chat_sep = prompt_dict.get('chat_sep', '\n')
409
+ chat_turn_sep = prompt_dict.get('chat_turn_sep', '\n')
410
  humanstr = prompt_dict.get('humanstr', '')
411
  botstr = prompt_dict.get('botstr', '')
412
  elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
413
  PromptType.plain.name]:
414
+ promptA = promptB = PreInstruct = PreInput = PreResponse = None
415
  terminate_response = []
416
+ chat_turn_sep = chat_sep = ''
417
+ # plain should have None for human/bot, so nothing truncated out, not '' that would truncate after first token
418
+ humanstr = None
419
+ botstr = None
420
  elif prompt_type == 'simple_instruct':
421
  promptA = promptB = PreInstruct = PreInput = PreResponse = None
422
  terminate_response = []
423
+ chat_turn_sep = chat_sep = '\n'
424
+ humanstr = None
425
+ botstr = None
426
  elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
427
  PromptType.instruct.name] + [PromptType.instruct_with_end.value,
428
  str(PromptType.instruct_with_end.value),
 
448
  terminate_response = ['### End']
449
  else:
450
  terminate_response = None
451
+ chat_turn_sep = chat_sep = '\n'
452
  humanstr = PreInstruct
453
  botstr = PreResponse
454
  elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
 
470
  ### Response:
471
  """
472
  terminate_response = None
473
+ chat_turn_sep = chat_sep = '\n'
474
  humanstr = PreInstruct # first thing human says
475
  botstr = PreResponse # first thing bot says
476
  elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
 
492
 
493
  """
494
  preprompt = PRE_PROMPT.format(cur_date, cur_time)
495
+ start = ''
496
+ promptB = promptA = '%s%s' % (preprompt, start)
497
 
498
+ PreInstruct = human + ' '
499
 
500
  PreInput = None
501
 
502
+ if making_context:
503
  # when making context, want it to appear as-if LLM generated, which starts with space after :
504
  PreResponse = bot + ' '
505
  else:
 
507
  # if add space here, non-unique tokenization will often make LLM produce wrong output
508
  PreResponse = bot
509
 
510
+ terminate_response = ['\n' + human, '\n' + bot, human, bot, PreResponse]
511
+ chat_turn_sep = chat_sep = '\n'
512
  humanstr = human # tag before human talks
513
  botstr = bot # tag before bot talks
514
+ generates_leading_space = True
515
  elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
516
  PromptType.dai_faq.name]:
517
  promptA = ''
 
527
  ### Driverless AI documentation answer:
528
  """
529
  terminate_response = ['\n\n']
530
+ chat_turn_sep = chat_sep = terminate_response
531
  humanstr = PreInstruct
532
  botstr = PreResponse
533
  elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
 
536
  PreInstruct = '## Main Text\n\n'
537
  PreResponse = '\n\n## Summary\n\n'
538
  terminate_response = None
539
+ chat_turn_sep = chat_sep = '\n'
540
  humanstr = PreInstruct
541
  botstr = PreResponse
542
  elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
 
556
  """
557
  terminate_response = [
558
  '### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
559
+ chat_turn_sep = chat_sep = '\n'
560
  humanstr = PreInstruct
561
  botstr = PreResponse
562
  elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
 
564
  preprompt = ''
565
  prompt_tokens = "<|prompt|>"
566
  answer_tokens = "<|answer|>"
567
+ start = ''
568
  promptB = promptA = '%s%s' % (preprompt, start)
569
+ PreInstruct = prompt_tokens
570
  PreInput = None
571
  PreResponse = answer_tokens
572
  eos = '<|endoftext|>' # neox eos
 
 
573
  humanstr = prompt_tokens
574
  botstr = answer_tokens
575
+ terminate_response = [humanstr, PreResponse, eos]
576
+ chat_sep = ''
577
+ chat_turn_sep = eos
578
+ elif prompt_type in [PromptType.prompt_answer_openllama.value, str(PromptType.prompt_answer_openllama.value),
579
+ PromptType.prompt_answer_openllama.name]:
580
+ preprompt = ''
581
+ prompt_tokens = "<|prompt|>"
582
+ answer_tokens = "<|answer|>"
583
+ start = ''
584
+ promptB = promptA = '%s%s' % (preprompt, start)
585
+ PreInstruct = prompt_tokens
586
+ PreInput = None
587
+ PreResponse = answer_tokens
588
+ eos = '</s>' # llama eos
589
+ humanstr = prompt_tokens
590
+ botstr = answer_tokens
591
+ terminate_response = [humanstr, PreResponse, eos]
592
+ chat_sep = ''
593
+ chat_turn_sep = eos
594
  elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
595
  PromptType.open_assistant.name]:
596
  # From added_tokens.json
597
  preprompt = ''
598
  prompt_tokens = "<|prompter|>"
599
  answer_tokens = "<|assistant|>"
600
+ start = ''
601
  promptB = promptA = '%s%s' % (preprompt, start)
602
+ PreInstruct = prompt_tokens
603
  PreInput = None
604
  PreResponse = answer_tokens
605
  pend = "<|prefix_end|>"
606
  eos = "</s>"
 
 
607
  humanstr = prompt_tokens
608
  botstr = answer_tokens
609
+ terminate_response = [humanstr, PreResponse, pend, eos]
610
+ chat_turn_sep = chat_sep = eos
611
  elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
612
  PromptType.wizard_lm.name]:
613
  # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
 
619
  PreResponse = "\n\n### Response\n"
620
  eos = "</s>"
621
  terminate_response = [PreResponse, eos]
622
+ chat_turn_sep = chat_sep = eos
623
  humanstr = promptA
624
  botstr = PreResponse
625
  elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
 
635
  ### Assistant:
636
  """
637
  terminate_response = [PreResponse]
638
+ chat_turn_sep = chat_sep = '\n'
639
  humanstr = PreInstruct
640
  botstr = PreResponse
641
  elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
642
  PromptType.instruct_vicuna2.name]:
643
+ promptA = promptB = "" if not (chat and reduced) else ''
 
644
 
645
  PreInstruct = """
646
  HUMAN:
 
653
  """
654
  terminate_response = [
655
  'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
656
+ chat_turn_sep = chat_sep = '\n'
657
  humanstr = PreInstruct
658
  botstr = PreResponse
659
  elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
660
  PromptType.instruct_vicuna3.name]:
661
+ promptA = promptB = "" if not (chat and reduced) else ''
 
662
 
663
  PreInstruct = """
664
  ### User:
 
671
  """
672
  terminate_response = [
673
  '### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
674
+ chat_turn_sep = chat_sep = '\n'
675
  humanstr = PreInstruct
676
  botstr = PreResponse
677
  elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
678
  PromptType.wizard2.name]:
679
  # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
680
+ preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" if not (
681
+ chat and reduced) else ''
682
  start = ''
683
  promptB = promptA = '%s%s' % (preprompt, start)
684
  PreInstruct = """
 
689
  ### Response:
690
  """
691
  terminate_response = [PreResponse]
692
+ chat_turn_sep = chat_sep = '\n'
693
  humanstr = PreInstruct
694
  botstr = PreResponse
695
  elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
696
  PromptType.wizard3.name]:
697
  # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
698
+ preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" if not (
699
+ chat and reduced) else ''
700
  start = ''
701
  promptB = promptA = '%s%s' % (preprompt, start)
702
  PreInstruct = """USER: """
703
  PreInput = None
704
  PreResponse = """ASSISTANT: """
705
  terminate_response = [PreResponse]
706
+ chat_turn_sep = chat_sep = '\n'
707
+ humanstr = PreInstruct
708
+ botstr = PreResponse
709
+ elif prompt_type in [PromptType.wizard_vicuna.value, str(PromptType.wizard_vicuna.value),
710
+ PromptType.wizard_vicuna.name]:
711
+ preprompt = ''
712
+ start = ''
713
+ promptB = promptA = '%s%s' % (preprompt, start)
714
+ PreInstruct = """USER: """
715
+ PreInput = None
716
+ PreResponse = """ASSISTANT: """
717
+ terminate_response = [PreResponse]
718
+ chat_turn_sep = chat_sep = '\n'
719
  humanstr = PreInstruct
720
  botstr = PreResponse
721
 
722
  elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
723
  PromptType.instruct_simple.name]:
724
+ promptB = promptA = '' if not (chat and reduced) else ''
 
725
 
726
  PreInstruct = """
727
  ### Instruction:
 
735
  ### Response:
736
  """
737
  terminate_response = None
738
+ chat_turn_sep = chat_sep = '\n'
739
+ humanstr = PreInstruct
740
+ botstr = PreResponse
741
+ elif prompt_type in [PromptType.openai.value, str(PromptType.openai.value),
742
+ PromptType.openai.name]:
743
+ preprompt = """The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly.""" if not (
744
+ chat and reduced) else ''
745
+ start = ''
746
+ promptB = promptA = '%s%s' % (preprompt, start)
747
+ PreInstruct = "\nHuman: "
748
+ PreInput = None
749
+ PreResponse = "\nAI:"
750
+ terminate_response = [PreResponse] + [" Human:", " AI:"]
751
+ chat_turn_sep = chat_sep = '\n'
752
+ humanstr = PreInstruct
753
+ botstr = PreResponse
754
+ elif prompt_type in [PromptType.gptj.value, str(PromptType.gptj.value),
755
+ PromptType.gptj.name]:
756
+ preprompt = "### Instruction:\n The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response." if not (
757
+ chat and reduced) else ''
758
+ start = ''
759
+ promptB = promptA = '%s%s' % (preprompt, start)
760
+ PreInstruct = "\n### Prompt: "
761
+ PreInput = None
762
+ PreResponse = "\n### Response: "
763
+ terminate_response = [PreResponse] + ["Prompt:", "Response:"]
764
+ chat_turn_sep = chat_sep = '\n'
765
  humanstr = PreInstruct
766
  botstr = PreResponse
767
+ elif prompt_type in [PromptType.openai_chat.value, str(PromptType.openai_chat.value),
768
+ PromptType.openai_chat.name]:
769
+ # prompting and termination all handled by endpoint
770
+ preprompt = """"""
771
+ start = ''
772
+ promptB = promptA = '%s%s' % (preprompt, start)
773
+ PreInstruct = ""
774
+ PreInput = None
775
+ PreResponse = ""
776
+ terminate_response = []
777
+ chat_turn_sep = chat_sep = '\n'
778
+ humanstr = None
779
+ botstr = None
780
+ elif prompt_type in [PromptType.vicuna11.value, str(PromptType.vicuna11.value),
781
+ PromptType.vicuna11.name]:
782
+ preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ if not (
783
+ chat and reduced) else ''
784
+ start = ''
785
+ promptB = promptA = '%s%s' % (preprompt, start)
786
+ eos = '</s>'
787
+ PreInstruct = """USER: """
788
+ PreInput = None
789
+ PreResponse = """ASSISTANT:"""
790
+ terminate_response = [PreResponse]
791
+ chat_sep = ' '
792
+ chat_turn_sep = eos
793
+ humanstr = PreInstruct
794
+ botstr = PreResponse
795
+
796
+ if making_context:
797
+ # when making context, want it to appear as-if LLM generated, which starts with space after :
798
+ PreResponse = PreResponse + ' '
799
+ else:
800
+ # normally LLM adds space after this, because was how trained.
801
+ # if add space here, non-unique tokenization will often make LLM produce wrong output
802
+ PreResponse = PreResponse
803
  else:
804
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
805
 
806
+ if isinstance(terminate_response, (tuple, list)):
807
+ assert '' not in terminate_response, "Bad terminate_response"
808
+
809
+ ret_dict = dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput,
810
  PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
811
+ chat_turn_sep=chat_turn_sep,
812
+ humanstr=humanstr, botstr=botstr,
813
+ generates_leading_space=generates_leading_space)
814
+
815
+ if return_dict:
816
+ return ret_dict, prompt_dict_error
817
  else:
818
+ return tuple(list(ret_dict.values()))
819
 
820
 
821
+ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced, making_context):
822
  context = data_point.get('context')
823
  if context is None:
824
  context = ''
 
829
  prompt_dict = data_point.get('prompt_dict', prompt_dict)
830
  assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
831
  promptA, promptB, PreInstruct, PreInput, PreResponse, \
832
+ terminate_response, chat_sep, chat_turn_sep, humanstr, botstr, \
833
+ generates_leading_space = get_prompt(prompt_type, prompt_dict, chat,
834
+ context, reduced, making_context)
835
 
836
+ # could avoid if reduce=True, but too complex for parent functions to handle
837
+ prompt = context
838
 
839
  if input and promptA:
840
  prompt += f"""{promptA}"""
 
884
  if output:
885
  prompt += f"""{output}"""
886
 
887
+ return prompt, pre_response, terminate_response, chat_sep, chat_turn_sep
888
 
889
 
890
  def inject_chatsep(prompt_type, prompt, chat_sep=None):
 
899
  allowed_repeat_line_length=10):
900
  self.prompt_type = prompt_type
901
  self.prompt_dict = prompt_dict
 
 
 
902
  self.debug = debug
903
  self.chat = chat
904
  self.stream_output = stream_output
 
907
  self.prompt = None
908
  context = "" # not for chat context
909
  reduced = False # not for chat context
910
+ making_context = False # not for chat context
911
  self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
912
+ self.terminate_response, self.chat_sep, self.chat_turn_sep, self.humanstr, self.botstr, \
913
+ self.generates_leading_space = \
914
+ get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced, making_context)
915
+ self.pre_response = self.PreResponse
916
 
917
+ def generate_prompt(self, data_point, reduced=None):
918
+ """
919
+ data_point['context'] is assumed to be like a system prompt or pre-conversation, not inserted after user prompt
920
+ :param data_point:
921
+ :param reduced:
922
+ :return:
923
+ """
924
+ reduced = data_point.get('context') not in ['', None] if reduced is None else reduced
925
+ making_context = False # whether really making final prompt or just generating context
926
+ prompt, _, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced,
927
+ making_context)
928
  if self.debug:
929
  print("prompt: %s" % prompt, flush=True)
930
+ # if have context, should have always reduced and only preappend promptA/B here
931
+ if data_point.get('context'):
932
+ if data_point.get('input') and self.promptA:
933
+ prompt = self.promptA + prompt
934
+ elif self.promptB:
935
+ prompt = self.promptB + prompt
936
+
937
  self.prompt = prompt
938
  return prompt
939
 
 
952
  if sanitize_bot_response:
953
  from better_profanity import profanity
954
  response = profanity.censor(response)
955
+ if self.generates_leading_space and isinstance(response, str) and len(response) > 0 and response[0] == ' ':
956
+ response = response[1:]
957
  return response
958
 
959
  def clean_repeats(response):
 
975
  # then use most basic parsing like pipeline
976
  if self.botstr in output:
977
  if self.humanstr:
978
+ output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
979
  else:
980
  # i.e. use after bot but only up to next bot
981
+ output = clean_response(output.split(self.botstr)[1].split(self.botstr)[0])
982
  else:
983
+ # output = clean_response(output)
984
  # assume just not printed yet
985
  output = ""
986
  else:
 
1007
  allow_terminate = True
1008
  output = output[len(prompt):]
1009
  # clean after subtract prompt out, so correct removal of pre_response
1010
+ output = clean_response(output)
1011
  if self.repeat_penalty:
1012
+ output = clean_repeats(output)
1013
  if self.terminate_response and allow_terminate:
1014
  finds = []
1015
  for term in self.terminate_response:
 
1017
  finds = [x for x in finds if x >= 0]
1018
  if len(finds) > 0:
1019
  termi = finds[0]
1020
+ output = output[:termi]
1021
  else:
1022
+ output = output
 
 
1023
  if multi_output:
1024
  # prefix with output counter
1025
  output = "\n=========== Output %d\n\n" % (1 + oi) + output
 
1032
  if self.debug:
1033
  print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
1034
  return output
1035
+ import torch
1036
+ from transformers import StoppingCriteria, StoppingCriteriaList
1037
+
1038
+
1039
+
1040
+ class StoppingCriteriaSub(StoppingCriteria):
1041
+
1042
+ def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None):
1043
+ super().__init__()
1044
+ assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
1045
+ self.encounters = encounters
1046
+ self.stops = [stop.to(device) for stop in stops]
1047
+ self.num_stops = [0] * len(stops)
1048
+ self.model_max_length = model_max_length
1049
+
1050
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
1051
+ for stopi, stop in enumerate(self.stops):
1052
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
1053
+ self.num_stops[stopi] += 1
1054
+ if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
1055
+ # print("Stopped", flush=True)
1056
+ return True
1057
+ if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
1058
+ # critical limit
1059
+ return True
1060
+ # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
1061
+ # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
1062
+ return False
1063
+
1064
+
1065
+ def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:", model_max_length=None):
1066
+ # FIXME: prompt_dict unused currently
1067
+ if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
1068
+ if prompt_type == PromptType.human_bot.name:
1069
+ # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
1070
+ # stopping only starts once output is beyond prompt
1071
+ # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
1072
+ stop_words = [human, bot, '\n' + human, '\n' + bot]
1073
+ encounters = [1, 2]
1074
+ elif prompt_type == PromptType.instruct_vicuna.name:
1075
+ # even below is not enough, generic strings and many ways to encode
1076
+ stop_words = [
1077
+ '### Human:',
1078
+ """
1079
+ ### Human:""",
1080
+ """
1081
+ ### Human:
1082
+ """,
1083
+ '### Assistant:',
1084
+ """
1085
+ ### Assistant:""",
1086
+ """
1087
+ ### Assistant:
1088
+ """,
1089
+ ]
1090
+ encounters = [1, 2]
1091
+ else:
1092
+ # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
1093
+ stop_words = ['### End']
1094
+ encounters = [1]
1095
+ stop_words_ids = [
1096
+ tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
1097
+ # handle single token case
1098
+ stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
1099
+ stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
1100
+ # avoid padding in front of tokens
1101
+ if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
1102
+ stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
1103
+ # handle fake \n added
1104
+ stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
1105
+ # build stopper
1106
+ stopping_criteria = StoppingCriteriaList(
1107
+ [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device,
1108
+ model_max_length=model_max_length)])
1109
+ else:
1110
+ stopping_criteria = StoppingCriteriaList()
1111
+ return stopping_criteria