Simonlob commited on
Commit
f645dbd
1 Parent(s): 98ed597

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +309 -0
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from argparse import Namespace
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import soundfile as sf
7
+ import torch
8
+
9
+ from matcha.cli import (
10
+ MATCHA_URLS,
11
+ VOCODER_URLS,
12
+ assert_model_downloaded,
13
+ get_device,
14
+ load_matcha,
15
+ load_vocoder,
16
+ process_text,
17
+ to_waveform,
18
+ )
19
+ from matcha.utils.utils import get_user_data_dir, plot_tensor
20
+
21
+ LOCATION = Path(get_user_data_dir())
22
+
23
+ args = Namespace(
24
+ cpu=True,
25
+ model="akyl_ai",
26
+ vocoder="hifigan_T2_v1",
27
+ spk=0,
28
+ )
29
+
30
+ CURRENTLY_LOADED_MODEL = args.model
31
+
32
+
33
+ def MATCHA_TTS_LOC(x):
34
+ return LOCATION / f"{x}.ckpt"
35
+
36
+
37
+ def VOCODER_LOC(x):
38
+ return LOCATION / f"{x}"
39
+
40
+
41
+ LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
42
+ RADIO_OPTIONS = {
43
+ "Multi Speaker (VCTK)": {
44
+ "model": "matcha_vctk",
45
+ "vocoder": "hifigan_univ_v1",
46
+ },
47
+ "Single Speaker (LJ Speech)": {
48
+ "model": "akyl_ai",
49
+ "vocoder": "hifigan_T2_v1",
50
+ },
51
+ }
52
+
53
+ # Ensure all the required models are downloaded
54
+ assert_model_downloaded(MATCHA_TTS_LOC("akyl_ai"), MATCHA_URLS["akyl_ai"])
55
+ assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"])
56
+ assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"])
57
+ assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"])
58
+
59
+ device = get_device(args)
60
+
61
+ # Load default model
62
+ model = load_matcha(args.model, MATCHA_TTS_LOC(args.model), device)
63
+ vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC(args.vocoder), device)
64
+
65
+
66
+ def load_model(model_name, vocoder_name):
67
+ model = load_matcha(model_name, MATCHA_TTS_LOC(model_name), device)
68
+ vocoder, denoiser = load_vocoder(vocoder_name, VOCODER_LOC(vocoder_name), device)
69
+ return model, vocoder, denoiser
70
+
71
+
72
+ def load_model_ui(model_type, textbox):
73
+ model_name, vocoder_name = RADIO_OPTIONS[model_type]["model"], RADIO_OPTIONS[model_type]["vocoder"]
74
+
75
+ global model, vocoder, denoiser, CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
76
+ if CURRENTLY_LOADED_MODEL != model_name:
77
+ model, vocoder, denoiser = load_model(model_name, vocoder_name)
78
+ CURRENTLY_LOADED_MODEL = model_name
79
+
80
+ if model_name == "akyl_ai":
81
+ spk_slider = gr.update(visible=False, value=-1)
82
+ single_speaker_examples = gr.update(visible=True)
83
+ multi_speaker_examples = gr.update(visible=False)
84
+ length_scale = gr.update(value=0.95)
85
+ else:
86
+ spk_slider = gr.update(visible=True, value=0)
87
+ single_speaker_examples = gr.update(visible=False)
88
+ multi_speaker_examples = gr.update(visible=True)
89
+ length_scale = gr.update(value=0.85)
90
+
91
+ return (
92
+ textbox,
93
+ gr.update(interactive=True),
94
+ spk_slider,
95
+ single_speaker_examples,
96
+ multi_speaker_examples,
97
+ length_scale,
98
+ )
99
+
100
+
101
+ @torch.inference_mode()
102
+ def process_text_gradio(text):
103
+ output = process_text(1, text, device)
104
+ return output["x_phones"][1::2], output["x"], output["x_lengths"]
105
+
106
+
107
+ @torch.inference_mode()
108
+ def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk):
109
+ spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None
110
+ output = model.synthesise(
111
+ text,
112
+ text_length,
113
+ n_timesteps=n_timesteps,
114
+ temperature=temperature,
115
+ spks=spk,
116
+ length_scale=length_scale,
117
+ )
118
+ output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
119
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
120
+ sf.write(fp.name, output["waveform"], 22050, "PCM_24")
121
+
122
+ return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy())
123
+
124
+
125
+ def multispeaker_example_cacher(text, n_timesteps, mel_temp, length_scale, spk):
126
+ global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
127
+ if CURRENTLY_LOADED_MODEL != "matcha_vctk":
128
+ global model, vocoder, denoiser # pylint: disable=global-statement
129
+ model, vocoder, denoiser = load_model("matcha_vctk", "hifigan_univ_v1")
130
+ CURRENTLY_LOADED_MODEL = "matcha_vctk"
131
+
132
+ phones, text, text_lengths = process_text_gradio(text)
133
+ audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
134
+ return phones, audio, mel_spectrogram
135
+
136
+
137
+ def ljspeech_example_cacher(text, n_timesteps, mel_temp, length_scale, spk=-1):
138
+ global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
139
+ if CURRENTLY_LOADED_MODEL != "akyl_ai":
140
+ global model, vocoder, denoiser # pylint: disable=global-statement
141
+ model, vocoder, denoiser = load_model("akyl_ai", "hifigan_T2_v1")
142
+ CURRENTLY_LOADED_MODEL = "akyl_ai"
143
+
144
+ phones, text, text_lengths = process_text_gradio(text)
145
+ audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
146
+ return phones, audio, mel_spectrogram
147
+
148
+
149
+ def main():
150
+ description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching
151
+ ### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)
152
+ We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up ODE-based speech synthesis. Our method:
153
+
154
+
155
+ * Is probabilistic
156
+ * Has compact memory footprint
157
+ * Sounds highly natural
158
+ * Is very fast to synthesise from
159
+
160
+
161
+ Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS). Read our [arXiv preprint for more details](https://arxiv.org/abs/2309.03199).
162
+ Code is available in our [GitHub repository](https://github.com/shivammehta25/Matcha-TTS), along with pre-trained models.
163
+
164
+ Cached examples are available at the bottom of the page.
165
+ """
166
+
167
+ with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo:
168
+ processed_text = gr.State(value=None)
169
+ processed_text_len = gr.State(value=None)
170
+
171
+ with gr.Box():
172
+ with gr.Row():
173
+ gr.Markdown(description, scale=3)
174
+ with gr.Column():
175
+ gr.Image(LOGO_URL, label="Matcha-TTS logo", height=50, width=50, scale=1, show_label=False)
176
+ html = '<br><iframe width="560" height="315" src="https://www.youtube.com/embed/xmvJkz3bqw0?si=jN7ILyDsbPwJCGoa" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>'
177
+ gr.HTML(html)
178
+
179
+ with gr.Box():
180
+ radio_options = list(RADIO_OPTIONS.keys())
181
+ model_type = gr.Radio(
182
+ radio_options, value=radio_options[0], label="Choose a Model", interactive=True, container=False
183
+ )
184
+
185
+ with gr.Row():
186
+ gr.Markdown("# Text Input")
187
+ with gr.Row():
188
+ text = gr.Textbox(value="", lines=2, label="Text to synthesise", scale=3)
189
+ spk_slider = gr.Slider(
190
+ minimum=0, maximum=107, step=1, value=args.spk, label="Speaker ID", interactive=True, scale=1
191
+ )
192
+
193
+ with gr.Row():
194
+ gr.Markdown("### Hyper parameters")
195
+ with gr.Row():
196
+ n_timesteps = gr.Slider(
197
+ label="Number of ODE steps",
198
+ minimum=1,
199
+ maximum=100,
200
+ step=1,
201
+ value=10,
202
+ interactive=True,
203
+ )
204
+ length_scale = gr.Slider(
205
+ label="Length scale (Speaking rate)",
206
+ minimum=0.5,
207
+ maximum=1.5,
208
+ step=0.05,
209
+ value=1.0,
210
+ interactive=True,
211
+ )
212
+ mel_temp = gr.Slider(
213
+ label="Sampling temperature",
214
+ minimum=0.00,
215
+ maximum=2.001,
216
+ step=0.16675,
217
+ value=0.667,
218
+ interactive=True,
219
+ )
220
+
221
+ synth_btn = gr.Button("Synthesise")
222
+
223
+ with gr.Box():
224
+ with gr.Row():
225
+ gr.Markdown("### Phonetised text")
226
+ phonetised_text = gr.Textbox(interactive=False, scale=10, label="Phonetised text")
227
+
228
+ with gr.Box():
229
+ with gr.Row():
230
+ mel_spectrogram = gr.Image(interactive=False, label="mel spectrogram")
231
+
232
+ # with gr.Row():
233
+ audio = gr.Audio(interactive=False, label="Audio")
234
+
235
+ with gr.Row(visible=False) as example_row_lj_speech:
236
+ examples = gr.Examples( # pylint: disable=unused-variable
237
+ examples=[
238
+ [
239
+ "Баарыңарга салам, менин атым Акылай. Мен бардыгын бул жерде Инновация борборунда көргөнүмө абдан кубанычтамын.",
240
+ 50,
241
+ 0.677,
242
+ 0.95,
243
+ ],
244
+ [
245
+ "Мага колдоо көрсөтүп, мени тандагандарга ыраазымын. Айыл үчүн иштейбиз, жол курабыз, асфальт төшөйбүз”, — деген ал.",
246
+ 2,
247
+ 0.677,
248
+ 0.95,
249
+ ],
250
+
251
+
252
+ ],
253
+ fn=ljspeech_example_cacher,
254
+ inputs=[text, n_timesteps, mel_temp, length_scale],
255
+ outputs=[phonetised_text, audio, mel_spectrogram],
256
+ cache_examples=True,
257
+ )
258
+
259
+ with gr.Row() as example_row_multispeaker:
260
+ multi_speaker_examples = gr.Examples( # pylint: disable=unused-variable
261
+ examples=[
262
+ [
263
+ "Hello everyone! I am speaker 0 and I am here to tell you that Matcha-TTS is amazing!",
264
+ 10,
265
+ 0.677,
266
+ 0.85,
267
+ 0,
268
+ ],
269
+ [
270
+ "Hello everyone! I am speaker 16 and I am here to tell you that Matcha-TTS is amazing!",
271
+ 10,
272
+ 0.677,
273
+ 0.85,
274
+ 16,
275
+ ],
276
+
277
+ ],
278
+ fn=multispeaker_example_cacher,
279
+ inputs=[text, n_timesteps, mel_temp, length_scale, spk_slider],
280
+ outputs=[phonetised_text, audio, mel_spectrogram],
281
+ cache_examples=True,
282
+ label="Multi Speaker Examples",
283
+ )
284
+
285
+ model_type.change(lambda x: gr.update(interactive=False), inputs=[synth_btn], outputs=[synth_btn]).then(
286
+ load_model_ui,
287
+ inputs=[model_type, text],
288
+ outputs=[text, synth_btn, spk_slider, example_row_lj_speech, example_row_multispeaker, length_scale],
289
+ )
290
+
291
+ synth_btn.click(
292
+ fn=process_text_gradio,
293
+ inputs=[
294
+ text,
295
+ ],
296
+ outputs=[phonetised_text, processed_text, processed_text_len],
297
+ api_name="matcha_tts",
298
+ queue=True,
299
+ ).then(
300
+ fn=synthesise_mel,
301
+ inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale, spk_slider],
302
+ outputs=[audio, mel_spectrogram],
303
+ )
304
+
305
+ demo.queue().launch(share=True)
306
+
307
+
308
+ if __name__ == "__main__":
309
+ main()