waterface / app.py
johnowhitaker's picture
Updated article with more info and link to sketchify space
425dec3
raw
history blame contribute delete
No virus
8.94 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from imstack.core import ImStack
from tqdm.notebook import tqdm
import kornia.augmentation as K
from CLIP import clip
from torchvision import transforms
from PIL import Image
import numpy as np
import math
from matplotlib import pyplot as plt
from fastprogress.fastprogress import master_bar, progress_bar
from IPython.display import HTML
from base64 import b64encode
import warnings
warnings.filterwarnings('ignore') # Some pytorch functions give warnings about behaviour changes that I don't want to see over and over again :)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def sinc(x):
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
def lanczos(x, a):
cond = torch.logical_and(-a < x, x < a)
out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
return out / out.sum()
def ramp(ratio, width):
n = math.ceil(width / ratio + 1)
out = torch.empty([n])
cur = 0
for i in range(out.shape[0]):
out[i] = cur
cur += ratio
return torch.cat([-out[1:].flip([0]), out])[1:-1]
class Prompt(nn.Module):
def __init__(self, embed, weight=1., stop=float('-inf')):
super().__init__()
self.register_buffer('embed', embed)
self.register_buffer('weight', torch.as_tensor(weight))
self.register_buffer('stop', torch.as_tensor(stop))
def forward(self, input):
input_normed = F.normalize(input.unsqueeze(1), dim=2)
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
dists = dists * self.weight.sign()
return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.cut_pow = cut_pow
self.augs = nn.Sequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomSharpness(0.3,p=0.4),
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
K.RandomPerspective(0.2,p=0.4),
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))
self.noise_fac = 0.1
def forward(self, input):
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(self.cutn):
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
batch = self.augs(torch.cat(cutouts, dim=0))
if self.noise_fac:
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
batch = batch + facs * torch.randn_like(batch)
return batch
def resample(input, size, align_corners=True):
n, c, h, w = input.shape
dh, dw = size
input = input.view([n * c, 1, h, w])
if dh < h:
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
pad_h = (kernel_h.shape[0] - 1) // 2
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
input = F.conv2d(input, kernel_h[None, None, :, None])
if dw < w:
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
pad_w = (kernel_w.shape[0] - 1) // 2
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
input = F.conv2d(input, kernel_w[None, None, None, :])
input = input.view([n, c, h, w])
return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
class ReplaceGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x_forward, x_backward):
ctx.shape = x_backward.shape
return x_forward
@staticmethod
def backward(ctx, grad_in):
return None, grad_in.sum_to_size(ctx.shape)
replace_grad = ReplaceGrad.apply
#Load CLOOB model
import sys
sys.path.append('./cloob-training')
sys.path.append('./clip')
# git isn't pulling the submodules for cloob-training so we need to add a path to clip
# I hate this :D
with open('./cloob-training/cloob_training/model_pt.py', 'r+') as f:
content = f.read()
f.seek(0, 0)
f.write("import sys\n" + "sys.path.append('../../../clip')\n" + '\n' + content.replace("import clip", "from CLIP import clip"))
from cloob_training import model_pt, pretrained
config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
cloob = model_pt.get_pt_model(config)
checkpoint = pretrained.download_checkpoint(config)
cloob.load_state_dict(model_pt.get_pt_params(config, checkpoint))
cloob.eval().requires_grad_(False).to(device)
print('done')
# Load fastai model
import gradio as gr
from fastai.vision.all import *
from os.path import exists
import requests
model_fn = 'quick_224px'
url = 'https://huggingface.co/johnowhitaker/sketchy_unet_rn34/resolve/main/quick_224px'
if not exists(model_fn):
print('starting download')
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(model_fn, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print('done')
else:
print('file exists')
def get_x(item):return None
def get_y(item):return None
sketch_model = load_learner(model_fn)
# Cutouts
cutn=16
cut_pow=1
make_cutouts = MakeCutouts(cloob.config['image_encoder']['image_size'], cutn, cut_pow)
def process_im(image_path,
sketchify_first=True,
prompt='A watercolor painting of a face',
lr=0.03,
n_iter=10
):
n_iter = int(n_iter)
pil_im = None
if sketchify_first:
pred = sketch_model.predict(image_path)
np_im = pred[0].permute(1, 2, 0).numpy()
pil_im = Image.fromarray(np_im.astype(np.uint8))
else:
pil_im = Image.open(image_path).resize((540, 540))
prompt_texts = [prompt]
weight_decay=1e-4
out_size=540
base_size=8
n_layers=5
scale=3
layer_decay = 0.3
# The prompts
p_prompts = []
for pr in prompt_texts:
embed = cloob.text_encoder(cloob.tokenize(pr).to(device)).float()
p_prompts.append(Prompt(embed, 1, float('-inf')).to(device)) # 1 is the weight
# Some negative prompts
n_prompts = []
for pr in ["Random noise", 'saturated rainbow RGB deep dream']:
embed = cloob.text_encoder(cloob.tokenize(pr).to(device)).float()
n_prompts.append(Prompt(embed, 0.5, float('-inf')).to(device)) # 0.5 is the weight
# The ImageStack - trying a different scale and n_layers
ims = ImStack(base_size=base_size,
scale=scale,
n_layers=n_layers,
out_size=out_size,
decay=layer_decay,
init_image = pil_im)
# desaturate starting image
desat = 0.6#@param
if desat != 1:
for i in range(n_layers):
ims.layers[i] = ims.layers[i].detach()*desat
ims.layers[i].requires_grad = True
optimizer = optim.Adam(ims.layers, lr=lr, weight_decay=weight_decay)
losses = []
for i in tqdm(range(n_iter)):
optimizer.zero_grad()
im = ims()
batch = cloob.normalize(make_cutouts(im))
iii = cloob.image_encoder(batch).float()
l = 0
for prompt in p_prompts:
l += prompt(iii)
for prompt in n_prompts:
l -= prompt(iii)
losses.append(float(l.detach().cpu()))
l.backward() # Backprop
optimizer.step() # Update
return ims.to_pil()
from gradio.inputs import Checkbox
iface = gr.Interface(fn=process_im,
inputs=[
gr.inputs.Image(label="Input Image", shape=(512, 512), type="filepath"),
gr.inputs.Checkbox(label='Sketchify First', default=True),
gr.inputs.Textbox(default="A charcoal and watercolor sketch of a person", label="Prompt"),
gr.inputs.Number(default=0.03, label='LR'),
gr.inputs.Number(default=10, label='num_steps'),
],
outputs=[gr.outputs.Image(type="pil", label="Model Output")],
title = 'Sketchy ImStack + CLOOB', description = "Stylize an image with ImStack+CLOOB after a Sketchy Unet",
article = "An input image is sketchified with a unet - see https://huggingface.co/spaces/johnowhitaker/sketchy_unet_demo and links from there to training and blog post. It is then loaded into an imstack (https://johnowhitaker.github.io/imstack/) which is optimized towards a CLOOB prompt for n_steps. Feel free to reach me @johnowhitaker with questions :)"
)
iface.launch(enable_queue=True)