#!/usr/bin/env python # -*- coding: utf-8 -*- import torch import gradio as gr from lib import create_model from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group from lib.dataloader import ImageMixin test_weight = './weight_epoch-200_best.pt' parameter = './parameters.json' class ImageHandler(ImageMixin): def __init__(self, params): self.params = params self.transform = self._make_transforms() def set_image(self, image): image = self.transform(image) image = {'image': image.unsqueeze(0)} return image def load_parameter(parameter): _args = ParamSet() params = _retrieve_parameter(parameter) for _param, _arg in params.items(): setattr(_args, _param, _arg) _args.augmentation = 'no' _args.sampler = 'no' _args.pretrained = False _args.mlp = None _args.net = _args.model _args.device = torch.device('cpu') args_model = _dispatch_by_group(_args, 'model') args_dataloader = _dispatch_by_group(_args, 'dataloader') return args_model, args_dataloader args_model, args_dataloader = load_parameter(parameter) model = create_model(args_model) model.load_weight(test_weight) def main(image): model.eval() image_handler = ImageHandler(args_dataloader) image = image_handler.set_image(image) with torch.no_grad(): outputs = model(image) label_name = list(outputs.keys())[0] result = outputs[label_name].detach().numpy().item() result = f"{result:.2f}" return result # Gradio with gr.Blocks(title="Aging Biomarker from CXR", css=".gradio-container {background:mintcream;}" ) as demo: gr.HTML("""
Aging Biomarker from CXR
""") with gr.Row(): input_image = gr.Image(type="pil", image_mode="L", shape=(320, 320)) output_label=gr.Label(label="Estimated age") send_btn = gr.Button("Inference") send_btn.click(fn=main, inputs=input_image, outputs=output_label) with gr.Row(): gr.Examples(['./samples/66_female_xp.png'], label='Sample CXR 1: 66 years old female', inputs=input_image) gr.Examples(['./samples/28_male_xp.png'], label='Sample CXR 2: 28 years old male', inputs=input_image) demo.launch(debug=True)