Factory-POC / app.py
Do0rMaMu's picture
Update app.py
e8ce44e verified
raw
history blame
No virus
1.88 kB
from fastapi import FastAPI, UploadFile, File
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import torch
import io
import os
from typing import Union
# Patch to remove flash-attn dependency
from transformers.dynamic_module_utils import get_imports
def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
"""Work around for flash-attn imports."""
if not str(filename).endswith("/modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
if "flash_attn" in imports:
imports.remove("flash_attn")
return imports
device = "cuda" if torch.cuda.is_available() else "cpu"
# Apply the patch
from unittest.mock import patch
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = AutoModelForCausalLM.from_pretrained("numberPlate_model_2", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("numberPlate_model_2", trust_remote_code=True)
# Initialize FastAPI
app = FastAPI()
def process_image(image, task_token):
inputs = processor(text=task_token, images=image, return_tensors="pt", padding=True).to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=False
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_result = processor.post_process_generation(generated_text, task=task_token, image_size=(image.width, image.height))
return parsed_result
@app.post("/process-image/")
async def process_image_endpoint(file: UploadFile = File(...), task_token: str = "<OD>"):
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
result = process_image(image, task_token)
return result