acanivet commited on
Commit
f22a587
1 Parent(s): 333a80f
Files changed (1) hide show
  1. model.py +5 -4
model.py CHANGED
@@ -1,7 +1,8 @@
1
  from cvae import CVAE
2
  import torch
3
  from typing import Sequence
4
- import re
 
5
 
6
  instruments = ['bass_acoustic', 'brass_acoustic', 'flute_acoustic', 'guitar_acoustic', 'keyboard_acoustic', 'mallet_acoustic', 'organ_acoustic', 'reed_acoustic', 'string_acoustic', 'synth_lead_acoustic', 'vocal_acoustic', 'bass_synthetic', 'brass_synthetic', 'flute_synthetic', 'guitar_synthetic', 'keyboard_synthetic', 'mallet_synthetic', 'organ_synthetic', 'reed_synthetic', 'string_synthetic', 'synth_lead_synthetic', 'vocal_synthetic', 'bass_electronic', 'brass_electronic', 'flute_electronic', 'guitar_electronic', 'keyboard_electronic', 'mallet_electronic', 'organ_electronic', 'reed_electronic', 'string_electronic', 'synth_lead_electronic', 'vocal_electronic']
7
 
@@ -13,7 +14,7 @@ model = CVAE.load_from_checkpoint(
13
  channels=[32, 64, 128, 256, 512],
14
  num_classes=len(instruments),
15
  learning_rate=1e-5
16
- )
17
 
18
  def format(text):
19
  text = text.split(' ')[-1]
@@ -24,5 +25,5 @@ def choice_to_tensor(choice: Sequence[str]) -> torch.Tensor:
24
  return torch.tensor(instruments.index(choice))
25
 
26
  def generate(choice: Sequence[str], params: Sequence[int]=None):
27
- noise = torch.tensor(params).unsqueeze(0).to('cuda') if params else torch.randn(1, 5).to('cuda')
28
- return model.sample(eps=noise, c = choice_to_tensor(choice).to('cuda')).cpu().numpy()[0]
 
1
  from cvae import CVAE
2
  import torch
3
  from typing import Sequence
4
+
5
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
6
 
7
  instruments = ['bass_acoustic', 'brass_acoustic', 'flute_acoustic', 'guitar_acoustic', 'keyboard_acoustic', 'mallet_acoustic', 'organ_acoustic', 'reed_acoustic', 'string_acoustic', 'synth_lead_acoustic', 'vocal_acoustic', 'bass_synthetic', 'brass_synthetic', 'flute_synthetic', 'guitar_synthetic', 'keyboard_synthetic', 'mallet_synthetic', 'organ_synthetic', 'reed_synthetic', 'string_synthetic', 'synth_lead_synthetic', 'vocal_synthetic', 'bass_electronic', 'brass_electronic', 'flute_electronic', 'guitar_electronic', 'keyboard_electronic', 'mallet_electronic', 'organ_electronic', 'reed_electronic', 'string_electronic', 'synth_lead_electronic', 'vocal_electronic']
8
 
 
14
  channels=[32, 64, 128, 256, 512],
15
  num_classes=len(instruments),
16
  learning_rate=1e-5
17
+ ).to(device)
18
 
19
  def format(text):
20
  text = text.split(' ')[-1]
 
25
  return torch.tensor(instruments.index(choice))
26
 
27
  def generate(choice: Sequence[str], params: Sequence[int]=None):
28
+ noise = torch.tensor(params).unsqueeze(0).to(device) if params else torch.randn(1, 5).to('cuda')
29
+ return model.sample(eps=noise, c = choice_to_tensor(choice).to(device)).cpu().numpy()[0]