import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import cv2 import os from PIL import Image import warnings import gradio as gr from model import DocGeoNet from seg import U2NETP import glob warnings.filterwarnings('ignore') class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.msk = U2NETP(3, 1) self.DocTr = DocGeoNet() def forward(self, x): msk, _1,_2,_3,_4,_5,_6 = self.msk(x) msk = (msk > 0.5).float() x = msk * x _, _, bm = self.DocTr(x) bm = (2 * (bm / 255.) - 1) * 0.99 return bm def reload_seg_model(model, path=""): if not bool(path): return model else: model_dict = model.state_dict() pretrained_dict = torch.load(path, map_location='cpu') pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) return model def reload_rec_model(model, path=""): if not bool(path): return model else: model_dict = model.state_dict() pretrained_dict = torch.load(path, map_location='cpu') pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) return model def rec(input_image): seg_model_path = './model_pretrained/preprocess.pth' rec_model_path = './model_pretrained/DocGeoNet.pth' net = Net() reload_rec_model(net.DocTr, rec_model_path) reload_seg_model(net.msk, seg_model_path) net.eval() im_ori = np.array(input_image)[:, :, :3] / 255. # read image 0-255 to 0-1 h, w, _ = im_ori.shape im = cv2.resize(im_ori, (256, 256)) im = im.transpose(2, 0, 1) im = torch.from_numpy(im).float().unsqueeze(0) with torch.no_grad(): bm = net(im) bm = bm.cpu() bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow bm0 = cv2.blur(bm0, (3, 3)) bm1 = cv2.blur(bm1, (3, 3)) lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2 out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True) img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8) # Convert from BGR to RGB img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB) return Image.fromarray(img_rec) demo_img_files = glob.glob('./distorted/*.[jJ][pP][gG]') + glob.glob('./distorted/*.[pP][nN][gG]') # Gradio Interface input_image = gr.inputs.Image() output_image = gr.outputs.Image(type='pil') iface = gr.Interface(fn=rec, inputs=input_image, outputs=output_image, title="DocGeoNet",examples=demo_img_files) #iface.launch(server_port=8821, server_name="0.0.0.0") iface.launch()