tb-ocr / app.py
wjbmattingly's picture
Update app.py
ee83624 verified
raw
history blame contribute delete
No virus
2.09 kB
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import requests
import gradio as gr
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
model_id = "yifeihu/TB-OCR-preview-0.1"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cuda",
trust_remote_code=True,
torch_dtype="auto",
attn_implementation='flash_attention_2',
load_in_4bit=True
)
processor = AutoProcessor.from_pretrained(model_id,
trust_remote_code=True,
num_crops=16
)
def phi_ocr(image):
question = "Convert the text to markdown format."
prompt_message = [{
'role': 'user',
'content': f'<|image_1|>\n{question}',
}]
prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, [image], return_tensors="pt").to("cuda")
generation_args = {
"max_new_tokens": 1024,
"temperature": 0.1,
"do_sample": False
}
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
response = response.split("<image_end>")[0]
return response
@spaces.GPU
def process_image(input_image):
return phi_ocr(input_image)
with gr.Blocks() as demo:
gr.Markdown("# OCR with TB-OCR-preview-0.1")
gr.Markdown("Upload an image to extract and convert text to markdown format.")
gr.Markdown("[Check out the model here](https://huggingface.co/yifeihu/TB-OCR-preview-0.1)")
with gr.Row():
input_image = gr.Image(type="pil")
output_text = gr.Textbox()
input_image.change(fn=process_image, inputs=input_image, outputs=output_text)
demo.launch()