#!/usr/bin/env python # -*- coding: utf-8 -*- import torch import gradio as gr from lib import create_model from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group from lib.dataloader import ImageMixin test_weight = './weight_epoch-200_best.pt' parameter = './parameters.json' class ImageHandler(ImageMixin): def __init__(self, params): self.params = params self.transform = self._make_transforms() def set_image(self, image): image = self.transform(image) image = {'image': image.unsqueeze(0)} return image def load_parameter(parameter): _args = ParamSet() params = _retrieve_parameter(parameter) for _param, _arg in params.items(): setattr(_args, _param, _arg) _args.augmentation = 'no' _args.sampler = 'no' _args.pretrained = False _args.mlp = None _args.net = _args.model _args.device = torch.device('cpu') args_model = _dispatch_by_group(_args, 'model') args_dataloader = _dispatch_by_group(_args, 'dataloader') return args_model, args_dataloader args_model, args_dataloader = load_parameter(parameter) model = create_model(args_model) model.load_weight(test_weight) def main(image): model.eval() image_handler = ImageHandler(args_dataloader) image = image_handler.set_image(image) with torch.no_grad(): outputs = model(image) label_name = list(outputs.keys())[0] result = outputs[label_name].detach().numpy().item() result = f"{result:.2f}" return result # Gradio iface = gr.Interface(fn=main, inputs=[gr.Image(type='pil', image_mode='L')], outputs='text') iface.launch()