VietnameseVITS / utils.py
chnk58hoang's picture
convert onnx
db3dea6
raw
history blame
No virus
3.35 kB
from TTS.tts.models.vits import Vits
from TTS.tts.configs.vits_config import VitsConfig
import numpy as np
import unicodedata
import regex
num_re = regex.compile(r"([0-9.,]*[0-9])")
digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"]
def read_number(num: str) -> str:
"""Translate numeric text into written form
Args: num (str) numeric text
Returns: (str) written form of num
"""
if len(num) == 1:
return digits[int(num)]
elif len(num) == 2 and num.isdigit():
n = int(num)
end = digits[n % 10]
if n == 10:
return "mười"
if n % 10 == 5:
end = "lăm"
if n % 10 == 0:
return digits[n // 10] + " mươi"
elif n < 20:
return "mười " + end
else:
if n % 10 == 1:
end = "mốt"
return digits[n // 10] + " mươi " + end
elif len(num) == 3 and num.isdigit():
n = int(num)
if n % 100 == 0:
return digits[n // 100] + " trăm"
elif num[1] == "0":
return digits[n // 100] + " trăm lẻ " + digits[n % 100]
else:
return digits[n // 100] + " trăm " + read_number(num[1:])
elif 4 <= len(num) <= 6 and num.isdigit():
n = int(num)
n1 = n // 1000
return read_number(str(n1)) + " ngàn " + read_number(num[-3:])
elif "," in num:
n1, n2 = num.split(",")
return read_number(n1) + " phẩy " + read_number(n2)
elif "." in num:
parts = num.split(".")
if len(parts) == 2:
if parts[1] == "000":
return read_number(parts[0]) + " ngàn"
elif parts[1].startswith("00"):
end = digits[int(parts[1][2:])]
return read_number(parts[0]) + " ngàn lẻ " + end
else:
return read_number(parts[0]) + " ngàn " + read_number(parts[1])
elif len(parts) == 3:
return (
read_number(parts[0])
+ " triệu "
+ read_number(parts[1])
+ " ngàn "
+ read_number(parts[2])
)
return num
def load_model():
config = VitsConfig()
config.load_json("vits/config.json")
vits = Vits.init_from_config(config)
vits.load_onnx("vits/coqui_vits.onnx")
text = "xin chào tôi là hoàng đây"
text_inputs = np.asarray(
vits.tokenizer.text_to_ids(text),
dtype=np.int64,
)[None, :]
audio = vits.inference_onnx(text_inputs)
return vits
def normalize_text(text):
"""Normalize the input text
Args: text (str) the input text
Returns: text (str) the normalized text
"""
# lowercase
text = text.lower()
# unicode normalize
text = unicodedata.normalize("NFKC", text)
text = text.replace(".", "")
text = text.replace(",", "")
text = text.replace(";", "")
text = text.replace(":", "")
text = text.replace("!", "")
text = text.replace("?", "")
text = text.replace("(", "")
# Convert numeric text into written form
text = num_re.sub(r" \1 ", text)
words = text.split()
words = [read_number(w) if num_re.fullmatch(w) else w for w in words]
text = " ".join(words)
return text