ppt_owl_vit / app.py
mrciolino's picture
reqs
882a71b
raw
history blame contribute delete
No virus
5.33 kB
from transformers import OwlViTProcessor, OwlViTForObjectDetection
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
import warnings
import torch
import os
import io
# setttings
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
warnings.filterwarnings('ignore')
st.set_page_config()
class owl_vit:
def __init__(self, image_path, text, threshold):
self.image_path = image_path
self.text = text
self.threshold = threshold
def process(self, processor, model):
image = Image.open(self.image_path)
if len(image.split()) == 1:
image = image.convert("RGB")
inputs = processor(text=[self.text], images=[image], return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([[image.height, image.width] for image in [image]])
self.results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
self.image = image
return self.result_image()
def result_image(self):
boxes, scores, labels = self.results[0]["boxes"], self.results[0]["scores"], self.results[0]["labels"]
plt.imshow(self.image)
ax = plt.gca()
for box, score, label in zip(boxes, scores, labels):
if score >= self.threshold:
box = box.detach().numpy()
color = list(mcolors.CSS4_COLORS.keys())[label]
ax.add_patch(plt.Rectangle(box[:2], box[2] - box[0], box[3] - box[1], fill=False, color=color, linewidth=3,))
ax.text(box[0], box[1], f"{self.text[label]}: {round(score.item(), 2)}", fontsize=15, color=color)
plt.tight_layout()
img_buf = io.BytesIO()
plt.savefig(img_buf, format='png')
image = Image.open(img_buf)
return image
def load_model():
with st.spinner('Getting Neruons in Order ...'):
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16")
return processor, model
def show_detects(image):
st.title("Results")
st.image(image, use_column_width=True, caption="Object Detection Results", clamp=True)
def process(upload, text, threshold):
# save upload to file
filetype = upload.name.split('.')[-1]
name = len(os.listdir("images")) + 1
file_path = os.path.join('images', f'{name}.{filetype}')
with open(file_path, "wb") as f:
f.write(upload.getbuffer())
# predict detections and show results
detector = owl_vit(file_path, text, threshold)
results = detector.process(processor, model)
show_detects(results)
# clean up - if over 1000 images in folder, delete oldest 1
if len(os.listdir("images")) > 1000:
oldest = min(os.listdir("images"), key=os.path.getctime)
os.remove(os.path.join("images", oldest))
def main(processor, model):
# splash image
st.image(os.path.join('refs', 'baseball_labeled.png'), use_column_width=True)
# title project descriptions
st.title("OWL-ViT")
st.markdown("**OWL-ViT** is a zero-shot text-conditioned object detection model. OWL-ViT uses CLIP as its multi-modal \
backbone, with a ViT-like Transformer to get visual features and a causal language model to get the text features. \
To use CLIP for detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a \
lightweight classification and box head to each transformer output token. Open-vocabulary classification \
is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained \
from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification \
and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image \
can be used to perform zero-shot text-conditioned object detection.", unsafe_allow_html=True)
# example
if st.button("Run the Example Image/Text"):
with st.spinner('Detecting Objects and Comparing Vocab...'):
info = owl_vit(os.path.join('refs', 'baseball.jpg'), ["batter", "umpire", "catcher"], 0.50)
results = info.process(processor, model)
show_detects(results)
if st.button("Clear Example"):
st.markdown("")
# upload
col1, col2 = st.columns(2)
threshold = st.slider('Confidence Threshold', min_value=0.0, max_value=1.0, value=0.1)
with col1:
upload = st.file_uploader('Image:', type=['jpg', 'jpeg', 'png'])
with col2:
text = st.text_area('Objects to Detect: (comma, seperated)', "batter, umpire, catcher")
text = [x.strip() for x in text.split(',')]
# process
if upload is not None and text is not None:
filetype = upload.name.split('.')[-1]
if filetype in ['jpg', 'jpeg', 'png']:
with st.spinner('Detecting and Counting Single Image...'):
process(upload, text, threshold)
else:
st.warning('Unsupported file type.')
if __name__ == '__main__':
processor, model = load_model()
main(processor, model)