ChirathD's picture
Create handler.py
3115ce7 verified
raw
history blame contribute delete
No virus
1.14 kB
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer
import torch
from peft import PeftModel
import json
import os
class EndpointHandler():
def __init__(self, path=""):
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_quant_type="nf4",
bnb_8bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained('LexiconShiftInnovations/Gemma_Dental_it_07_merged')
model = AutoModelForCausalLM.from_pretrained('LexiconShiftInnovations/Gemma_Dental_it_07_merged', quantization_config=bnb_config, device_map={"":0})
self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
if parameters is not None:
prediction = self.pipeline(inputs, **parameters)
else:
prediction = self.pipeline(inputs)
return prediction