jpohhhh's picture
Fix handler.py import
a61e58e
raw
history blame
1.82 kB
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModel
from optimum.pipelines import pipeline
from optimum.onnxruntime import ORTModelForFeatureExtraction
import torch
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class EndpointHandler():
def __init__(self, path=""):
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.tokenizer = AutoTokenizer.from_pretrained("optimum/sbert-all-MiniLM-L6-with-pooler")
model_regular = ORTModelForFeatureExtraction.from_pretrained("", file_name="model.onnx", from_transformers=False)
self.onnx_extractor = pipeline(task, model=model_regular, tokenizer=tokenizer)
# self.model.to(self.device)
# print("model will run on ", self.device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
sentences = data.pop("inputs",data)
# inputs = tokenizer("I love burritos!", return_tensors="pt")
pred = self.onnx_extractor(sentences)
return pred
# Perform pooling. In this case, max pooling.
# sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# return sentence_embeddings.tolist()