AeroPath / demo /src /gui.py
andreped's picture
Linted code
8291c90
raw
history blame
No virus
9.34 kB
import os
import gradio as gr
from .convert import nifti_to_obj
from .css_style import css
from .inference import run_model
from .logger import flush_logs
from .logger import read_logs
from .logger import setup_logger
from .utils import load_ct_to_numpy
from .utils import load_pred_volume_to_numpy
# setup logging
LOGGER = setup_logger()
class WebUI:
def __init__(
self,
model_name: str = None,
cwd: str = "/home/user/app/",
share: int = 1,
):
# global states
self.images = []
self.pred_images = []
# @TODO: This should be dynamically set based on chosen volume size
self.nb_slider_items = 820
self.model_name = model_name
self.cwd = cwd
self.share = share
self.class_name = "airways" # default
self.class_names = {
"airways": "CT_Airways",
"lungs": "CT_Lungs",
}
self.result_names = {
"airways": "Airways",
"lungs": "Lungs",
}
# define widgets not to be rendered immediantly, but later on
self.slider = gr.Slider(
minimum=1,
maximum=self.nb_slider_items,
value=1,
step=1,
label="Which 2D slice to show",
)
self.volume_renderer = gr.Model3D(
clear_color=[0.0, 0.0, 0.0, 0.0],
label="3D Model",
show_label=True,
visible=True,
elem_id="model-3d",
camera_position=[90, 180, 768],
height=512,
)
def set_class_name(self, value):
LOGGER.info(f"Changed task to: {value}")
self.class_name = value
def combine_ct_and_seg(self, img, pred):
return (img, [(pred, self.class_name)])
def upload_file(self, file):
out = file.name
LOGGER.info(f"File uploaded: {out}")
return out
def process(self, mesh_file_name):
path = mesh_file_name.name
curr = path.split("/")[-1]
self.extension = ".".join(curr.split(".")[1:])
self.filename = (
curr.split(".")[0] + "-" + self.class_names[self.class_name]
)
run_model(
path,
model_path=os.path.join(self.cwd, "resources/models/"),
task=self.class_names[self.class_name],
name=self.result_names[self.class_name],
output_filename=self.filename + "." + self.extension,
)
LOGGER.info("Converting prediction NIfTI to OBJ...")
nifti_to_obj(path=self.filename + "." + self.extension)
LOGGER.info("Loading CT to numpy...")
self.images = load_ct_to_numpy(path)
LOGGER.info("Loading prediction volume to numpy..")
self.pred_images = load_pred_volume_to_numpy(
self.filename + "." + self.extension
)
return "./prediction.obj"
def download_prediction(self):
if (not self.filename) or (not self.extension):
LOGGER.error(
"The prediction is not available or ready to download. Wait until the result is available in the 3D viewer."
)
return self.filename + "." + self.extension
def get_img_pred_pair(self, k):
k = int(k)
out = gr.AnnotatedImage(
self.combine_ct_and_seg(self.images[k], self.pred_images[k]),
visible=True,
elem_id="model-2d",
color_map={self.class_name: "#ffae00"},
height=512,
width=512,
)
return out
def toggle_sidebar(self, state):
state = not state
return gr.update(visible=state), state
def run(self):
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column(visible=True, scale=0.2) as sidebar_left:
logs = gr.Textbox(
placeholder="\n" * 16,
label="Logs",
info="Verbose from inference will be displayed below.",
lines=36,
max_lines=36,
autoscroll=True,
elem_id="logs",
show_copy_button=True,
container=True,
)
demo.load(read_logs, None, logs, every=1)
with gr.Column():
with gr.Row():
with gr.Column(scale=1, min_width=150):
sidebar_state = gr.State(True)
btn_toggle_sidebar = gr.Button(
"Toggle Sidebar",
elem_id="toggle-button",
)
btn_toggle_sidebar.click(
self.toggle_sidebar,
[sidebar_state],
[sidebar_left, sidebar_state],
)
btn_clear_logs = gr.Button(
"Clear logs", elem_id="logs-button"
)
btn_clear_logs.click(flush_logs, [], [])
file_output = gr.File(
file_count="single",
elem_id="upload",
scale=3,
)
file_output.upload(
self.upload_file, file_output, file_output
)
model_selector = gr.Dropdown(
list(self.class_names.keys()),
label="Task",
info="Which structure to segment.",
multiselect=False,
scale=1,
)
model_selector.input(
fn=lambda x: self.set_class_name(x),
inputs=model_selector,
outputs=None,
)
with gr.Column(scale=1, min_width=150):
run_btn = gr.Button(
"Run analysis",
variant="primary",
elem_id="run-button",
)
run_btn.click(
fn=lambda x: self.process(x),
inputs=file_output,
outputs=self.volume_renderer,
)
download_btn = gr.DownloadButton(
"Download prediction",
visible=True,
variant="secondary",
elem_id="download",
)
download_btn.click(
fn=self.download_prediction,
inputs=None,
outputs=download_btn,
)
with gr.Row():
gr.Examples(
examples=[
os.path.join(self.cwd, "test_thorax_CT.nii.gz"),
],
inputs=file_output,
outputs=file_output,
fn=self.upload_file,
cache_examples=True,
)
gr.Markdown(
"""
**NOTE:** Inference might take several minutes (Airways: ~8 minutes), see logs to the left. \\
The segmentation will be available in the 2D and 3D viewers below when finished.
"""
)
with gr.Row():
with gr.Group():
with gr.Column():
# create dummy image to be replaced by loaded images
t = gr.AnnotatedImage(
visible=True,
elem_id="model-2d",
color_map={self.class_name: "#ffae00"},
# height=512,
# width=512,
)
self.slider.input(
self.get_img_pred_pair,
self.slider,
t,
)
self.slider.render()
with gr.Group(): # gr.Box():
self.volume_renderer.render()
# sharing app publicly -> share=True:
# https://gradio.app/sharing-your-app/
# inference times > 60 seconds -> need queue():
# https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
demo.queue().launch(
server_name="0.0.0.0", server_port=7860, share=self.share
)