jpohhhh commited on
Commit
89b609f
1 Parent(s): 0f61114

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -5
handler.py CHANGED
@@ -15,6 +15,7 @@ class EndpointHandler():
15
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  self.model = ORTModelForCustomTasks.from_pretrained("optimum/sbert-all-MiniLM-L6-with-pooler")
17
  self.tokenizer = AutoTokenizer.from_pretrained("optimum/sbert-all-MiniLM-L6-with-pooler")
 
18
  # self.model.to(self.device)
19
  # print("model will run on ", self.device)
20
 
@@ -27,9 +28,9 @@ class EndpointHandler():
27
  A :obj:`list` | `dict`: will be serialized and returned
28
  """
29
  sentences = data.pop("inputs",data)
30
- inputs = tokenizer("I love burritos!", return_tensors="pt")
31
- pred = self.model(**encoded_input)
32
-
33
  # Perform pooling. In this case, max pooling.
34
- sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
35
- return sentence_embeddings.tolist()
 
15
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  self.model = ORTModelForCustomTasks.from_pretrained("optimum/sbert-all-MiniLM-L6-with-pooler")
17
  self.tokenizer = AutoTokenizer.from_pretrained("optimum/sbert-all-MiniLM-L6-with-pooler")
18
+ self.onnx_extractor = pipeline("feature-extraction", model=model, tokenizer=tokenizer)
19
  # self.model.to(self.device)
20
  # print("model will run on ", self.device)
21
 
 
28
  A :obj:`list` | `dict`: will be serialized and returned
29
  """
30
  sentences = data.pop("inputs",data)
31
+ # inputs = tokenizer("I love burritos!", return_tensors="pt")
32
+ pred = onnx_extractor(sentences)
33
+ return pred
34
  # Perform pooling. In this case, max pooling.
35
+ # sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
36
+ # return sentence_embeddings.tolist()