File size: 6,943 Bytes
c2164fe
 
 
 
 
 
a52e5e8
c2164fe
 
79df973
8b8b671
79df973
8b8b671
28f074f
8b8b671
79df973
8b8b671
79df973
 
 
c2164fe
a52e5e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b8b671
 
 
79df973
c2164fe
 
28f074f
f40b380
28f074f
f40b380
 
 
28f074f
 
 
1a8fbd0
 
 
 
 
 
 
a52e5e8
1a8fbd0
 
 
 
 
 
 
 
 
 
28f074f
 
 
 
1a8fbd0
 
28f074f
a52e5e8
 
 
1a8fbd0
f40b380
1a8fbd0
 
 
 
 
 
 
 
 
c2164fe
28f074f
 
79df973
 
 
 
 
8b8b671
79df973
 
8b8b671
 
28f074f
 
79df973
8b8b671
79df973
 
e60fd27
79df973
8b8b671
79df973
 
28f074f
 
79df973
c2164fe
 
5d31697
 
 
 
28f074f
 
79df973
 
 
 
28f074f
 
e60fd27
 
 
 
79df973
 
 
 
c2164fe
 
8b8b671
 
c2164fe
 
79df973
8b8b671
 
c2164fe
8b8b671
 
 
3d1b36b
 
 
 
8b8b671
 
 
 
 
 
 
 
 
c2164fe
 
 
8b8b671
 
 
a52e5e8
1a8fbd0
 
 
 
 
 
79df973
f8ec29a
79df973
8b8b671
79df973
 
 
 
 
 
 
 
 
 
8b8b671
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# -*- coding: utf-8 -*-
#
# @File:   app.py
# @Author: Haozhe Xie
# @Date:   2024-03-02 16:30:00
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2024-09-22 21:11:51
# @Email:  root@haozhexie.com

import gradio as gr
import logging
import numpy as np
import os

import ssl
import subprocess
import sys
import urllib.request

from PIL import Image

# Reinstall PyTorch with CUDA 11.8 (Default version is 12.1)
# subprocess.call(
#     [
#         "pip",
#         "install",
#         "torch==2.2.2",
#         "torchvision==0.17.2",
#         "--index-url",
#         "https://download.pytorch.org/whl/cu118",
#     ]
# )
import torch

# Create a dummy decorator for Non-ZeroGPU environments
if os.environ.get("SPACES_ZERO_GPU") is not None:
    import spaces
else:
    class spaces:
        @staticmethod
        def GPU(func):
            # This is a dummy wrapper that just calls the function.
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)

            return wrapper

# Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
ssl._create_default_https_context = ssl._create_unverified_context
# Import CityDreamer modules
sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer"))


def _get_output(cmd):
    try:
        return subprocess.check_output(cmd).decode("utf-8")
    except Exception as ex:
        logging.exception(ex)

    return None


def install_cuda_toolkit():
    # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
    CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
    CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
    subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
    subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
    subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])

    os.environ["CUDA_HOME"] = "/usr/local/cuda"
    os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
    os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
        os.environ["CUDA_HOME"],
        "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
    )
    # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
    os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"


def setup_runtime_env():
    logging.info("Python Version: %s" % _get_output(["python", "--version"]))
    logging.info("CUDA Version: %s" % _get_output(["nvcc", "--version"]))
    logging.info("GCC Version: %s" % _get_output(["gcc", "--version"]))
    logging.info("CUDA is available: %s" % torch.cuda.is_available())
    logging.info("CUDA Device Capability: %s" % (torch.cuda.get_device_capability(),))

    # Install Pre-compiled CUDA extensions (Not working)
    # Ref: https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/110
    #
    # ext_dir = os.path.join(os.path.dirname(__file__), "wheels")
    # for e in os.listdir(ext_dir):
    #     logging.info("Installing Extensions from %s" % e)
    #     subprocess.call(
    #         ["pip", "install", os.path.join(ext_dir, e)], stderr=subprocess.STDOUT
    #     )
    # Compile CUDA extensions
    ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions")
    for e in os.listdir(ext_dir):
        if os.path.isdir(os.path.join(ext_dir, e)):
            subprocess.call(["pip", "install", "."], cwd=os.path.join(ext_dir, e))

    logging.info("Installed Python Packages: %s" % _get_output(["pip", "list"]))


def get_models(file_name):
    import citydreamer.model

    if not os.path.exists(file_name):
        urllib.request.urlretrieve(
            "https://huggingface.co/hzxie/city-dreamer/resolve/main/%s" % file_name,
            file_name,
        )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    ckpt = torch.load(file_name, map_location=torch.device(device))
    model = citydreamer.model.GanCraftGenerator(ckpt["cfg"])
    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda().eval()

    model.load_state_dict(ckpt["gancraft_g"], strict=False)
    return model


def get_city_layout():
    hf = np.array(Image.open("assets/NYC-HghtFld.png")).astype(np.int32)
    seg = np.array(Image.open("assets/NYC-SegMap.png").convert("P")).astype(np.int32)
    return hf, seg


@spaces.GPU
def get_generated_city(
    radius, altitude, azimuth, map_center, progress=gr.Progress(track_tqdm=True)
):
    logging.info("CUDA is available: %s" % torch.cuda.is_available())
    logging.info("PyTorch is built with CUDA: %s" % torch.version.cuda)
    # The import must be done after CUDA extension compilation
    import citydreamer.inference

    return citydreamer.inference.generate_city(
        get_generated_city.fgm.to("cuda"),
        get_generated_city.bgm.to("cuda"),
        get_generated_city.hf.copy(),
        get_generated_city.seg.copy(),
        map_center,
        map_center,
        radius,
        altitude,
        azimuth,
    )


def main(debug):
    title = "CityDreamer Demo 🏙️"
    with open("README.md", "r") as f:
        markdown = f.read()
        desc = markdown[markdown.rfind("---") + 3 :]
    with open("ARTICLE.md", "r") as f:
        arti = f.read()

    app = gr.Interface(
        get_generated_city,
        [
            gr.Slider(128, 512, value=343, step=5, label="Camera Radius (m)"),
            gr.Slider(256, 512, value=296, step=5, label="Camera Altitude (m)"),
            gr.Slider(0, 360, value=60, step=5, label="Camera Azimuth (°)"),
            gr.Slider(1440, 6752, value=3970, step=5, label="Map Center (px)"),
        ],
        [gr.Image(type="numpy", label="Generated City")],
        title=title,
        description=desc,
        article=arti,
        allow_flagging="never",
    )
    app.queue(api_open=False)
    app.launch(debug=debug)


if __name__ == "__main__":
    logging.basicConfig(
        format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
    )
    logging.info("Environment Variables: %s" % os.environ)
    if _get_output(["nvcc", "--version"]) is None:
        logging.info("Installing CUDA toolkit...")
        install_cuda_toolkit()
    else:
        logging.info("Detected CUDA: %s" % _get_output(["nvcc", "--version"]))

    logging.info("Compiling CUDA extensions...")
    setup_runtime_env()

    logging.info("Downloading pretrained models...")
    fgm = get_models("CityDreamer-Fgnd.pth")
    bgm = get_models("CityDreamer-Bgnd.pth")
    get_generated_city.fgm = fgm
    get_generated_city.bgm = bgm

    logging.info("Loading New York city layout to RAM...")
    hf, seg = get_city_layout()
    get_generated_city.hf = hf
    get_generated_city.seg = seg

    logging.info("Starting the main application...")
    main(os.getenv("DEBUG") == "1")