Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import torch | |
import gradio as gr | |
from lib import create_model | |
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group | |
from lib.dataloader import ImageMixin | |
test_weight = './weight_epoch-200_best.pt' | |
parameter = './parameters.json' | |
class ImageHandler(ImageMixin): | |
def __init__(self, params): | |
self.params = params | |
self.transform = self._make_transforms() | |
def set_image(self, image): | |
image = self.transform(image) | |
image = {'image': image.unsqueeze(0)} | |
return image | |
def load_parameter(parameter): | |
_args = ParamSet() | |
params = _retrieve_parameter(parameter) | |
for _param, _arg in params.items(): | |
setattr(_args, _param, _arg) | |
_args.augmentation = 'no' | |
_args.sampler = 'no' | |
_args.pretrained = False | |
_args.mlp = None | |
_args.net = _args.model | |
_args.device = torch.device('cpu') | |
args_model = _dispatch_by_group(_args, 'model') | |
args_dataloader = _dispatch_by_group(_args, 'dataloader') | |
return args_model, args_dataloader | |
args_model, args_dataloader = load_parameter(parameter) | |
model = create_model(args_model) | |
model.load_weight(test_weight) | |
def main(image): | |
model.eval() | |
image_handler = ImageHandler(args_dataloader) | |
image = image_handler.set_image(image) | |
with torch.no_grad(): | |
outputs = model(image) | |
label_name = list(outputs.keys())[0] | |
result = outputs[label_name].detach().numpy().item() | |
result = f"{result:.2f}" | |
return result | |
# Gradio | |
with gr.Blocks(title="Aging Biomarker from CXR", | |
css=".gradio-container {background:mintcream;}" | |
) as demo: | |
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-weight:bold; text-align:center; color:royalblue;">Aging Biomarker from CXR</div>""") | |
with gr.Row(): | |
input_image = gr.Image(type="pil", image_mode="L", shape=(320, 320)) | |
output_label=gr.Label(label="Estimated age") | |
send_btn = gr.Button("Inference") | |
send_btn.click(fn=main, inputs=input_image, outputs=output_label) | |
with gr.Row(): | |
gr.Examples(['./samples/66_female_xp.png'], label='CXR: 66 years old female', inputs=input_image) | |
gr.Examples(['./samples/28_male_xp.png'], label='CXR: 28 years old male', inputs=input_image) | |
demo.launch(debug=True) | |