Spaces:
Build error
Build error
### 1. Imports and class names setup ### | |
import gradio as gr | |
import os | |
import requests | |
import torch | |
import numpy as np | |
from roboflow import Roboflow | |
import cv2 | |
rf = Roboflow(api_key="PO54lH9XBJxPjmlAvQsW") | |
project = rf.workspace().project("no_glasses") | |
model = project.version(1).model | |
file_urls = [ | |
'https://www.dropbox.com/s/7sjfwncffg8xej2/video_7.mp4?dl=1' | |
] | |
def download_file(url, save_name): | |
url = url | |
if not os.path.exists(save_name): | |
file = requests.get(url) | |
open(save_name, 'wb').write(file.content) | |
for i, url in enumerate(file_urls): | |
if 'mp4' in file_urls[i]: | |
download_file( | |
file_urls[i], | |
f"video.mp4" | |
) | |
else: | |
download_file( | |
file_urls[i], | |
f"image_{i}.jpg" | |
) | |
video_path = [['video.mp4']] | |
from model import create_effnetb2_model | |
from timeit import default_timer as timer | |
from typing import Tuple, Dict | |
# Setup class names | |
class_names = ["Yes","No"] | |
### 2. Model and transforms preparation ### | |
# Create EffNetB2 model | |
effnetb2, effnetb2_transforms = create_effnetb2_model( | |
num_classes=2, # len(class_names) would also work | |
) | |
# Load saved weights | |
effnetb2.load_state_dict( | |
torch.load( | |
f="glass_model.pth", | |
map_location=torch.device("cpu"), # load to CPU | |
) | |
) | |
def detect(imagepath): | |
pix=model.predict(imagepath, confidence=40, overlap=30) | |
pix=pix.json() | |
img=cv2.imread(imagepath) | |
x1,x2,y1,y2=[],[],[],[] | |
for i in pix.keys(): | |
if i=="predictions": | |
for j in pix["predictions"]: | |
for a,b in j.items(): | |
if a=="x": | |
x1.append(b) | |
if a=="y": | |
y1.append(b) | |
if a=="width": | |
x2.append(b) | |
if a=="height": | |
y2.append(b) | |
for p in range(0,len(x1)): | |
x2[p]=x2[p]+x1[p] | |
for p in range(0,len(x1)): | |
y2[p]=y2[p]+x1[p] | |
for (x11,y11,x12,y12) in zip(x1,y1,x2,y2): | |
cv2.rectangle( | |
img, | |
(x11,y11), | |
(x12,y12), | |
color=(0, 0, 255), | |
thickness=2, | |
lineType=cv2.LINE_AA | |
) | |
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
#cv2.imshow("kamehamehaa",img) | |
def show_preds_video(video_path): | |
cap = cv2.VideoCapture(video_path) | |
while(cap.isOpened()): | |
ret, frame = cap.read() | |
if ret: | |
frame_copy = frame.copy() | |
pix=model.predict(frame, confidence=40, overlap=30) | |
pix=pix.json() | |
x1,x2,y1,y2=[],[],[],[] | |
for i in pix.keys(): | |
if i=="predictions": | |
for j in pix["predictions"]: | |
for a,b in j.items(): | |
if a=="x": | |
x1.append(b) | |
if a=="y": | |
y1.append(b) | |
if a=="width": | |
x2.append(b) | |
if a=="height": | |
y2.append(b) | |
for p in range(0,len(x1)): | |
x2[p]=x2[p]+x1[p] | |
for p in range(0,len(x1)): | |
y2[p]=y2[p]+x1[p] | |
for (x11,y11,x12,y12) in zip(x1,y1,x2,y2): | |
cv2.rectangle( | |
img, | |
(x11,y11), | |
(x12,y12), | |
color=(0, 0, 255), | |
thickness=2, | |
lineType=cv2.LINE_AA | |
) | |
yield cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB) | |
### 3. Predict function ### | |
# Create predict function | |
def predict(img) -> Tuple[Dict, float]: | |
"""Transforms and performs a prediction on img and returns prediction and time taken. | |
""" | |
# Start the timer | |
start_time = timer() | |
# Transform the target image and add a batch dimension | |
img = effnetb2_transforms(img).unsqueeze(0) | |
# Put model into evaluation mode and turn on inference mode | |
effnetb2.eval() | |
with torch.inference_mode(): | |
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities | |
pred_probs = torch.softmax(effnetb2(img), dim=1) | |
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter) | |
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))} | |
# Calculate the prediction time | |
pred_time = round(timer() - start_time, 5) | |
# Return the prediction dictionary and prediction time | |
return pred_labels_and_probs, pred_time | |
### 4. Gradio app ### | |
# Create title, description and article strings | |
title = "Safety Glasses Detector" | |
description = "An EfficientNetB2 feature extractor computer vision model to classify images of Safety glasses at construction sites" | |
article = "(https://www.learnpytorch.io/)." | |
# Create examples list from "examples/" directory | |
#example_list = [["examples/" + example] for example in os.listdir("examples")] | |
inputs_image = [ | |
gr.components.Image(type="filepath", label="Input Image"), | |
] | |
outputs_image = [ | |
gr.components.Image(type="numpy", label="Output Image"), | |
] | |
inputs_video = [ | |
gr.components.Video(type="filepath", label="Input Video"), | |
] | |
outputs_video = [ | |
gr.components.Image(type="numpy", label="Output Image"), | |
] | |
# Create the Gradio demo | |
app1 = gr.Interface(fn=predict, # mapping function from input to output | |
inputs=gr.Image(type="pil"), # what are the inputs? | |
outputs=[gr.Label(num_top_classes=2, label="Predictions"), # what are the outputs? | |
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs | |
title=title, | |
description=description, | |
article=article) | |
app2=gr.Interface(fn=detect, | |
inputs=inputs_image, | |
outputs=outputs_image, | |
title=title) | |
app3=gr.Interface( | |
fn=show_preds_video, | |
inputs=inputs_video, | |
outputs=outputs_video, | |
examples=video_path, | |
cache_examples=False, | |
) | |
demo = gr.TabbedInterface([app1, app2,app3], ["Classify", "Detect","Video Interface"]) | |
# Launch the demo! | |
demo.launch() | |