city-dreamer / app.py
hzxie's picture
Fix CUDA to 12.2.
a52e5e8 verified
raw
history blame contribute delete
No virus
6.94 kB
# -*- coding: utf-8 -*-
#
# @File: app.py
# @Author: Haozhe Xie
# @Date: 2024-03-02 16:30:00
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2024-09-22 21:11:51
# @Email: root@haozhexie.com
import gradio as gr
import logging
import numpy as np
import os
import ssl
import subprocess
import sys
import urllib.request
from PIL import Image
# Reinstall PyTorch with CUDA 11.8 (Default version is 12.1)
# subprocess.call(
# [
# "pip",
# "install",
# "torch==2.2.2",
# "torchvision==0.17.2",
# "--index-url",
# "https://download.pytorch.org/whl/cu118",
# ]
# )
import torch
# Create a dummy decorator for Non-ZeroGPU environments
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
# This is a dummy wrapper that just calls the function.
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
# Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
ssl._create_default_https_context = ssl._create_unverified_context
# Import CityDreamer modules
sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer"))
def _get_output(cmd):
try:
return subprocess.check_output(cmd).decode("utf-8")
except Exception as ex:
logging.exception(ex)
return None
def install_cuda_toolkit():
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
os.environ["CUDA_HOME"] = "/usr/local/cuda"
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
os.environ["CUDA_HOME"],
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
)
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
def setup_runtime_env():
logging.info("Python Version: %s" % _get_output(["python", "--version"]))
logging.info("CUDA Version: %s" % _get_output(["nvcc", "--version"]))
logging.info("GCC Version: %s" % _get_output(["gcc", "--version"]))
logging.info("CUDA is available: %s" % torch.cuda.is_available())
logging.info("CUDA Device Capability: %s" % (torch.cuda.get_device_capability(),))
# Install Pre-compiled CUDA extensions (Not working)
# Ref: https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/110
#
# ext_dir = os.path.join(os.path.dirname(__file__), "wheels")
# for e in os.listdir(ext_dir):
# logging.info("Installing Extensions from %s" % e)
# subprocess.call(
# ["pip", "install", os.path.join(ext_dir, e)], stderr=subprocess.STDOUT
# )
# Compile CUDA extensions
ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions")
for e in os.listdir(ext_dir):
if os.path.isdir(os.path.join(ext_dir, e)):
subprocess.call(["pip", "install", "."], cwd=os.path.join(ext_dir, e))
logging.info("Installed Python Packages: %s" % _get_output(["pip", "list"]))
def get_models(file_name):
import citydreamer.model
if not os.path.exists(file_name):
urllib.request.urlretrieve(
"https://huggingface.co/hzxie/city-dreamer/resolve/main/%s" % file_name,
file_name,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = torch.load(file_name, map_location=torch.device(device))
model = citydreamer.model.GanCraftGenerator(ckpt["cfg"])
if torch.cuda.is_available():
model = torch.nn.DataParallel(model).cuda().eval()
model.load_state_dict(ckpt["gancraft_g"], strict=False)
return model
def get_city_layout():
hf = np.array(Image.open("assets/NYC-HghtFld.png")).astype(np.int32)
seg = np.array(Image.open("assets/NYC-SegMap.png").convert("P")).astype(np.int32)
return hf, seg
@spaces.GPU
def get_generated_city(
radius, altitude, azimuth, map_center, progress=gr.Progress(track_tqdm=True)
):
logging.info("CUDA is available: %s" % torch.cuda.is_available())
logging.info("PyTorch is built with CUDA: %s" % torch.version.cuda)
# The import must be done after CUDA extension compilation
import citydreamer.inference
return citydreamer.inference.generate_city(
get_generated_city.fgm.to("cuda"),
get_generated_city.bgm.to("cuda"),
get_generated_city.hf.copy(),
get_generated_city.seg.copy(),
map_center,
map_center,
radius,
altitude,
azimuth,
)
def main(debug):
title = "CityDreamer Demo 🏙️"
with open("README.md", "r") as f:
markdown = f.read()
desc = markdown[markdown.rfind("---") + 3 :]
with open("ARTICLE.md", "r") as f:
arti = f.read()
app = gr.Interface(
get_generated_city,
[
gr.Slider(128, 512, value=343, step=5, label="Camera Radius (m)"),
gr.Slider(256, 512, value=296, step=5, label="Camera Altitude (m)"),
gr.Slider(0, 360, value=60, step=5, label="Camera Azimuth (°)"),
gr.Slider(1440, 6752, value=3970, step=5, label="Map Center (px)"),
],
[gr.Image(type="numpy", label="Generated City")],
title=title,
description=desc,
article=arti,
allow_flagging="never",
)
app.queue(api_open=False)
app.launch(debug=debug)
if __name__ == "__main__":
logging.basicConfig(
format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
)
logging.info("Environment Variables: %s" % os.environ)
if _get_output(["nvcc", "--version"]) is None:
logging.info("Installing CUDA toolkit...")
install_cuda_toolkit()
else:
logging.info("Detected CUDA: %s" % _get_output(["nvcc", "--version"]))
logging.info("Compiling CUDA extensions...")
setup_runtime_env()
logging.info("Downloading pretrained models...")
fgm = get_models("CityDreamer-Fgnd.pth")
bgm = get_models("CityDreamer-Bgnd.pth")
get_generated_city.fgm = fgm
get_generated_city.bgm = bgm
logging.info("Loading New York city layout to RAM...")
hf, seg = get_city_layout()
get_generated_city.hf = hf
get_generated_city.seg = seg
logging.info("Starting the main application...")
main(os.getenv("DEBUG") == "1")