from fastapi import FastAPI, Query, File, UploadFile from fastapi.responses import FileResponse import torch from diffusion import Diffusion # Make sure you import your own modules correctly from utils import get_id_frame, get_audio_emb, save_video # Make sure you import your own modules correctly import shutil from pathlib import Path app = FastAPI() @app.post("/generate_video/") async def generate_video( id_frame_file: UploadFile = File(...), audio_file: UploadFile = File(...), gpu: bool = Query(True, description="Use GPU if available"), id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"), inference_steps: int = Query(100, description="Number of inference diffusion steps"), output: str = Query("/Users/a/Documents/Automations/git talking heads/output_video.mp4", description="Path to save the output video") ): device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu' print('Loading model...') unet = torch.jit.load("/Users/a/Documents/Automations/git talking heads/checkpoints/crema_script.pt") diffusion_args = { "in_channels": 3, "image_size": 128, "out_channels": 6, "n_timesteps": 1000, } diffusion = Diffusion(unet, device, **diffusion_args).to(device) diffusion.space(inference_steps) # Save uploaded files to disk id_frame_path = Path("temp_id_frame.jpg") audio_path = Path("temp_audio.mp3") with id_frame_path.open("wb") as buffer: shutil.copyfileobj(id_frame_file.file, buffer) with audio_path.open("wb") as buffer: shutil.copyfileobj(audio_file.file, buffer) id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device) audio, audio_emb = get_audio_emb(str(audio_path), "/Users/a/Documents/Automations/git talking heads/checkpoints/audio_encoder.pt", device) unet_args = { "n_audio_motion_embs": 2, "n_motion_frames": 2, "motion_channels": 3 } samples = diffusion.sample(id_frame, audio_emb.unsqueeze(0), **unet_args) save_video(output, samples, audio=audio, fps=25, audio_rate=16000) print(f'Results saved at {output}') return FileResponse(output)