kevinwang676 commited on
Commit
dd6f15d
1 Parent(s): bf21026

Upload 20 files

Browse files
speaker_encoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
speaker_encoder/audio.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from speaker_encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ import numpy as np
6
+ import webrtcvad
7
+ import librosa
8
+ import struct
9
+
10
+ int16_max = (2 ** 15) - 1
11
+
12
+
13
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
14
+ source_sr: Optional[int] = None):
15
+ """
16
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
17
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
18
+
19
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
20
+ just .wav), either the waveform as a numpy array of floats.
21
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
22
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
23
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
24
+ this argument will be ignored.
25
+ """
26
+ # Load the wav from disk if needed
27
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
28
+ wav, source_sr = librosa.load(fpath_or_wav, sr=None)
29
+ else:
30
+ wav = fpath_or_wav
31
+
32
+ # Resample the wav if needed
33
+ if source_sr is not None and source_sr != sampling_rate:
34
+ wav = librosa.resample(wav, source_sr, sampling_rate)
35
+
36
+ # Apply the preprocessing: normalize volume and shorten long silences
37
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
38
+ wav = trim_long_silences(wav)
39
+
40
+ return wav
41
+
42
+
43
+ def wav_to_mel_spectrogram(wav):
44
+ """
45
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
46
+ Note: this not a log-mel spectrogram.
47
+ """
48
+ frames = librosa.feature.melspectrogram(
49
+ y=wav,
50
+ sr=sampling_rate,
51
+ n_fft=int(sampling_rate * mel_window_length / 1000),
52
+ hop_length=int(sampling_rate * mel_window_step / 1000),
53
+ n_mels=mel_n_channels
54
+ )
55
+ return frames.astype(np.float32).T
56
+
57
+
58
+ def trim_long_silences(wav):
59
+ """
60
+ Ensures that segments without voice in the waveform remain no longer than a
61
+ threshold determined by the VAD parameters in params.py.
62
+
63
+ :param wav: the raw waveform as a numpy array of floats
64
+ :return: the same waveform with silences trimmed away (length <= original wav length)
65
+ """
66
+ # Compute the voice detection window size
67
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
68
+
69
+ # Trim the end of the audio to have a multiple of the window size
70
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
71
+
72
+ # Convert the float waveform to 16-bit mono PCM
73
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
74
+
75
+ # Perform voice activation detection
76
+ voice_flags = []
77
+ vad = webrtcvad.Vad(mode=3)
78
+ for window_start in range(0, len(wav), samples_per_window):
79
+ window_end = window_start + samples_per_window
80
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
81
+ sample_rate=sampling_rate))
82
+ voice_flags = np.array(voice_flags)
83
+
84
+ # Smooth the voice detection with a moving average
85
+ def moving_average(array, width):
86
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
87
+ ret = np.cumsum(array_padded, dtype=float)
88
+ ret[width:] = ret[width:] - ret[:-width]
89
+ return ret[width - 1:] / width
90
+
91
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
92
+ audio_mask = np.round(audio_mask).astype(np.bool)
93
+
94
+ # Dilate the voiced regions
95
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
96
+ audio_mask = np.repeat(audio_mask, samples_per_window)
97
+
98
+ return wav[audio_mask == True]
99
+
100
+
101
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
102
+ if increase_only and decrease_only:
103
+ raise ValueError("Both increase only and decrease only are set")
104
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
105
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
106
+ return wav
107
+ return wav * (10 ** (dBFS_change / 20))
speaker_encoder/ckpt/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
speaker_encoder/compute_embed.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder import inference as encoder
2
+ from multiprocessing.pool import Pool
3
+ from functools import partial
4
+ from pathlib import Path
5
+ # from utils import logmmse
6
+ # from tqdm import tqdm
7
+ # import numpy as np
8
+ # import librosa
9
+
10
+
11
+ def embed_utterance(fpaths, encoder_model_fpath):
12
+ if not encoder.is_loaded():
13
+ encoder.load_model(encoder_model_fpath)
14
+
15
+ # Compute the speaker embedding of the utterance
16
+ wav_fpath, embed_fpath = fpaths
17
+ wav = np.load(wav_fpath)
18
+ wav = encoder.preprocess_wav(wav)
19
+ embed = encoder.embed_utterance(wav)
20
+ np.save(embed_fpath, embed, allow_pickle=False)
21
+
22
+
23
+ def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int):
24
+
25
+ wav_dir = outdir_root.joinpath("audio")
26
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
27
+ assert wav_dir.exists() and metadata_fpath.exists()
28
+ embed_dir = synthesizer_root.joinpath("embeds")
29
+ embed_dir.mkdir(exist_ok=True)
30
+
31
+ # Gather the input wave filepath and the target output embed filepath
32
+ with metadata_fpath.open("r") as metadata_file:
33
+ metadata = [line.split("|") for line in metadata_file]
34
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
35
+
36
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
37
+ # Embed the utterances in separate threads
38
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
39
+ job = Pool(n_processes).imap(func, fpaths)
40
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
speaker_encoder/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librispeech_datasets = {
2
+ "train": {
3
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4
+ "other": ["LibriSpeech/train-other-500"]
5
+ },
6
+ "test": {
7
+ "clean": ["LibriSpeech/test-clean"],
8
+ "other": ["LibriSpeech/test-other"]
9
+ },
10
+ "dev": {
11
+ "clean": ["LibriSpeech/dev-clean"],
12
+ "other": ["LibriSpeech/dev-other"]
13
+ },
14
+ }
15
+ libritts_datasets = {
16
+ "train": {
17
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18
+ "other": ["LibriTTS/train-other-500"]
19
+ },
20
+ "test": {
21
+ "clean": ["LibriTTS/test-clean"],
22
+ "other": ["LibriTTS/test-other"]
23
+ },
24
+ "dev": {
25
+ "clean": ["LibriTTS/dev-clean"],
26
+ "other": ["LibriTTS/dev-other"]
27
+ },
28
+ }
29
+ voxceleb_datasets = {
30
+ "voxceleb1" : {
31
+ "train": ["VoxCeleb1/wav"],
32
+ "test": ["VoxCeleb1/test_wav"]
33
+ },
34
+ "voxceleb2" : {
35
+ "train": ["VoxCeleb2/dev/aac"],
36
+ "test": ["VoxCeleb2/test_wav"]
37
+ }
38
+ }
39
+
40
+ other_datasets = [
41
+ "LJSpeech-1.1",
42
+ "VCTK-Corpus/wav48",
43
+ ]
44
+
45
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
speaker_encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
speaker_encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
speaker_encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.random_cycler import RandomCycler
2
+ from speaker_encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
speaker_encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from speaker_encoder.data_objects.speaker import Speaker
4
+
5
+ class SpeakerBatch:
6
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
7
+ self.speakers = speakers
8
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
9
+
10
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
11
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
12
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
speaker_encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.random_cycler import RandomCycler
2
+ from speaker_encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from speaker_encoder.data_objects.speaker import Speaker
4
+ from speaker_encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
speaker_encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
speaker_encoder/hparams.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Mel-filterbank
2
+ mel_window_length = 25 # In milliseconds
3
+ mel_window_step = 10 # In milliseconds
4
+ mel_n_channels = 40
5
+
6
+
7
+ ## Audio
8
+ sampling_rate = 16000
9
+ # Number of spectrogram frames in a partial utterance
10
+ partials_n_frames = 160 # 1600 ms
11
+
12
+
13
+ ## Voice Activation Detection
14
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
15
+ # This sets the granularity of the VAD. Should not need to be changed.
16
+ vad_window_length = 30 # In milliseconds
17
+ # Number of frames to average together when performing the moving average smoothing.
18
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
19
+ vad_moving_average_width = 8
20
+ # Maximum number of consecutive silent frames a segment can have.
21
+ vad_max_silence_length = 6
22
+
23
+
24
+ ## Audio volume normalization
25
+ audio_norm_target_dBFS = -30
26
+
27
+
28
+ ## Model parameters
29
+ model_hidden_size = 256
30
+ model_embedding_size = 256
31
+ model_num_layers = 3
speaker_encoder/inference.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.params_data import *
2
+ from speaker_encoder.model import SpeakerEncoder
3
+ from speaker_encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from speaker_encoder import audio
6
+ from pathlib import Path
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+
11
+ _model = None # type: SpeakerEncoder
12
+ _device = None # type: torch.device
13
+
14
+
15
+ def load_model(weights_fpath: Path, device=None):
16
+ """
17
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
18
+ first call to embed_frames() with the default weights file.
19
+
20
+ :param weights_fpath: the path to saved model weights.
21
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
22
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
23
+ If None, will default to your GPU if it"s available, otherwise your CPU.
24
+ """
25
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
26
+ # was saved on. Worth investigating.
27
+ global _model, _device
28
+ if device is None:
29
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ elif isinstance(device, str):
31
+ _device = torch.device(device)
32
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
33
+ checkpoint = torch.load(weights_fpath)
34
+ _model.load_state_dict(checkpoint["model_state"])
35
+ _model.eval()
36
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
37
+
38
+
39
+ def is_loaded():
40
+ return _model is not None
41
+
42
+
43
+ def embed_frames_batch(frames_batch):
44
+ """
45
+ Computes embeddings for a batch of mel spectrogram.
46
+
47
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
48
+ (batch_size, n_frames, n_channels)
49
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
50
+ """
51
+ if _model is None:
52
+ raise Exception("Model was not loaded. Call load_model() before inference.")
53
+
54
+ frames = torch.from_numpy(frames_batch).to(_device)
55
+ embed = _model.forward(frames).detach().cpu().numpy()
56
+ return embed
57
+
58
+
59
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
60
+ min_pad_coverage=0.75, overlap=0.5):
61
+ """
62
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
63
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
64
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
65
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
66
+ defined in params_data.py.
67
+
68
+ The returned ranges may be indexing further than the length of the waveform. It is
69
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
70
+
71
+ :param n_samples: the number of samples in the waveform
72
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
73
+ utterance
74
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
75
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
76
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
77
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
78
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
79
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
80
+ utterances are entirely disjoint.
81
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
82
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
83
+ utterances.
84
+ """
85
+ assert 0 <= overlap < 1
86
+ assert 0 < min_pad_coverage <= 1
87
+
88
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
89
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
90
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
91
+
92
+ # Compute the slices
93
+ wav_slices, mel_slices = [], []
94
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
95
+ for i in range(0, steps, frame_step):
96
+ mel_range = np.array([i, i + partial_utterance_n_frames])
97
+ wav_range = mel_range * samples_per_frame
98
+ mel_slices.append(slice(*mel_range))
99
+ wav_slices.append(slice(*wav_range))
100
+
101
+ # Evaluate whether extra padding is warranted or not
102
+ last_wav_range = wav_slices[-1]
103
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
104
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
105
+ mel_slices = mel_slices[:-1]
106
+ wav_slices = wav_slices[:-1]
107
+
108
+ return wav_slices, mel_slices
109
+
110
+
111
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
112
+ """
113
+ Computes an embedding for a single utterance.
114
+
115
+ # TODO: handle multiple wavs to benefit from batching on GPU
116
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
117
+ :param using_partials: if True, then the utterance is split in partial utterances of
118
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
119
+ normalized average. If False, the utterance is instead computed from feeding the entire
120
+ spectogram to the network.
121
+ :param return_partials: if True, the partial embeddings will also be returned along with the
122
+ wav slices that correspond to the partial embeddings.
123
+ :param kwargs: additional arguments to compute_partial_splits()
124
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
125
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
126
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
127
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
128
+ instead.
129
+ """
130
+ # Process the entire utterance if not using partials
131
+ if not using_partials:
132
+ frames = audio.wav_to_mel_spectrogram(wav)
133
+ embed = embed_frames_batch(frames[None, ...])[0]
134
+ if return_partials:
135
+ return embed, None, None
136
+ return embed
137
+
138
+ # Compute where to split the utterance into partials and pad if necessary
139
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
140
+ max_wave_length = wave_slices[-1].stop
141
+ if max_wave_length >= len(wav):
142
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
143
+
144
+ # Split the utterance into partials
145
+ frames = audio.wav_to_mel_spectrogram(wav)
146
+ frames_batch = np.array([frames[s] for s in mel_slices])
147
+ partial_embeds = embed_frames_batch(frames_batch)
148
+
149
+ # Compute the utterance embedding from the partial embeddings
150
+ raw_embed = np.mean(partial_embeds, axis=0)
151
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
152
+
153
+ if return_partials:
154
+ return embed, partial_embeds, wave_slices
155
+ return embed
156
+
157
+
158
+ def embed_speaker(wavs, **kwargs):
159
+ raise NotImplemented()
160
+
161
+
162
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
163
+ if ax is None:
164
+ ax = plt.gca()
165
+
166
+ if shape is None:
167
+ height = int(np.sqrt(len(embed)))
168
+ shape = (height, -1)
169
+ embed = embed.reshape(shape)
170
+
171
+ cmap = cm.get_cmap()
172
+ mappable = ax.imshow(embed, cmap=cmap)
173
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174
+ cbar.set_clim(*color_range)
175
+
176
+ ax.set_xticks([]), ax.set_yticks([])
177
+ ax.set_title(title)
speaker_encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.params_model import *
2
+ from speaker_encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels, # 40
19
+ hidden_size=model_hidden_size, # 256
20
+ num_layers=model_num_layers, # 3
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
speaker_encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+
speaker_encoder/params_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model parameters
3
+ model_hidden_size = 256
4
+ model_embedding_size = 256
5
+ model_num_layers = 3
6
+
7
+
8
+ ## Training parameters
9
+ learning_rate_init = 1e-4
10
+ speakers_per_batch = 64
11
+ utterances_per_speaker = 10
speaker_encoder/preprocess.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocess.pool import ThreadPool
2
+ from speaker_encoder.params_data import *
3
+ from speaker_encoder.config import librispeech_datasets, anglophone_nationalites
4
+ from datetime import datetime
5
+ from speaker_encoder import audio
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+
10
+
11
+ class DatasetLog:
12
+ """
13
+ Registers metadata about the dataset in a text file.
14
+ """
15
+ def __init__(self, root, name):
16
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
17
+ self.sample_data = dict()
18
+
19
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
20
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
21
+ self.write_line("-----")
22
+ self._log_params()
23
+
24
+ def _log_params(self):
25
+ from speaker_encoder import params_data
26
+ self.write_line("Parameter values:")
27
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
28
+ value = getattr(params_data, param_name)
29
+ self.write_line("\t%s: %s" % (param_name, value))
30
+ self.write_line("-----")
31
+
32
+ def write_line(self, line):
33
+ self.text_file.write("%s\n" % line)
34
+
35
+ def add_sample(self, **kwargs):
36
+ for param_name, value in kwargs.items():
37
+ if not param_name in self.sample_data:
38
+ self.sample_data[param_name] = []
39
+ self.sample_data[param_name].append(value)
40
+
41
+ def finalize(self):
42
+ self.write_line("Statistics:")
43
+ for param_name, values in self.sample_data.items():
44
+ self.write_line("\t%s:" % param_name)
45
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
46
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
47
+ self.write_line("-----")
48
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
49
+ self.write_line("Finished on %s" % end_time)
50
+ self.text_file.close()
51
+
52
+
53
+ def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
54
+ dataset_root = datasets_root.joinpath(dataset_name)
55
+ if not dataset_root.exists():
56
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
57
+ return None, None
58
+ return dataset_root, DatasetLog(out_dir, dataset_name)
59
+
60
+
61
+ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
62
+ skip_existing, logger):
63
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
64
+
65
+ # Function to preprocess utterances for one speaker
66
+ def preprocess_speaker(speaker_dir: Path):
67
+ # Give a name to the speaker that includes its dataset
68
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
+
70
+ # Create an output directory with that name, as well as a txt file containing a
71
+ # reference to each source file.
72
+ speaker_out_dir = out_dir.joinpath(speaker_name)
73
+ speaker_out_dir.mkdir(exist_ok=True)
74
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
+
76
+ # There's a possibility that the preprocessing was interrupted earlier, check if
77
+ # there already is a sources file.
78
+ if sources_fpath.exists():
79
+ try:
80
+ with sources_fpath.open("r") as sources_file:
81
+ existing_fnames = {line.split(",")[0] for line in sources_file}
82
+ except:
83
+ existing_fnames = {}
84
+ else:
85
+ existing_fnames = {}
86
+
87
+ # Gather all audio files for that speaker recursively
88
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
89
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
90
+ # Check if the target output file already exists
91
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
92
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
93
+ if skip_existing and out_fname in existing_fnames:
94
+ continue
95
+
96
+ # Load and preprocess the waveform
97
+ wav = audio.preprocess_wav(in_fpath)
98
+ if len(wav) == 0:
99
+ continue
100
+
101
+ # Create the mel spectrogram, discard those that are too short
102
+ frames = audio.wav_to_mel_spectrogram(wav)
103
+ if len(frames) < partials_n_frames:
104
+ continue
105
+
106
+ out_fpath = speaker_out_dir.joinpath(out_fname)
107
+ np.save(out_fpath, frames)
108
+ logger.add_sample(duration=len(wav) / sampling_rate)
109
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
110
+
111
+ sources_file.close()
112
+
113
+ # Process the utterances for each speaker
114
+ with ThreadPool(8) as pool:
115
+ list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
116
+ unit="speakers"))
117
+ logger.finalize()
118
+ print("Done preprocessing %s.\n" % dataset_name)
119
+
120
+
121
+ # Function to preprocess utterances for one speaker
122
+ def __preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, extension: str, skip_existing: bool):
123
+ # Give a name to the speaker that includes its dataset
124
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
125
+
126
+ # Create an output directory with that name, as well as a txt file containing a
127
+ # reference to each source file.
128
+ speaker_out_dir = out_dir.joinpath(speaker_name)
129
+ speaker_out_dir.mkdir(exist_ok=True)
130
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
131
+
132
+ # There's a possibility that the preprocessing was interrupted earlier, check if
133
+ # there already is a sources file.
134
+ # if sources_fpath.exists():
135
+ # try:
136
+ # with sources_fpath.open("r") as sources_file:
137
+ # existing_fnames = {line.split(",")[0] for line in sources_file}
138
+ # except:
139
+ # existing_fnames = {}
140
+ # else:
141
+ # existing_fnames = {}
142
+ existing_fnames = {}
143
+ # Gather all audio files for that speaker recursively
144
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
145
+
146
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
147
+ # Check if the target output file already exists
148
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
149
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
150
+ if skip_existing and out_fname in existing_fnames:
151
+ continue
152
+
153
+ # Load and preprocess the waveform
154
+ wav = audio.preprocess_wav(in_fpath)
155
+ if len(wav) == 0:
156
+ continue
157
+
158
+ # Create the mel spectrogram, discard those that are too short
159
+ frames = audio.wav_to_mel_spectrogram(wav)
160
+ if len(frames) < partials_n_frames:
161
+ continue
162
+
163
+ out_fpath = speaker_out_dir.joinpath(out_fname)
164
+ np.save(out_fpath, frames)
165
+ # logger.add_sample(duration=len(wav) / sampling_rate)
166
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
167
+
168
+ sources_file.close()
169
+ return len(wav)
170
+
171
+ def _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
172
+ skip_existing, logger):
173
+ # from multiprocessing import Pool, cpu_count
174
+ from pathos.multiprocessing import ProcessingPool as Pool
175
+ # Function to preprocess utterances for one speaker
176
+ def __preprocess_speaker(speaker_dir: Path):
177
+ # Give a name to the speaker that includes its dataset
178
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
179
+
180
+ # Create an output directory with that name, as well as a txt file containing a
181
+ # reference to each source file.
182
+ speaker_out_dir = out_dir.joinpath(speaker_name)
183
+ speaker_out_dir.mkdir(exist_ok=True)
184
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
185
+
186
+ existing_fnames = {}
187
+ # Gather all audio files for that speaker recursively
188
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
189
+ wav_lens = []
190
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
191
+ # Check if the target output file already exists
192
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
193
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
194
+ if skip_existing and out_fname in existing_fnames:
195
+ continue
196
+
197
+ # Load and preprocess the waveform
198
+ wav = audio.preprocess_wav(in_fpath)
199
+ if len(wav) == 0:
200
+ continue
201
+
202
+ # Create the mel spectrogram, discard those that are too short
203
+ frames = audio.wav_to_mel_spectrogram(wav)
204
+ if len(frames) < partials_n_frames:
205
+ continue
206
+
207
+ out_fpath = speaker_out_dir.joinpath(out_fname)
208
+ np.save(out_fpath, frames)
209
+ # logger.add_sample(duration=len(wav) / sampling_rate)
210
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
211
+ wav_lens.append(len(wav))
212
+ sources_file.close()
213
+ return wav_lens
214
+
215
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
216
+ # Process the utterances for each speaker
217
+ # with ThreadPool(8) as pool:
218
+ # list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
219
+ # unit="speakers"))
220
+ pool = Pool(processes=20)
221
+ for i, wav_lens in enumerate(pool.map(__preprocess_speaker, speaker_dirs), 1):
222
+ for wav_len in wav_lens:
223
+ logger.add_sample(duration=wav_len / sampling_rate)
224
+ print(f'{i}/{len(speaker_dirs)} \r')
225
+
226
+ logger.finalize()
227
+ print("Done preprocessing %s.\n" % dataset_name)
228
+
229
+
230
+ def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
231
+ for dataset_name in librispeech_datasets["train"]["other"]:
232
+ # Initialize the preprocessing
233
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
234
+ if not dataset_root:
235
+ return
236
+
237
+ # Preprocess all speakers
238
+ speaker_dirs = list(dataset_root.glob("*"))
239
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
240
+ skip_existing, logger)
241
+
242
+
243
+ def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
244
+ # Initialize the preprocessing
245
+ dataset_name = "VoxCeleb1"
246
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
247
+ if not dataset_root:
248
+ return
249
+
250
+ # Get the contents of the meta file
251
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
252
+ metadata = [line.split("\t") for line in metafile][1:]
253
+
254
+ # Select the ID and the nationality, filter out non-anglophone speakers
255
+ nationalities = {line[0]: line[3] for line in metadata}
256
+ # keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
257
+ # nationality.lower() in anglophone_nationalites]
258
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items()]
259
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
260
+ (len(keep_speaker_ids), len(nationalities)))
261
+
262
+ # Get the speaker directories for anglophone speakers only
263
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
264
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
265
+ speaker_dir.name in keep_speaker_ids]
266
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
267
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
268
+
269
+ # Preprocess all speakers
270
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
271
+ skip_existing, logger)
272
+
273
+
274
+ def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
275
+ # Initialize the preprocessing
276
+ dataset_name = "VoxCeleb2"
277
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
278
+ if not dataset_root:
279
+ return
280
+
281
+ # Get the speaker directories
282
+ # Preprocess all speakers
283
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
284
+ _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
285
+ skip_existing, logger)
speaker_encoder/train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.visualizations import Visualizations
2
+ from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
3
+ from speaker_encoder.params_model import *
4
+ from speaker_encoder.model import SpeakerEncoder
5
+ from utils.profiler import Profiler
6
+ from pathlib import Path
7
+ import torch
8
+
9
+ def sync(device: torch.device):
10
+ # FIXME
11
+ return
12
+ # For correct profiling (cuda operations are async)
13
+ if device.type == "cuda":
14
+ torch.cuda.synchronize(device)
15
+
16
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
17
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
18
+ no_visdom: bool):
19
+ # Create a dataset and a dataloader
20
+ dataset = SpeakerVerificationDataset(clean_data_root)
21
+ loader = SpeakerVerificationDataLoader(
22
+ dataset,
23
+ speakers_per_batch, # 64
24
+ utterances_per_speaker, # 10
25
+ num_workers=8,
26
+ )
27
+
28
+ # Setup the device on which to run the forward pass and the loss. These can be different,
29
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
30
+ # hyperparameters) faster on the CPU.
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ # FIXME: currently, the gradient is None if loss_device is cuda
33
+ loss_device = torch.device("cpu")
34
+
35
+ # Create the model and the optimizer
36
+ model = SpeakerEncoder(device, loss_device)
37
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
38
+ init_step = 1
39
+
40
+ # Configure file path for the model
41
+ state_fpath = models_dir.joinpath(run_id + ".pt")
42
+ backup_dir = models_dir.joinpath(run_id + "_backups")
43
+
44
+ # Load any existing model
45
+ if not force_restart:
46
+ if state_fpath.exists():
47
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
48
+ checkpoint = torch.load(state_fpath)
49
+ init_step = checkpoint["step"]
50
+ model.load_state_dict(checkpoint["model_state"])
51
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
52
+ optimizer.param_groups[0]["lr"] = learning_rate_init
53
+ else:
54
+ print("No model \"%s\" found, starting training from scratch." % run_id)
55
+ else:
56
+ print("Starting the training from scratch.")
57
+ model.train()
58
+
59
+ # Initialize the visualization environment
60
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
61
+ vis.log_dataset(dataset)
62
+ vis.log_params()
63
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
64
+ vis.log_implementation({"Device": device_name})
65
+
66
+ # Training loop
67
+ profiler = Profiler(summarize_every=10, disabled=False)
68
+ for step, speaker_batch in enumerate(loader, init_step):
69
+ profiler.tick("Blocking, waiting for batch (threaded)")
70
+
71
+ # Forward pass
72
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
73
+ sync(device)
74
+ profiler.tick("Data to %s" % device)
75
+ embeds = model(inputs)
76
+ sync(device)
77
+ profiler.tick("Forward pass")
78
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
79
+ loss, eer = model.loss(embeds_loss)
80
+ sync(loss_device)
81
+ profiler.tick("Loss")
82
+
83
+ # Backward pass
84
+ model.zero_grad()
85
+ loss.backward()
86
+ profiler.tick("Backward pass")
87
+ model.do_gradient_ops()
88
+ optimizer.step()
89
+ profiler.tick("Parameter update")
90
+
91
+ # Update visualizations
92
+ # learning_rate = optimizer.param_groups[0]["lr"]
93
+ vis.update(loss.item(), eer, step)
94
+
95
+ # Draw projections and save them to the backup folder
96
+ if umap_every != 0 and step % umap_every == 0:
97
+ print("Drawing and saving projections (step %d)" % step)
98
+ backup_dir.mkdir(exist_ok=True)
99
+ projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
100
+ embeds = embeds.detach().cpu().numpy()
101
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
102
+ vis.save()
103
+
104
+ # Overwrite the latest version of the model
105
+ if save_every != 0 and step % save_every == 0:
106
+ print("Saving the model (step %d)" % step)
107
+ torch.save({
108
+ "step": step + 1,
109
+ "model_state": model.state_dict(),
110
+ "optimizer_state": optimizer.state_dict(),
111
+ }, state_fpath)
112
+
113
+ # Make a backup
114
+ if backup_every != 0 and step % backup_every == 0:
115
+ print("Making a backup (step %d)" % step)
116
+ backup_dir.mkdir(exist_ok=True)
117
+ backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
118
+ torch.save({
119
+ "step": step + 1,
120
+ "model_state": model.state_dict(),
121
+ "optimizer_state": optimizer.state_dict(),
122
+ }, backup_fpath)
123
+
124
+ profiler.tick("Extras (visualizations, saving)")
125
+
speaker_encoder/visualizations.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from datetime import datetime
3
+ from time import perf_counter as timer
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ # import webbrowser
7
+ import visdom
8
+ import umap
9
+
10
+ colormap = np.array([
11
+ [76, 255, 0],
12
+ [0, 127, 70],
13
+ [255, 0, 0],
14
+ [255, 217, 38],
15
+ [0, 135, 255],
16
+ [165, 0, 165],
17
+ [255, 167, 255],
18
+ [0, 255, 255],
19
+ [255, 96, 38],
20
+ [142, 76, 0],
21
+ [33, 0, 127],
22
+ [0, 0, 0],
23
+ [183, 183, 183],
24
+ ], dtype=np.float) / 255
25
+
26
+
27
+ class Visualizations:
28
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
29
+ # Tracking data
30
+ self.last_update_timestamp = timer()
31
+ self.update_every = update_every
32
+ self.step_times = []
33
+ self.losses = []
34
+ self.eers = []
35
+ print("Updating the visualizations every %d steps." % update_every)
36
+
37
+ # If visdom is disabled TODO: use a better paradigm for that
38
+ self.disabled = disabled
39
+ if self.disabled:
40
+ return
41
+
42
+ # Set the environment name
43
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
44
+ if env_name is None:
45
+ self.env_name = now
46
+ else:
47
+ self.env_name = "%s (%s)" % (env_name, now)
48
+
49
+ # Connect to visdom and open the corresponding window in the browser
50
+ try:
51
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
52
+ except ConnectionError:
53
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
54
+ "start it.")
55
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
56
+
57
+ # Create the windows
58
+ self.loss_win = None
59
+ self.eer_win = None
60
+ # self.lr_win = None
61
+ self.implementation_win = None
62
+ self.projection_win = None
63
+ self.implementation_string = ""
64
+
65
+ def log_params(self):
66
+ if self.disabled:
67
+ return
68
+ from speaker_encoder import params_data
69
+ from speaker_encoder import params_model
70
+ param_string = "<b>Model parameters</b>:<br>"
71
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
72
+ value = getattr(params_model, param_name)
73
+ param_string += "\t%s: %s<br>" % (param_name, value)
74
+ param_string += "<b>Data parameters</b>:<br>"
75
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
76
+ value = getattr(params_data, param_name)
77
+ param_string += "\t%s: %s<br>" % (param_name, value)
78
+ self.vis.text(param_string, opts={"title": "Parameters"})
79
+
80
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
81
+ if self.disabled:
82
+ return
83
+ dataset_string = ""
84
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
85
+ dataset_string += "\n" + dataset.get_logs()
86
+ dataset_string = dataset_string.replace("\n", "<br>")
87
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
88
+
89
+ def log_implementation(self, params):
90
+ if self.disabled:
91
+ return
92
+ implementation_string = ""
93
+ for param, value in params.items():
94
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
95
+ implementation_string = implementation_string.replace("\n", "<br>")
96
+ self.implementation_string = implementation_string
97
+ self.implementation_win = self.vis.text(
98
+ implementation_string,
99
+ opts={"title": "Training implementation"}
100
+ )
101
+
102
+ def update(self, loss, eer, step):
103
+ # Update the tracking data
104
+ now = timer()
105
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
106
+ self.last_update_timestamp = now
107
+ self.losses.append(loss)
108
+ self.eers.append(eer)
109
+ print(".", end="")
110
+
111
+ # Update the plots every <update_every> steps
112
+ if step % self.update_every != 0:
113
+ return
114
+ time_string = "Step time: mean: %5dms std: %5dms" % \
115
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
116
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
117
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
118
+ if not self.disabled:
119
+ self.loss_win = self.vis.line(
120
+ [np.mean(self.losses)],
121
+ [step],
122
+ win=self.loss_win,
123
+ update="append" if self.loss_win else None,
124
+ opts=dict(
125
+ legend=["Avg. loss"],
126
+ xlabel="Step",
127
+ ylabel="Loss",
128
+ title="Loss",
129
+ )
130
+ )
131
+ self.eer_win = self.vis.line(
132
+ [np.mean(self.eers)],
133
+ [step],
134
+ win=self.eer_win,
135
+ update="append" if self.eer_win else None,
136
+ opts=dict(
137
+ legend=["Avg. EER"],
138
+ xlabel="Step",
139
+ ylabel="EER",
140
+ title="Equal error rate"
141
+ )
142
+ )
143
+ if self.implementation_win is not None:
144
+ self.vis.text(
145
+ self.implementation_string + ("<b>%s</b>" % time_string),
146
+ win=self.implementation_win,
147
+ opts={"title": "Training implementation"},
148
+ )
149
+
150
+ # Reset the tracking
151
+ self.losses.clear()
152
+ self.eers.clear()
153
+ self.step_times.clear()
154
+
155
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
156
+ max_speakers=10):
157
+ max_speakers = min(max_speakers, len(colormap))
158
+ embeds = embeds[:max_speakers * utterances_per_speaker]
159
+
160
+ n_speakers = len(embeds) // utterances_per_speaker
161
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
162
+ colors = [colormap[i] for i in ground_truth]
163
+
164
+ reducer = umap.UMAP()
165
+ projected = reducer.fit_transform(embeds)
166
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
167
+ plt.gca().set_aspect("equal", "datalim")
168
+ plt.title("UMAP projection (step %d)" % step)
169
+ if not self.disabled:
170
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
171
+ if out_fpath is not None:
172
+ plt.savefig(out_fpath)
173
+ plt.clf()
174
+
175
+ def save(self):
176
+ if not self.disabled:
177
+ self.vis.save([self.env_name])
178
+
speaker_encoder/voice_encoder.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.hparams import *
2
+ from speaker_encoder import audio
3
+ from pathlib import Path
4
+ from typing import Union, List
5
+ from torch import nn
6
+ from time import perf_counter as timer
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ class SpeakerEncoder(nn.Module):
12
+ def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose=True):
13
+ """
14
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
15
+ If None, defaults to cuda if it is available on your machine, otherwise the model will
16
+ run on cpu. Outputs are always returned on the cpu, as numpy arrays.
17
+ """
18
+ super().__init__()
19
+
20
+ # Define the network
21
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
22
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
23
+ self.relu = nn.ReLU()
24
+
25
+ # Get the target device
26
+ if device is None:
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ elif isinstance(device, str):
29
+ device = torch.device(device)
30
+ self.device = device
31
+
32
+ # Load the pretrained model'speaker weights
33
+ # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
34
+ # if not weights_fpath.exists():
35
+ # raise Exception("Couldn't find the voice encoder pretrained model at %s." %
36
+ # weights_fpath)
37
+
38
+ start = timer()
39
+ checkpoint = torch.load(weights_fpath, map_location="cpu")
40
+
41
+ self.load_state_dict(checkpoint["model_state"], strict=False)
42
+ self.to(device)
43
+
44
+ if verbose:
45
+ print("Loaded the voice encoder model on %s in %.2f seconds." %
46
+ (device.type, timer() - start))
47
+
48
+ def forward(self, mels: torch.FloatTensor):
49
+ """
50
+ Computes the embeddings of a batch of utterance spectrograms.
51
+ :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape
52
+ (batch_size, n_frames, n_channels)
53
+ :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size).
54
+ Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
55
+ """
56
+ # Pass the input through the LSTM layers and retrieve the final hidden state of the last
57
+ # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
58
+ _, (hidden, _) = self.lstm(mels)
59
+ embeds_raw = self.relu(self.linear(hidden[-1]))
60
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
61
+
62
+ @staticmethod
63
+ def compute_partial_slices(n_samples: int, rate, min_coverage):
64
+ """
65
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to
66
+ obtain partial utterances of <partials_n_frames> each. Both the waveform and the
67
+ mel spectrogram slices are returned, so as to make each partial utterance waveform
68
+ correspond to its spectrogram.
69
+
70
+ The returned ranges may be indexing further than the length of the waveform. It is
71
+ recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
72
+
73
+ :param n_samples: the number of samples in the waveform
74
+ :param rate: how many partial utterances should occur per second. Partial utterances must
75
+ cover the span of the entire utterance, thus the rate should not be lower than the inverse
76
+ of the duration of a partial utterance. By default, partial utterances are 1.6s long and
77
+ the minimum rate is thus 0.625.
78
+ :param min_coverage: when reaching the last partial utterance, it may or may not have
79
+ enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
80
+ then the last partial utterance will be considered by zero-padding the audio. Otherwise,
81
+ it will be discarded. If there aren't enough frames for one partial utterance,
82
+ this parameter is ignored so that the function always returns at least one slice.
83
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
84
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
85
+ utterances.
86
+ """
87
+ assert 0 < min_coverage <= 1
88
+
89
+ # Compute how many frames separate two partial utterances
90
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
91
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
92
+ frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
93
+ assert 0 < frame_step, "The rate is too high"
94
+ assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
95
+ (sampling_rate / (samples_per_frame * partials_n_frames))
96
+
97
+ # Compute the slices
98
+ wav_slices, mel_slices = [], []
99
+ steps = max(1, n_frames - partials_n_frames + frame_step + 1)
100
+ for i in range(0, steps, frame_step):
101
+ mel_range = np.array([i, i + partials_n_frames])
102
+ wav_range = mel_range * samples_per_frame
103
+ mel_slices.append(slice(*mel_range))
104
+ wav_slices.append(slice(*wav_range))
105
+
106
+ # Evaluate whether extra padding is warranted or not
107
+ last_wav_range = wav_slices[-1]
108
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
109
+ if coverage < min_coverage and len(mel_slices) > 1:
110
+ mel_slices = mel_slices[:-1]
111
+ wav_slices = wav_slices[:-1]
112
+
113
+ return wav_slices, mel_slices
114
+
115
+ def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
116
+ """
117
+ Computes an embedding for a single utterance. The utterance is divided in partial
118
+ utterances and an embedding is computed for each. The complete utterance embedding is the
119
+ L2-normed average embedding of the partial utterances.
120
+
121
+ TODO: independent batched version of this function
122
+
123
+ :param wav: a preprocessed utterance waveform as a numpy array of float32
124
+ :param return_partials: if True, the partial embeddings will also be returned along with
125
+ the wav slices corresponding to each partial utterance.
126
+ :param rate: how many partial utterances should occur per second. Partial utterances must
127
+ cover the span of the entire utterance, thus the rate should not be lower than the inverse
128
+ of the duration of a partial utterance. By default, partial utterances are 1.6s long and
129
+ the minimum rate is thus 0.625.
130
+ :param min_coverage: when reaching the last partial utterance, it may or may not have
131
+ enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
132
+ then the last partial utterance will be considered by zero-padding the audio. Otherwise,
133
+ it will be discarded. If there aren't enough frames for one partial utterance,
134
+ this parameter is ignored so that the function always returns at least one slice.
135
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
136
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
137
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
138
+ returned.
139
+ """
140
+ # Compute where to split the utterance into partials and pad the waveform with zeros if
141
+ # the partial utterances cover a larger range.
142
+ wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
143
+ max_wave_length = wav_slices[-1].stop
144
+ if max_wave_length >= len(wav):
145
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
146
+
147
+ # Split the utterance into partials and forward them through the model
148
+ mel = audio.wav_to_mel_spectrogram(wav)
149
+ mels = np.array([mel[s] for s in mel_slices])
150
+ with torch.no_grad():
151
+ mels = torch.from_numpy(mels).to(self.device)
152
+ partial_embeds = self(mels).cpu().numpy()
153
+
154
+ # Compute the utterance embedding from the partial embeddings
155
+ raw_embed = np.mean(partial_embeds, axis=0)
156
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
157
+
158
+ if return_partials:
159
+ return embed, partial_embeds, wav_slices
160
+ return embed
161
+
162
+ def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
163
+ """
164
+ Compute the embedding of a collection of wavs (presumably from the same speaker) by
165
+ averaging their embedding and L2-normalizing it.
166
+
167
+ :param wavs: list of wavs a numpy arrays of float32.
168
+ :param kwargs: extra arguments to embed_utterance()
169
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
170
+ """
171
+ raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) \
172
+ for wav in wavs], axis=0)
173
+ return raw_embed / np.linalg.norm(raw_embed, 2)