E-Hospital commited on
Commit
df17865
1 Parent(s): f4a9a7f

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +35 -0
handler.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
3
+ import torch
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ # load model and processor from path
8
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(path, device_map="auto")
9
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
10
+
11
+ self.pipeline = pipeline(task="text-generation", tokenizer=self.tokenizer, device=0, device_map="auto", framework="pt", model=self.model, max_length=512)
12
+
13
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
14
+ """
15
+ Args:
16
+ data (:obj:):
17
+ includes the deserialized image file as PIL.Image
18
+ """
19
+ # process input
20
+ inputs = data.pop("inputs", data)
21
+ parameters = data.pop("parameters", None)
22
+
23
+ # preprocess
24
+ input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
25
+
26
+ # pass inputs with all kwargs in data
27
+ if parameters is not None:
28
+ prediction = self.pipeline(inputs, device=0, **parameters)
29
+ else:
30
+ prediction = self.pipeline(inputs, device=0)
31
+
32
+ # postprocess the prediction
33
+ prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
34
+
35
+ return [{"generated_text": prediction}]