File size: 2,134 Bytes
04ef268
7fd88e0
04ef268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fd88e0
04ef268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import spaces
import gradio as gr
from src.util.base import *
from src.util.params import *
from PIL import Image, ImageDraw


def visualize_poke(
    pokeX, pokeY, pokeHeight, pokeWidth, imageHeight=imageHeight, imageWidth=imageWidth
):
    if (
        (pokeX - pokeWidth // 2 < 0)
        or (pokeX + pokeWidth // 2 > imageWidth // 8)
        or (pokeY - pokeHeight // 2 < 0)
        or (pokeY + pokeHeight // 2 > imageHeight // 8)
    ):
        gr.Warning("Modification outside image")
    shape = [
        (pokeX * 8 - pokeWidth * 8 // 2, pokeY * 8 - pokeHeight * 8 // 2),
        (pokeX * 8 + pokeWidth * 8 // 2, pokeY * 8 + pokeHeight * 8 // 2),
    ]

    blank = Image.new("RGB", (imageWidth, imageHeight))

    if os.path.exists("outputs/original.png"):
        oImg = Image.open("outputs/original.png")
        pImg = Image.open("outputs/poked.png")
    else:
        oImg = blank
        pImg = blank

    oRec = ImageDraw.Draw(oImg)
    pRec = ImageDraw.Draw(pImg)

    oRec.rectangle(shape, outline="white")
    pRec.rectangle(shape, outline="white")

    return oImg, pImg

@spaces.GPU(enable_queue=True)
def display_poke_images(
    prompt,
    seed,
    num_inference_steps,
    poke=False,
    pokeX=None,
    pokeY=None,
    pokeHeight=None,
    pokeWidth=None,
    intermediate=False,
    progress=gr.Progress(),
):
    text_embeddings = get_text_embeddings(prompt)
    latents, modified_latents = generate_modified_latents(
        poke, seed, pokeX, pokeY, pokeHeight, pokeWidth
    )

    progress(0)
    images = generate_images(
        latents, text_embeddings, num_inference_steps, intermediate=intermediate
    )

    if not intermediate:
        images.save("outputs/original.png")

    if poke:
        progress(0.5)
        modImages = generate_images(
            modified_latents,
            text_embeddings,
            num_inference_steps,
            intermediate=intermediate,
        )

        if not intermediate:
            modImages.save("outputs/poked.png")
    else:
        modImages = None

    return images, modImages


__all__ = ["display_poke_images", "visualize_poke"]