Kaori1707's picture
add segment model
76daa54
raw
history blame
No virus
4.38 kB
import gradio as gr
import numpy as np
import torch
from torchvision.transforms import Compose
import cv2
from dpt.models import DPTDepthModel, DPTSegmentationModel
from dpt.transforms import Resize, NormalizeImage, PrepareForNet
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: %s" % device)
default_models = {
"dpt_hybrid": "weights/dpt_hybrid-midas-501f0c75.pt",
"segment_hybrid": "weights/dpt_hybrid-ade20k-53898607.pt"
}
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
depth_model = DPTDepthModel(
path=default_models["dpt_hybrid"],
backbone="vitb_rn50_384",
non_negative=True,
enable_attention_hooks=False,
)
depth_model.eval()
depth_model.to(device)
seg_model = DPTSegmentationModel(
150,
path=default_models["segment_hybrid"],
backbone="vitb_rn50_384",
)
seg_model.eval()
seg_model.to(device)
# Transform
net_w = net_h = 384
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="minimal",
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)
def write_depth(depth, bits=1, absolute_depth=False):
"""Write depth map to pfm and png file.
Args:
path (str): filepath without extension
depth (array): depth
"""
# write_pfm(path + ".pfm", depth.astype(np.float32))
if absolute_depth:
out = depth
else:
depth_min = depth.min()
depth_max = depth.max()
max_val = (2 ** (8 * bits)) - 1
if depth_max - depth_min > np.finfo("float").eps:
out = max_val * (depth - depth_min) / (depth_max - depth_min)
else:
out = np.zeros(depth.shape, dtype=depth.dtype)
if bits == 1:
return out.astype("uint8")
elif bits == 2:
return out.astype("uint16")
def DPT(image):
img_input = transform({"image": image})["image"]
# compute
with torch.no_grad():
sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
prediction = depth_model.forward(sample)
prediction = (
torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=image.shape[:2],
mode="bicubic",
align_corners=False,
)
.squeeze()
.cpu()
.numpy()
)
depth_img = write_depth(prediction, bits=2)
return depth_img
def Segment(image):
img_input = transform({"image": image})["image"]
# compute
with torch.no_grad():
sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
# if optimize == True and device == torch.device("cuda"):
# sample = sample.to(memory_format=torch.channels_last)
# sample = sample.half()
out = seg_model.forward(sample)
prediction = torch.nn.functional.interpolate(
out, size=image.shape[:2], mode="bicubic", align_corners=False
)
prediction = torch.argmax(prediction, dim=1) + 1
prediction = prediction.squeeze().cpu().numpy()
return prediction
title = " AISeed AI Application Demo "
description = "# A Demo of Deep Learning for Depth Estimation"
example_list = [["examples/" + example] for example in os.listdir("examples")]
with gr.Blocks() as demo:
demo.title = title
gr.Markdown(description)
with gr.Row():
with gr.Column():
im_2 = gr.Image(label="Depth Image")
im_3 = gr.Image(label="Segment Image")
with gr.Column():
im = gr.Image(label="Input Image")
btn1 = gr.Button(value="Depth Estimator")
btn1.click(DPT, inputs=[im], outputs=[im_2])
btn2 = gr.Button(value="Segment")
btn2.click(Segment, inputs=[im], outputs=[im_3])
gr.Examples(examples=example_list,
inputs=[im],
outputs=[im_2])
if __name__ == "__main__":
demo.launch()