tb-ocr / app.py
wjbmattingly's picture
Update app.py
b71e651 verified
raw
history blame
No virus
1.87 kB
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import requests
import gradio as gr
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
model_id = "yifeihu/TB-OCR-preview-0.1"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cuda",
trust_remote_code=True,
torch_dtype="auto",
attn_implementation='flash_attention_2',
load_in_4bit=True
)
processor = AutoProcessor.from_pretrained(model_id,
trust_remote_code=True,
num_crops=16
)
def phi_ocr(image):
question = "Convert the text to markdown format."
prompt_message = [{
'role': 'user',
'content': f'<|image_1|>\n{question}',
}]
prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, [image], return_tensors="pt").to("cuda")
generation_args = {
"max_new_tokens": 1024,
"temperature": 0.1,
"do_sample": False
}
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
response = response.split("<image_end>")[0]
return response
@spaces.GPU
def process_image(input_image):
return phi_ocr(input_image)
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="OCR with Phi-3.5-vision-instruct",
description="Upload an image to extract and convert text to markdown format."
)
iface.launch()