from cvae import CVAE import torch from typing import Sequence import streamlit as st from lightning import LightningModule def format_instruments(text: str) -> str: stems = text.split(" ")[1:] stems = [stem.replace(" ", "").lower() for stem in stems] return "_".join(stems) def choice_to_tensor(choice: Sequence[str]) -> torch.Tensor: choice = "_".join([format_instruments(i) for i in choice]) return torch.tensor(instruments.index(choice)) @st.cache_resource def load_model(device: str) -> LightningModule: return CVAE.load_from_checkpoint( "epoch=77-step=2819778.ckpt", io_channels=1, io_features=16000 * 4, latent_features=5, channels=[32, 64, 128, 256, 512], num_classes=len(instruments), learning_rate=1e-5, ).to(device) device = "cuda" if torch.cuda.is_available() else "cpu" instruments = [ "bass_acoustic", "brass_acoustic", "flute_acoustic", "guitar_acoustic", "keyboard_acoustic", "mallet_acoustic", "organ_acoustic", "reed_acoustic", "string_acoustic", "synth_lead_acoustic", "vocal_acoustic", "bass_synthetic", "brass_synthetic", "flute_synthetic", "guitar_synthetic", "keyboard_synthetic", "mallet_synthetic", "organ_synthetic", "reed_synthetic", "string_synthetic", "synth_lead_synthetic", "vocal_synthetic", "bass_electronic", "brass_electronic", "flute_electronic", "guitar_electronic", "keyboard_electronic", "mallet_electronic", "organ_electronic", "reed_electronic", "string_electronic", "synth_lead_electronic", "vocal_electronic", ] model = load_model(device) def generate(choice: Sequence[str], params: Sequence[int] = None): noise = ( torch.tensor(params).unsqueeze(0).to(device) if params else torch.randn(1, 5).to(device) ) return ( model.sample(eps=noise, c=choice_to_tensor(choice).to(device)).cpu().numpy()[0] )