Nanobit commited on
Commit
fac4600
2 Parent(s): dcb03d6 33d4017

Merge pull request #119 from NanoCode012/feat/update-inference

Browse files
Files changed (1) hide show
  1. scripts/finetune.py +15 -7
scripts/finetune.py CHANGED
@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Union
12
  import fire
13
  import torch
14
  import yaml
 
15
 
16
  from axolotl.utils.data import load_prepare_datasets
17
  from axolotl.utils.dict import DictDefault
@@ -73,26 +74,33 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
73
  instruction = get_multi_line_input()
74
  if not instruction:
75
  return
76
- prompt: str = next(prompter_module().build_prompt(instruction=instruction))
 
 
77
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
78
 
79
  model.eval()
80
  with torch.no_grad():
81
- # gc = GenerationConfig() # TODO swap out and use this
82
- generated = model.generate(
83
- inputs=batch["input_ids"].to(cfg.device),
84
- do_sample=True,
85
- use_cache=True,
86
  repetition_penalty=1.1,
87
- max_new_tokens=100,
88
  temperature=0.9,
89
  top_p=0.95,
90
  top_k=40,
 
 
 
 
 
91
  return_dict_in_generate=True,
92
  output_attentions=False,
93
  output_hidden_states=False,
94
  output_scores=False,
95
  )
 
 
 
 
96
  print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
97
 
98
 
 
12
  import fire
13
  import torch
14
  import yaml
15
+ from transformers import GenerationConfig
16
 
17
  from axolotl.utils.data import load_prepare_datasets
18
  from axolotl.utils.dict import DictDefault
 
74
  instruction = get_multi_line_input()
75
  if not instruction:
76
  return
77
+ prompt: str = next(
78
+ prompter_module().build_prompt(instruction=instruction.strip("\n"))
79
+ )
80
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
81
 
82
  model.eval()
83
  with torch.no_grad():
84
+ generation_config = GenerationConfig(
 
 
 
 
85
  repetition_penalty=1.1,
86
+ max_new_tokens=1024,
87
  temperature=0.9,
88
  top_p=0.95,
89
  top_k=40,
90
+ bos_token_id=tokenizer.bos_token_id,
91
+ eos_token_id=tokenizer.eos_token_id,
92
+ pad_token_id=tokenizer.pad_token_id,
93
+ do_sample=True,
94
+ use_cache=True,
95
  return_dict_in_generate=True,
96
  output_attentions=False,
97
  output_hidden_states=False,
98
  output_scores=False,
99
  )
100
+ generated = model.generate(
101
+ inputs=batch["input_ids"].to(cfg.device),
102
+ generation_config=generation_config,
103
+ )
104
  print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
105
 
106