Why is the audio output file not created?(We need your HELP!!!!!!!๐Ÿ˜ญ)

#74
by Rfy23 - opened

My code below looks like this:

It is a code that includes training and inference at the same time. However, I need evaluation metrics such as precision, recall, and f1, and an audio file, but the following code does not create an audio file. What should we modify to get additional evaluation indicators? help me๐Ÿ˜ญ

import os
import sys
import io
import torch
import torchaudio
import wandb # wandb ์ž„ํฌํŠธ
import torch.nn.functional as F

from TTS.config.shared_configs import BaseDatasetConfig, BaseAudioConfig
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.glow_tts import GlowTTS
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor

stdout์˜ ์ธ์ฝ”๋”ฉ์„ UTF-8๋กœ ์„ค์ •

sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')

class TrainerArgs:
pass

class Trainer:
def init(self, args, config, output_path, model, train_samples, eval_samples):
self.args = args
self.config = config
self.output_path = output_path
self.model = model
self.train_samples = train_samples
self.eval_samples = eval_samples
self.optimizer = optimizer # Placeholder for the optimizer

def set_optimizer(self, optimizer):
    self.optimizer = optimizer

def load_audio(self, audio_file):
    waveform, sample_rate = torchaudio.load(audio_file)
    if sample_rate != self.config.audio.sample_rate:
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.config.audio.sample_rate)(waveform)
    return waveform

def train_epoch(self):
    self.model.train()
    total_loss = 0
    for sample in self.train_samples:
        self.optimizer.zero_grad()
        input_data = self.load_audio(sample['audio_file'])
        x_lengths = torch.tensor([input_data.shape[1]])
        target_data = self.load_audio(sample['audio_file'])  
        y = target_data.unsqueeze(0)
        output = self.model(input_data, x_lengths, y)
        loss = self.compute_loss(output, y)
        loss.backward()
        self.optimizer.step()
        total_loss += loss.item()
    average_loss = total_loss / len(self.train_samples)
    print(f"Training Loss: {average_loss}")
    return average_loss

def evaluate(self, samples):
    self.model.eval()
    total_loss = 0
    with torch.no_grad():
        for sample in samples:
            input_data = self.load_audio(sample['audio_file'])
            x_lengths = torch.tensor([input_data.shape[1]])
            target_data = self.load_audio(sample['audio_file'])  
            y = target_data.unsqueeze(0)
            output = self.model(input_data, x_lengths, y)
            loss = self.compute_loss(output, y)
            total_loss += loss.item()
    average_loss = total_loss / len(samples)
    return average_loss

def compute_loss(self, output, target):
    loss = F.mse_loss(output, target)
    return loss

def formatter(root_path, manifest_file, **kwargs):
"""Assumes each line as
.wav|
"""
txt_file = os.path.join(root_path, manifest_file)
items = []
speaker_name = "my_Sherlock"
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.strip().split("|")
wav_file = os.path.join(root_path, "wavs", cols[0])
text = cols[1]
# ๋””๋ฒ„๊น…์„ ์œ„ํ•ด ํŒŒ์ผ ๊ฒฝ๋กœ์™€ ํ…์ŠคํŠธ ์ถœ๋ ฅ
print(f"Processing file: {wav_file}, Text: {text}")
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items

def load_model(config, checkpoint_path):
# ์˜ค๋””์˜ค ํ”„๋กœ์„ธ์„œ ์ดˆ๊ธฐํ™”
ap = AudioProcessor.init_from_config(config)

# ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™” ๋ฐ ์„ค์ • ๊ฐฑ์‹ 
tokenizer, config = TTSTokenizer.init_from_config(config)

# GlowTTS ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
model.load_checkpoint(checkpoint_path)

return model, ap, tokenizer

def infer(model, text, ap, tokenizer):
# ์ž…๋ ฅ ํ…์ŠคํŠธ๋ฅผ ํ† ํฐ์œผ๋กœ ๋ณ€ํ™˜
tokens = tokenizer.text_to_ids(text)

# ํ† ํฐ์„ ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€
tokens = torch.LongTensor(tokens).unsqueeze(0)

# ์ถ”๋ก  ๋ชจ๋“œ ์„ค์ • (๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ ๋น„ํ™œ์„ฑํ™”)
with torch.no_grad():
    # ๋ชจ๋ธ ์ถ”๋ก  ์ˆ˜ํ–‰
    outputs = model.inference(tokens)

# ๋ชจ๋ธ ์ถœ๋ ฅ์„ ์ด์šฉํ•ด ์˜ค๋””์˜ค ํŒŒํ˜• ์ƒ์„ฑ
waveform = ap.invert_spectrogram(outputs["model_outputs"].squeeze(0))

# ์ƒ์„ฑ๋œ ์˜ค๋””์˜ค ํŒŒํ˜• ๋ฐ˜ํ™˜
return waveform

def perform_inference(config, checkpoint_path):
# ๋ชจ๋ธ, ์˜ค๋””์˜ค ํ”„๋กœ์„ธ์„œ, ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™”
model, ap, tokenizer = load_model(config, checkpoint_path)

# ์‚ฌ์šฉ์ž๋กœ๋ถ€ํ„ฐ ํ…์ŠคํŠธ ์ž…๋ ฅ ๋ฐ›๊ธฐ
text = input("๊ฒฐ๊ณผ text ๋‚ด์šฉ์„ ์ ์–ด์ฃผ์„ธ์š”!: ")

# ์ถ”๋ก  ์ˆ˜ํ–‰
waveform = infer(model, text, ap, tokenizer)

# ๊ฒฐ๊ณผ ์˜ค๋””์˜ค ํŒŒ์ผ ์ €์žฅ
output_file = os.path.join(config.output_path, "output_0807_1511.wav")  # ํŒŒ์ผ ํ˜•์‹: output_๋‚ ์งœ_์‹œ๊ฐ„.wav ๋กœ ํ•  ๊ฒƒ!
ap.save_wav(waveform, output_file)
print(f"ํ•™์Šต ์ดํ›„ ๋‹˜์ด ์„ค์ •ํ•œ text๋กœ ์ถ”๋ก  ์˜ค๋””์˜ค๊ฐ€ ์ƒ์„ฑ๋์Šต๋‹ˆ๋‹ค~!..์ถ”๋ก  ์˜ค๋””์˜ค: {output_file}")

def main():
# wandb ์„ค์ • ๋ฐ ์ดˆ๊ธฐํ™”
wandb.init(project="sherlock")
wandb.require("core") # ์ƒˆ๋กœ์šด ๋ฐฑ์—”๋“œ ์‚ฌ์šฉ ์„ค์ •

# torch์˜ ๊ธฐ๋ณธ ํ…์„œ ์œ ํ˜•์„ CPU ํ…์„œ๋กœ ์„ค์ •
torch.set_default_tensor_type(torch.FloatTensor)

# ํ˜„์žฌ ์Šคํฌ๋ฆฝํŠธ์˜ ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ ์ถœ๋ ฅ ๊ฒฝ๋กœ๋กœ ์„ค์ •
output_path = os.path.dirname(os.path.abspath(__file__))

# ๋ฐ์ดํ„ฐ์…‹ ์„ค์ • ์ •์˜
dataset_config = BaseDatasetConfig(
    formatter="Sherlock", 
    meta_file_train="metadata.txt",  # ์˜ฌ๋ฐ”๋ฅธ ๊ฒฝ๋กœ๋กœ ์ˆ˜์ •
    path="MyTTSDataset"  # ์˜ฌ๋ฐ”๋ฅธ ๊ฒฝ๋กœ๋กœ ์ˆ˜์ •
)

audio_config = BaseAudioConfig(
    sample_rate=12050,
)

character_config = CharactersConfig(
    characters_class="TTS.tts.utils.text.characters.Graphemes",
    pad="_",
    eos="~",
    bos="^",
    blank="@",
    characters="Iabdfgijklmnprstuvxzษ”ษ›ษฃษจษซษฑส‚สสฒหˆหฬฏอกฮฒ",
    punctuations="!,.?: -โ€’โ€“โ€”โ€ฆ",
)

# phoneme_cache_path ์„ค์ •
phoneme_cache_path = os.path.join(output_path, "phoneme_cache")

# phoneme_cache ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
if not os.path.exists(phoneme_cache_path):
    os.makedirs(phoneme_cache_path)
    
# ํ•™์Šต ์„ค์ • ์ดˆ๊ธฐํ™”
config = GlowTTSConfig(
    run_name="Testrun",
    run_description="Desc",
    batch_size=32,
    eval_batch_size=16,
    num_loader_workers=4,
    num_eval_loader_workers=4,
    run_eval=True,
    test_delay_epochs=-1,
    epochs=1000,
    text_cleaner="english_cleaners",
    use_phonemes=True,
    phoneme_language="en-us",
    phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
    print_step=25,
    print_eval=False,
    mixed_precision=True,
    output_path=output_path,
    datasets=[dataset_config],
    audio=audio_config,
    characters=character_config,
    eval_split_size=0.2,
    test_sentences=[]
)

# ์˜ค๋””์˜ค ํ”„๋กœ์„ธ์„œ ์ดˆ๊ธฐํ™”
ap = AudioProcessor.init_from_config(config)

# ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™”
tokenizer, config = TTSTokenizer.init_from_config(config)

# ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ ๋กœ๋“œ
train_samples, eval_samples = load_tts_samples(
    dataset_config,
    eval_split=True,
    formatter=formatter,
    eval_split_size=config.eval_split_size,
    eval_split_max_size=config.eval_split_max_size,
)

# ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
model = GlowTTS(config, ap, tokenizer, speaker_manager=None)

# ํŠธ๋ ˆ์ด๋„ˆ ์ดˆ๊ธฐํ™”
trainer_args = TrainerArgs()
trainer = Trainer(
    trainer_args, config, output_path, model=model, optimizer, train_samples=train_samples
)

 # ์˜ตํ‹ฐ๋งˆ์ด์ € ์ดˆ๊ธฐํ™” (Adam ์˜ตํ‹ฐ๋งˆ์ด์ €)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
trainer.set_optimizer(optimizer)

# ํ•™์Šต ์‹œ์ž‘
for epoch in range(config.epochs):
    train_loss = trainer.train_epoch()
    eval_loss = trainer.evaluate(eval_samples)

    # ํ•™์Šต ์†์‹ค๊ณผ ํ‰๊ฐ€ ์†์‹ค ๋กœ๊ทธ ๊ธฐ๋ก
    wandb.log({"epoch": epoch, "train_loss": train_loss, "eval_loss": eval_loss})
    print(f"Epoch {epoch}, Train Loss: {train_loss}, Eval Loss: {eval_loss}")


# ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ
checkpoint_path = os.path.join(output_path, "checkpoints", "์…œ๋ก_model.pth")  # ์‹ค์ œ ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ๋กœ ์ˆ˜์ •
torch.save(model.state_dict(), checkpoint_path)
wandb.save(checkpoint_path)  # wandb์— ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ

# ์ถ”๋ก  ์ˆ˜ํ–‰
perform_inference(config, checkpoint_path)

if name == 'main':
main()

You know, you can simply record your OS's stereo mix.

Sign up or log in to comment