acanivet commited on
Commit
6ed901c
β€’
1 Parent(s): f22a587

cache + dl

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +32 -6
  3. model.py +14 -9
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
app.py CHANGED
@@ -1,9 +1,18 @@
1
  import streamlit as st
2
  from model import generate
 
3
  import numpy as np
 
 
 
 
 
 
 
 
4
 
5
  if "result" not in st.session_state:
6
- st.session_state["result"] = np.empty(16000*4)
7
 
8
  st.title("Sound Exploration")
9
 
@@ -12,7 +21,7 @@ col1, col2 = st.columns(2)
12
  with col1:
13
  instrument = st.selectbox(
14
  'Which intrument do you want ?',
15
- ('🎸 Bass', '🎺 Brass', 'πŸͺˆ Flute', 'πŸͺ• Guitar', '🎹 Keyboard', 'πŸ”¨ Mallet', 'Organ', 'Reed', '🎻 String', 'Synth lead', 'πŸŽ™οΈ Vocal')
16
  )
17
 
18
  with col2:
@@ -22,11 +31,28 @@ with col2:
22
  )
23
 
24
  with st.expander("Magical parameters πŸͺ„"):
25
- p1 = st.slider('p1', 0., 1., step=0.001)
 
 
 
 
 
 
 
 
 
26
 
27
  if st.button("Generate ✨", type="primary"):
28
- st.session_state["result"] = generate([instrument, instrument_t])
29
 
30
- if st.session_state["result"].any():
31
- st.audio(st.session_state["result"], sample_rate=16000)
 
 
 
 
 
 
 
 
32
 
 
1
  import streamlit as st
2
  from model import generate
3
+ import io
4
  import numpy as np
5
+ from scipy.io.wavfile import write
6
+
7
+ @st.cache_data
8
+ def np_to_wav(waveform, sample_rate) -> bytes:
9
+ bytes_wav = bytes()
10
+ byte_io = io.BytesIO(bytes_wav)
11
+ write(byte_io, sample_rate, waveform.T)
12
+ return byte_io.read()
13
 
14
  if "result" not in st.session_state:
15
+ st.session_state["result"] = None
16
 
17
  st.title("Sound Exploration")
18
 
 
21
  with col1:
22
  instrument = st.selectbox(
23
  'Which intrument do you want ?',
24
+ ('🎸 Bass', '🎺 Brass', 'πŸͺˆ Flute', 'πŸͺ• Guitar', '🎹 Keyboard', 'πŸ”¨ Mallet', 'Organ', '🎷 Reed', '🎻 String', '⚑ Synth lead', '🎀 Vocal')
25
  )
26
 
27
  with col2:
 
31
  )
32
 
33
  with st.expander("Magical parameters πŸͺ„"):
34
+ col1, col2 = st.columns(2)
35
+ with col1:
36
+ p1 = st.slider('p1', 0., 1., step=0.001, label_visibility='collapsed')
37
+ p2 = st.slider('p2', 0., 1., step=0.001, label_visibility='collapsed')
38
+ p3 = st.slider('p3', 0., 1., step=0.001, label_visibility='collapsed')
39
+ with col2:
40
+ p4 = st.slider('p4', 0., 1., step=0.001, label_visibility='collapsed')
41
+ p5 = st.slider('p5', 0., 1., step=0.001, label_visibility='collapsed')
42
+ use_params = st.toggle('Use magical parameters ?')
43
+ params = (p1, p2, p3, p4, p5) if use_params else None
44
 
45
  if st.button("Generate ✨", type="primary"):
46
+ st.session_state["result"] = generate([instrument, instrument_t], params)
47
 
48
+ if st.session_state["result"] is not None:
49
+ col1, col2 = st.columns(2)
50
+ with col1:
51
+ st.audio(st.session_state["result"], sample_rate=16000)
52
+ with col2:
53
+ st.download_button(
54
+ label="Download ⬇️",
55
+ data=np_to_wav(st.session_state["result"], 16000),
56
+ file_name='result.wav',
57
+ )
58
 
model.py CHANGED
@@ -1,20 +1,25 @@
1
  from cvae import CVAE
2
  import torch
3
  from typing import Sequence
 
4
 
5
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
6
 
7
  instruments = ['bass_acoustic', 'brass_acoustic', 'flute_acoustic', 'guitar_acoustic', 'keyboard_acoustic', 'mallet_acoustic', 'organ_acoustic', 'reed_acoustic', 'string_acoustic', 'synth_lead_acoustic', 'vocal_acoustic', 'bass_synthetic', 'brass_synthetic', 'flute_synthetic', 'guitar_synthetic', 'keyboard_synthetic', 'mallet_synthetic', 'organ_synthetic', 'reed_synthetic', 'string_synthetic', 'synth_lead_synthetic', 'vocal_synthetic', 'bass_electronic', 'brass_electronic', 'flute_electronic', 'guitar_electronic', 'keyboard_electronic', 'mallet_electronic', 'organ_electronic', 'reed_electronic', 'string_electronic', 'synth_lead_electronic', 'vocal_electronic']
8
 
9
- model = CVAE.load_from_checkpoint(
10
- 'epoch=17-step=650718.ckpt',
11
- io_channels=1,
12
- io_features=16000*4,
13
- latent_features=5,
14
- channels=[32, 64, 128, 256, 512],
15
- num_classes=len(instruments),
16
- learning_rate=1e-5
17
- ).to(device)
 
 
 
 
18
 
19
  def format(text):
20
  text = text.split(' ')[-1]
 
1
  from cvae import CVAE
2
  import torch
3
  from typing import Sequence
4
+ import streamlit as st
5
 
6
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
 
8
  instruments = ['bass_acoustic', 'brass_acoustic', 'flute_acoustic', 'guitar_acoustic', 'keyboard_acoustic', 'mallet_acoustic', 'organ_acoustic', 'reed_acoustic', 'string_acoustic', 'synth_lead_acoustic', 'vocal_acoustic', 'bass_synthetic', 'brass_synthetic', 'flute_synthetic', 'guitar_synthetic', 'keyboard_synthetic', 'mallet_synthetic', 'organ_synthetic', 'reed_synthetic', 'string_synthetic', 'synth_lead_synthetic', 'vocal_synthetic', 'bass_electronic', 'brass_electronic', 'flute_electronic', 'guitar_electronic', 'keyboard_electronic', 'mallet_electronic', 'organ_electronic', 'reed_electronic', 'string_electronic', 'synth_lead_electronic', 'vocal_electronic']
9
 
10
+ @st.cache_resource
11
+ def load_model(device):
12
+ return CVAE.load_from_checkpoint(
13
+ 'epoch=17-step=650718.ckpt',
14
+ io_channels=1,
15
+ io_features=16000*4,
16
+ latent_features=5,
17
+ channels=[32, 64, 128, 256, 512],
18
+ num_classes=len(instruments),
19
+ learning_rate=1e-5
20
+ ).to(device)
21
+
22
+ model = load_model(device)
23
 
24
  def format(text):
25
  text = text.split(' ')[-1]