jpohhhh commited on
Commit
96355e1
1 Parent(s): 89b609f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -5
handler.py CHANGED
@@ -1,6 +1,6 @@
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModel
3
- from optimum.onnxruntime import ORTModelForCustomTasks
4
 
5
  import torch
6
 
@@ -12,10 +12,9 @@ def mean_pooling(model_output, attention_mask):
12
 
13
  class EndpointHandler():
14
  def __init__(self, path=""):
15
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- self.model = ORTModelForCustomTasks.from_pretrained("optimum/sbert-all-MiniLM-L6-with-pooler")
17
- self.tokenizer = AutoTokenizer.from_pretrained("optimum/sbert-all-MiniLM-L6-with-pooler")
18
- self.onnx_extractor = pipeline("feature-extraction", model=model, tokenizer=tokenizer)
19
  # self.model.to(self.device)
20
  # print("model will run on ", self.device)
21
 
 
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModel
3
+ from optimum.pipelines import pipeline
4
 
5
  import torch
6
 
 
12
 
13
  class EndpointHandler():
14
  def __init__(self, path=""):
15
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ # self.tokenizer = AutoTokenizer.from_pretrained("optimum/sbert-all-MiniLM-L6-with-pooler")
17
+ self.onnx_extractor = pipeline("feature-extraction", model="optimum/sbert-all-MiniLM-L6-with-pooler", accelerator="ort")
 
18
  # self.model.to(self.device)
19
  # print("model will run on ", self.device)
20