Waste-Detector / app.py
Hector Lopez
Fix
db41ce0
raw
history blame
No virus
1.21 kB
import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
import cv2
import PIL
from model import get_model, predict, prepare_prediction
print('Creating the model')
model = get_model('checkpoint.ckpt')
def plot_img_no_mask(image, boxes):
# Show image
boxes = boxes.cpu().detach().numpy().astype(np.int32)
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
for i, box in enumerate(boxes):
[x1, y1, x2, y2] = np.array(box).astype(int)
# Si no se hace la copia da error en cv2.rectangle
image = np.array(image).copy()
pt1 = (x1, y1)
pt2 = (x2, y2)
cv2.rectangle(image, pt1, pt2, (220,0,0), thickness=5)
plt.axis('off')
ax.imshow(image)
fig.savefig("img.png", bbox_inches='tight')
image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
if image_file is not None:
print(image_file)
print('Getting predictions')
data = image_file.read()
pred_dict = predict(model, data)
print('Fixing the preds')
boxes, image = prepare_prediction(pred_dict)
print('Plotting')
plot_img_no_mask(image, boxes)
img = PIL.Image.open('img.png')
st.image(img,width=750)