CerealBoxMaker / app.py
AP123's picture
fp 16 to allow for T4 inference
1e875da
raw
history blame contribute delete
No virus
3.22 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
import random
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16")
pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora")
pipeline.to("cuda:0")
MAX_SEED = np.iinfo(np.int32).max
def text_to_image(prompt):
seed = random.randint(0, MAX_SEED)
negative_prompt = "ugly, blurry, nsfw, gore, blood"
output = pipeline(prompt=prompt, negative_prompt=negative_prompt, width=1024, height=1024, guidance_scale=7.0, num_inference_steps=25, generator=torch.Generator().manual_seed(seed))
generated_img = output.images[0]
generated_img_array = np.array(generated_img)
return generated_img_array
def create_cereal_box(input_image):
cover_img = Image.fromarray(input_image.astype('uint8'), 'RGB')
template_img = Image.open("template.jpeg")
scaling_factor = 1.5
rect_height = int(template_img.height * 0.32)
new_width = int(rect_height * 0.70)
cover_resized = cover_img.resize((new_width, rect_height), Image.LANCZOS)
new_width_scaled = int(new_width * scaling_factor)
new_height_scaled = int(rect_height * scaling_factor)
cover_resized_scaled = cover_resized.resize((new_width_scaled, new_height_scaled), Image.LANCZOS)
left_x = int(template_img.width * 0.085)
left_y = int((template_img.height - new_height_scaled) // 2 + template_img.height * 0.012)
left_position = (left_x, left_y)
right_x = int(template_img.width * 0.82) - new_width_scaled
right_y = left_y
right_position = (right_x, right_y)
template_copy = template_img.copy()
template_copy.paste(cover_resized_scaled, left_position)
template_copy.paste(cover_resized_scaled, right_position)
template_copy_array = np.array(template_copy)
return template_copy_array
def combined_function(prompt):
generated_img_array = text_to_image(prompt)
final_img = create_cereal_box(generated_img_array)
return final_img
with gr.Blocks() as app:
gr.HTML("<div style='text-align: center;'><h1>Cereal Box Maker 🥣</h1></div>")
gr.HTML("<div style='text-align: center;'><p>This application uses StableDiffusion XL to create any cereal box you could ever imagine!</p></div>")
gr.HTML("<div style='text-align: center;'><h3>Instructions:</h3><ol><li>Describe the cereal box you want to create and hit generate!</li><li>Print it out, cut the outside, fold the lines, and then tape!</li></ol></div>")
gr.HTML("<div style='text-align: center;'><p>A space by AP 🐧, follow me on <a href='https://twitter.com/angrypenguinPNG'>Twitter</a>! H/T to <a href='https://twitter.com/ostrisai'>OstrisAI</a> for their Cereal Box LoRA!</p></div>")
with gr.Row():
textbox = gr.Textbox(label="Describe your cereal box: Ex: 'Avengers Cereal'")
btn_generate = gr.Button("Generate", label="Generate")
with gr.Row():
output_img = gr.Image(label="Your Custom Cereal Box")
btn_generate.click(
combined_function,
inputs=[textbox],
outputs=[output_img]
)
app.queue(max_size=20, api_open=False)
app.launch()