derek-thomas HF staff commited on
Commit
d29dd9f
1 Parent(s): 19c1b44

Fixing return structure

Browse files
Files changed (1) hide show
  1. handler.py +11 -8
handler.py CHANGED
@@ -27,11 +27,13 @@ class EndpointHandler:
27
 
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
- self.tokenizer = AutoTokenizer.from_pretrained(path)
31
- self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto",
32
- offload_folder='offload',
33
- trust_remote_code=True,
34
- load_in_8bit=True)
 
 
35
 
36
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
37
 
@@ -39,6 +41,7 @@ class EndpointHandler:
39
  if 'prompt' in data.keys():
40
  text = data['prompt']
41
  else:
 
42
  user_data = data.pop('query',data)
43
  text = self.prompt_ar.format_map({'Question':user_data})
44
  inputs = data.pop("inputs", data)
@@ -71,10 +74,10 @@ class EndpointHandler:
71
  response = self.tokenizer.batch_decode(generate_ids,
72
  skip_special_tokens=True,
73
  clean_up_tokenization_spaces=True)[0]
74
- final_response = response.split("### Response: [|AI|]")
75
- turn = [f'[|Human|] {query}', f'[|AI|] {final_response[-1]}']
76
- chat_history.extend(turn)
77
  if 'prompt' in data.keys():
78
  return response
79
  else:
 
 
 
80
  return {"response": final_response, "chat_history": chat_history}
 
27
 
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
+ # self.tokenizer = AutoTokenizer.from_pretrained(path)
31
+ # self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto",
32
+ # offload_folder='offload',
33
+ # trust_remote_code=True,
34
+ # load_in_8bit=True)
35
+ self.tokenizer = tokenizer
36
+ self.model = model
37
 
38
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
39
 
 
41
  if 'prompt' in data.keys():
42
  text = data['prompt']
43
  else:
44
+ print(data.keys())
45
  user_data = data.pop('query',data)
46
  text = self.prompt_ar.format_map({'Question':user_data})
47
  inputs = data.pop("inputs", data)
 
74
  response = self.tokenizer.batch_decode(generate_ids,
75
  skip_special_tokens=True,
76
  clean_up_tokenization_spaces=True)[0]
 
 
 
77
  if 'prompt' in data.keys():
78
  return response
79
  else:
80
+ final_response = response.split("### Response: [|AI|]")
81
+ turn = [f'[|Human|] {query}', f'[|AI|] {final_response[-1]}']
82
+ chat_history.extend(turn)
83
  return {"response": final_response, "chat_history": chat_history}