generation / app.py
acanivet's picture
Update app.py
85c68ce
raw
history blame contribute delete
No virus
2.24 kB
import streamlit as st
from model import generate
import io
import numpy as np
from scipy.io.wavfile import write
# -----
# Utils
# -----
@st.cache_data
def np_to_wav(waveform: np.ndarray, sample_rate: int) -> bytes:
bytes_wav = bytes()
byte_io = io.BytesIO(bytes_wav)
write(byte_io, sample_rate, waveform.T)
return byte_io.read()
# ------------------
# App initialization
# ------------------
if "result" not in st.session_state:
st.session_state["result"] = None
# ---
# App
# ---
st.title("Sound Exploration")
col1, col2 = st.columns(2)
with col1:
instrument = st.selectbox(
"Which intrument do you want ?",
(
"🎸 Bass",
"🎺 Brass",
"πŸͺˆ Flute",
"πŸͺ• Guitar",
"🎹 Keyboard",
"πŸ”¨ Mallet",
"πŸͺ— Organ",
"🎷 Reed",
"🎻 String",
"⚑ Synth lead",
"🎀 Vocal",
),
)
with col2:
instrument_t = st.selectbox(
"Which type intrument do you want ?",
("πŸ“― Acoustic", "πŸŽ™οΈ Electronic", "πŸŽ›οΈ Synthetic"),
)
with st.expander("Magical parameters πŸͺ„"):
col1, col2 = st.columns(2)
with col1:
p1 = st.slider("p1", 0.0, 1.0, step=0.001, label_visibility="collapsed")
p2 = st.slider("p2", 0.0, 1.0, step=0.001, label_visibility="collapsed")
p3 = st.slider("p3", 0.0, 1.0, step=0.001, label_visibility="collapsed")
with col2:
p4 = st.slider("p4", 0.0, 1.0, step=0.001, label_visibility="collapsed")
p5 = st.slider("p5", 0.0, 1.0, step=0.001, label_visibility="collapsed")
use_params = st.toggle("Use magical parameters ?")
params = (p1, p2, p3, p4, p5) if use_params else None
if st.button("Generate ✨", type="primary"):
st.session_state["result"] = generate([instrument, instrument_t], params)
if st.session_state["result"] is not None:
col1, col2 = st.columns(2)
with col1:
st.audio(st.session_state["result"], sample_rate=16000)
with col2:
st.download_button(
label="Download ⬇️",
data=np_to_wav(st.session_state["result"], 16000),
file_name="result.wav",
)