File size: 3,230 Bytes
1f53a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
809872d
 
76261ed
dd2f551
a94c67d
 
76261ed
809872d
b9d9d5a
809872d
 
 
 
 
 
 
 
1f53a4c
25c58fc
 
 
dd2f551
809872d
c567f94
25c58fc
 
 
 
 
 
 
 
f239e66
 
25c58fc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
import gradio as gr
from lib import create_model
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
from lib.dataloader import ImageMixin


test_weight = './weight_epoch-200_best.pt'
parameter = './parameters.json'

class ImageHandler(ImageMixin):
    def __init__(self, params):
        self.params = params
        self.transform = self._make_transforms()

    def set_image(self, image):
        image = self.transform(image)
        image = {'image': image.unsqueeze(0)}
        return image

def load_parameter(parameter):
    _args = ParamSet()
    params = _retrieve_parameter(parameter)
    for _param, _arg in params.items():
        setattr(_args, _param, _arg)

    _args.augmentation = 'no'
    _args.sampler = 'no'
    _args.pretrained = False
    _args.mlp = None
    _args.net = _args.model
    _args.device = torch.device('cpu')

    args_model = _dispatch_by_group(_args, 'model')
    args_dataloader = _dispatch_by_group(_args, 'dataloader')
    return args_model, args_dataloader

args_model, args_dataloader = load_parameter(parameter)
model = create_model(args_model)
model.load_weight(test_weight)

def main(image):
    model.eval()
    image_handler = ImageHandler(args_dataloader)
    image = image_handler.set_image(image)

    with torch.no_grad():
        outputs = model(image)

    label_name = list(outputs.keys())[0]
    result = outputs[label_name].detach().numpy().item()
    result = f"{result:.2f}"
    return result

html_content = """
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
    <h3>Image preprocess</h3>
    <p>Only grayscale 320×320 resolution works appropriately.</p>
    <p>The longest side of the Xp should be downscaled to 320 pixels while maintaining the aspect ratio, 
        and the width along the shorter side should be padded black to 320 pixels.
    </p>
    <h3>Publication Details</h3>
    <p>See details in our publication, titled 
        "Chest radiography as a biomarker of ageing: artificial intelligence-based, 
        multi-institutional model development and validation in Japan"
    </p>
    <p><strong>Link:</strong> <a href="https://www.thelancet.com/journals/lanhl/article/PIIS2666-7568(23)00133-2/fulltext" target="_blank">
        https://www.thelancet.com/journals/lanhl/article/PIIS2666-7568(23)00133-2/fulltext
    </a></p>
</div>
"""
# Gradio
with gr.Blocks(title="Aging Biomarker from CXR",
            css=".gradio-container {background:mintcream;}"
        ) as demo:
    gr.HTML("""<div style="text-align:center"><h2>Aging Biomarker from CXR</h2></div>""")
    gr.HTML(html_content)
            
    with gr.Row(): 
        input_image = gr.Image(type="pil", image_mode="L", shape=(320, 320))        
        output_label=gr.Label(label="Estimated age")

    send_btn = gr.Button("Inference")
    send_btn.click(fn=main, inputs=input_image, outputs=output_label)

    with gr.Row():
        gr.Examples(['./samples/66_female_xp.png'], label='Sample CXR 1: 66 years old female', inputs=input_image)
        gr.Examples(['./samples/28_male_xp.png'], label='Sample CXR 2: 28 years old male', inputs=input_image)

demo.launch(debug=True)