r""" The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it when needed. Parameters from hparam.py will be used """ import argparse import json import os import sys from pathlib import Path import rootutils import torch from hydra import compose, initialize from omegaconf import open_dict from tqdm.auto import tqdm from matcha.data.text_mel_datamodule import TextMelDataModule from matcha.utils.logging_utils import pylogger log = pylogger.get_pylogger(__name__) def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): """Generate data mean and standard deviation helpful in data normalisation Args: data_loader (torch.utils.data.Dataloader): _description_ out_channels (int): mel spectrogram channels """ total_mel_sum = 0 total_mel_sq_sum = 0 total_mel_len = 0 for batch in tqdm(data_loader, leave=False): mels = batch["y"] mel_lengths = batch["y_lengths"] total_mel_len += torch.sum(mel_lengths) total_mel_sum += torch.sum(mels) total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) data_mean = total_mel_sum / (total_mel_len * out_channels) data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} def main(): parser = argparse.ArgumentParser() parser.add_argument( "-i", "--input-config", type=str, default="vctk.yaml", help="The name of the yaml config file under configs/data", ) parser.add_argument( "-b", "--batch-size", type=int, default="256", help="Can have increased batch size for faster computation", ) parser.add_argument( "-f", "--force", action="store_true", default=False, required=False, help="force overwrite the file", ) args = parser.parse_args() output_file = Path(args.input_config).with_suffix(".json") if os.path.exists(output_file) and not args.force: print("File already exists. Use -f to force overwrite") sys.exit(1) with initialize(version_base="1.3", config_path="../../configs/data"): cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") with open_dict(cfg): del cfg["hydra"] del cfg["_target_"] cfg["data_statistics"] = None cfg["seed"] = 1234 cfg["batch_size"] = args.batch_size cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) text_mel_datamodule = TextMelDataModule(**cfg) text_mel_datamodule.setup() data_loader = text_mel_datamodule.train_dataloader() log.info("Dataloader loaded! Now computing stats...") params = compute_data_statistics(data_loader, cfg["n_feats"]) print(params) json.dump( params, open(output_file, "w"), ) if __name__ == "__main__": main()