from typing import Mapping import torch import math from speechbrain.inference.interfaces import Pretrained class AttentionMLP(torch.nn.Module): def __init__(self, input_dim, hidden_dim): super(AttentionMLP, self).__init__() self.layers = torch.nn.Sequential( torch.nn.Linear(input_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, 1, bias=False), ) def forward(self, x): x = self.layers(x) att_w = torch.nn.functional.softmax(x, dim=2) return att_w class Discrete_EmbeddingLayer(torch.nn.Module): """This class handles embedding layers for discrete tokens. Arguments --------- num_codebooks: int , number of codebooks of the tokenizer. vocab_size : int, size of the dictionary of embeddings emb_dim: int , the size of each embedding vector pad_index: int (default: 0), If specified, the entries at padding_idx do not contribute to the gradient. init: boolean (default: False): If set to True, init the embedding with the tokenizer embedding otherwise init randomly. freeze: boolean (default: False) If True, the embedding is frozen. If False, the model will be trained alongside with the rest of the pipeline. chunk_size: int The size of lengthwize chunks use when evaluating via Gumbel softmax Example ------- >>> from speechbrain.lobes.models.huggingface_transformers.encodec import Encodec >>> model_hub = "facebook/encodec_24khz" >>> save_path = "savedir" >>> model = Encodec(model_hub, save_path) >>> audio = torch.randn(4, 1000) >>> length = torch.tensor([1.0, .5, .75, 1.0]) >>> tokens, emb = model.encode(audio, length) >>> print(tokens.shape) torch.Size([4, 4, 2]) >>> emb= Discrete_EmbeddingLayer(2, 1024, 1024) >>> in_emb = emb(tokens) >>> print(in_emb.shape) torch.Size([4, 4, 2, 1024]) """ def __init__( self, num_codebooks, vocab_size, emb_dim, pad_index=0, init=False, freeze=False, available_layers=None, layers=None, chunk_size=100, ): super(Discrete_EmbeddingLayer, self).__init__() self.vocab_size = vocab_size self.num_codebooks = num_codebooks self.freeze = freeze self.embedding = torch.nn.Embedding( num_codebooks * vocab_size, emb_dim ).requires_grad_(not self.freeze) self.init = init self.layers = layers self.available_layers = available_layers self.register_buffer("offsets", self.build_offsets()) self.register_buffer("layer_embs", self.compute_layer_embs()) self.chunk_size = chunk_size def init_embedding(self, weights): with torch.no_grad(): self.embedding.weight = torch.nn.Parameter(weights) def build_offsets(self): offsets = torch.arange( 0, self.num_codebooks * self.vocab_size, self.vocab_size, ) if self.layers: selected_layers = set(self.layers) indexes = [ idx for idx, layer in enumerate(self.available_layers) if layer in selected_layers ] offsets = offsets[indexes] return offsets def forward(self, in_tokens): """Computes the embedding for discrete tokens. a sample. Arguments --------- in_tokens : torch.Tensor A (Batch x Time x num_codebooks) audio sample Returns ------- in_embs : torch.Tensor """ with torch.set_grad_enabled(not self.freeze): # Add unique token IDs across diffrent codebooks by adding num_codebooks * vocab_size in_tokens_offset = in_tokens + self.offsets.to(in_tokens.device) # Forward Pass to embedding and in_embs = self.embedding(in_tokens_offset.int()) return in_embs def compute_layer_embs(self): weight = self.embedding.weight # Compute offsets layer_idx_map = { layer: idx for idx, layer in enumerate(self.available_layers) } layer_idx = [ layer_idx_map[layer] for layer in self.layers ] offsets = [ idx * self.vocab_size for idx in layer_idx ] layer_embs = torch.stack([ weight[offset:offset + self.vocab_size] for offset in offsets ]) # To (Batch x Length x Emb) layer_embs = layer_embs.unsqueeze(0).unsqueeze(0) return layer_embs def encode_logits(self, logits, length=None): """Computes waveforms from a batch of discrete units Arguments --------- units: torch.tensor Batch of discrete unit logits [batch, length, head, token] or tokens [batch, length, head] spk: torch.tensor Batch of speaker embeddings [batch, spk_dim] Returns ------- waveforms: torch.tensor Batch of mel-waveforms [batch, 1, time] """ # Convert logits to one-hot representations # without losing the gradient units_gumbel = torch.nn.functional.gumbel_softmax( logits, hard=False, dim=-1 ) # Straight-through trick _, argmax_idx = logits.max(dim=-1, keepdim=True) units_ref = torch.zeros_like(logits).scatter_( dim=-1, index=argmax_idx, src=torch.ones_like(logits) ) units_hard = units_ref - units_gumbel.detach() + units_gumbel # Sum over embeddings for each layer units_hard_chunked = units_hard.chunk( math.ceil(units_hard.size(1) / self.chunk_size), dim=1 ) emb = torch.cat( [ (self.layer_embs * units_hard_chunk.unsqueeze(-1)).sum(-2) for units_hard_chunk in units_hard_chunked ], dim=1 ) return emb def load_state_dict(self, state_dict, strict=True): result = super().load_state_dict(state_dict, strict) self.layer_embs = self.compute_layer_embs() return result class DiscreteSpkEmb(Pretrained): """A ready-to-use class for utterance-level classification (e.g, speaker-id, language-id, emotion recognition, keyword spotting, etc). The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model are defined in the yaml file. If you want to convert the predicted index into a corresponding text label, please provide the path of the label_encoder in a variable called 'lab_encoder_file' within the yaml. The class can be used either to run only the encoder (encode_batch()) to extract embeddings or to run a classification step (classify_batch()). ``` Example ------- >>> import torchaudio >>> from speechbrain.pretrained import EncoderClassifier >>> # Model is downloaded from the speechbrain HuggingFace repo >>> tmpdir = getfixture("tmpdir") >>> classifier = EncoderClassifier.from_hparams( ... source="speechbrain/spkrec-ecapa-voxceleb", ... savedir=tmpdir, ... ) >>> # Compute embeddings >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav") >>> embeddings = classifier.encode_batch(signal) >>> # Classification >>> prediction = classifier .classify_batch(signal) """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def encode_batch(self, audio, length=None): """Encodes the input audio into a single vector embedding. The waveforms should already be in the model's desired format. Arguments --------- audio : torch.tensor Batch of tokenized audio [batch, time, heads] length : torch.tensor Lengths of the waveforms relative to the longest one in the batch, tensor of shape [batch]. The longest one should have relative length 1.0 and others len(waveform) / max_length. Used for ignoring padding. Returns ------- torch.tensor The encoded batch """ # Manage single waveforms in input embeddings = self.mods.discrete_embedding_layer(audio) att_w = self.mods.attention_mlp(embeddings) feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) embeddings = self.mods.embedding_model(feats, length) return embeddings.squeeze(1) def encode_logits(self, logits, length=None): """Encodes the input audio logits into a single vector embedding. Arguments --------- audio : torch.tensor Batch of tokenized audio [batch, time, heads] length : torch.tensor Lengths of the waveforms relative to the longest one in the batch, tensor of shape [batch]. The longest one should have relative length 1.0 and others len(waveform) / max_length. Used for ignoring padding. Returns ------- torch.tensor The encoded batch """ embeddings = self.mods.discrete_embedding_layer.encode_logits(logits) att_w = self.mods.attention_mlp(embeddings) feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) embeddings = self.mods.embedding_model(feats, length) return embeddings.squeeze(1) def forward(self, audio, length=None): """Encodes the input audio into a single vector embedding. The waveforms should already be in the model's desired format. Arguments --------- audio : torch.tensor Batch of tokenized audio [batch, time, heads] or logits [batch, time, heads, tokens] length : torch.tensor Lengths of the waveforms relative to the longest one in the batch, tensor of shape [batch]. The longest one should have relative length 1.0 and others len(waveform) / max_length. Used for ignoring padding. Returns ------- torch.tensor The encoded batch """ audio_dim = audio.dim() if audio_dim == 3: embeddings = self.encode_batch(audio, length) elif audio_dim == 4: embeddings = self.encode_logits(audio, length) else: raise ValueError("Unsupported audio shape {audio.shape}") return embeddings