arnocandel commited on
Commit
d5d0b9d
1 Parent(s): 98a65ab

Upload 10 files

Browse files
h2oai_pipeline.py CHANGED
@@ -1,6 +1,9 @@
1
  from transformers import TextGenerationPipeline
2
  from transformers.pipelines.text_generation import ReturnType
3
 
 
 
 
4
  human = "<human>:"
5
  bot = "<bot>:"
6
 
@@ -28,3 +31,8 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
28
  for rec in records:
29
  rec['generated_text'] = rec['generated_text'].split(bot)[1].strip().split(human)[0].strip()
30
  return records
 
 
 
 
 
 
1
  from transformers import TextGenerationPipeline
2
  from transformers.pipelines.text_generation import ReturnType
3
 
4
+ from stopping import get_stopping
5
+
6
+ prompt_type = "human_bot"
7
  human = "<human>:"
8
  bot = "<bot>:"
9
 
 
31
  for rec in records:
32
  rec['generated_text'] = rec['generated_text'].split(bot)[1].strip().split(human)[0].strip()
33
  return records
34
+
35
+ def _forward(self, model_inputs, **generate_kwargs):
36
+ stopping_criteria = get_stopping(prompt_type, self.tokenizer, self.device, human=human, bot=bot)
37
+ generate_kwargs['stopping_criteria'] = stopping_criteria
38
+ return super()._forward(model_inputs, **generate_kwargs)
pytorch_model-00001-of-00003.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2431df521ef2282165f9264ccde0544981efa1906849f1e030b89ce7c544f307
3
  size 5028171302
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d0678fa21071e8428e77f6cf089e44d4aa6999bc81968b6bf7e211013ff39c7
3
  size 5028171302
pytorch_model-00002-of-00003.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:303d7d7c7bdc0f4f493d372226e70177252a00f33e9769880809a63878586f95
3
  size 5017761129
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:adb145b24d55109b0973dd9f10b4f2d6d90c33c7f27f80367217a0e59ed5af50
3
  size 5017761129
pytorch_model-00003-of-00003.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:320cf6f19aba9e0179ffb282952402038edfd27389c5468e40365b37d48259a7
3
  size 3803055858
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc650723bc83675298d5451e40bd0ea0c18c9b16b87a81e44b20641695ee9900
3
  size 3803055858