admin commited on
Commit
58ca3ce
1 Parent(s): 8e45898
Files changed (7) hide show
  1. .gitattributes +11 -11
  2. .gitignore +7 -0
  3. README.md +7 -4
  4. app.py +220 -0
  5. model.py +146 -0
  6. requirements.txt +6 -0
  7. utils.py +67 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
35
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *.pt
2
+ __pycache__/*
3
+ tmp/*
4
+ flagged/*
5
+ test.py
6
+ ffmpeg/*
7
+ rename.sh
README.md CHANGED
@@ -1,13 +1,16 @@
1
  ---
2
- title: Music Genre
3
- emoji: 🔥
4
  colorFrom: pink
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.13.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
  ---
2
+ title: Music Genre Classifier
3
+ emoji: 🎶
4
  colorFrom: pink
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.36.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
+ ## Maintenance
14
+ ```bash
15
+ GIT_LFS_SKIP_SMUDGE=1 git clone git@hf.co:spaces/ccmusic-database/music-genre
16
+ ```
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import random
5
+ import shutil
6
+ import librosa
7
+ import warnings
8
+ import subprocess
9
+ import numpy as np
10
+ import gradio as gr
11
+ import librosa.display
12
+ import matplotlib.pyplot as plt
13
+ import torchvision.transforms as transforms
14
+ from utils import get_modelist, find_mp3_files, download
15
+ from collections import Counter
16
+ from model import EvalNet
17
+ from PIL import Image
18
+
19
+
20
+ TRANSLATE = {
21
+ "Symphony": "交响乐 Symphony",
22
+ "Opera": "戏曲 Opera",
23
+ "Solo": "独奏 Solo",
24
+ "Chamber": "室内乐 Chamber",
25
+ "Pop_vocal_ballad": "芭乐 Pop vocal ballad",
26
+ "Adult_contemporary": "成人时代 Adult contemporary",
27
+ "Teen_pop": "青少年流行 Teen pop",
28
+ "Contemporary_dance_pop": "当代流行舞曲 Contemporary dance pop",
29
+ "Dance_pop": "流行舞曲 Dance pop",
30
+ "Classic_indie_pop": "经典独立流行 Classic indie pop",
31
+ "Chamber_cabaret_and_art_pop": "室内卡巴莱与艺术流行乐 Chamber cabaret & art pop",
32
+ "Soul_or_r_and_b": "灵魂乐或节奏布鲁斯 Soul / R&B",
33
+ "Adult_alternative_rock": "成人另类摇滚 Adult alternative rock",
34
+ "Uplifting_anthemic_rock": "迷幻民族摇滚 Uplifting anthemic rock",
35
+ "Soft_rock": "慢摇滚 Soft rock",
36
+ "Acoustic_pop": "原声流行 Acoustic pop",
37
+ }
38
+
39
+ CLASSES = list(TRANSLATE.keys())
40
+
41
+
42
+ def most_common_element(input_list):
43
+ counter = Counter(input_list)
44
+ mce, _ = counter.most_common(1)[0]
45
+ return mce
46
+
47
+
48
+ def mp3_to_mel(audio_path: str, width=11.4):
49
+ os.makedirs("./flagged", exist_ok=True)
50
+ try:
51
+ y, sr = librosa.load(audio_path)
52
+ mel_spec = librosa.feature.melspectrogram(y=y, sr=sr)
53
+ log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
54
+ dur = librosa.get_duration(y=y, sr=sr)
55
+ total_frames = log_mel_spec.shape[1]
56
+ step = int(width * total_frames / dur)
57
+ count = int(total_frames / step)
58
+ begin = int(0.5 * (total_frames - count * step))
59
+ end = begin + step * count
60
+ for i in range(begin, end, step):
61
+ librosa.display.specshow(log_mel_spec[:, i : i + step])
62
+ plt.axis("off")
63
+ plt.savefig(
64
+ f"./flagged/mel_{round(dur, 2)}_{i}.jpg",
65
+ bbox_inches="tight",
66
+ pad_inches=0.0,
67
+ )
68
+ plt.close()
69
+
70
+ except Exception as e:
71
+ print(f"Error converting {audio_path} : {e}")
72
+
73
+
74
+ def mp3_to_cqt(audio_path: str, width=11.4):
75
+ os.makedirs("./flagged", exist_ok=True)
76
+ try:
77
+ y, sr = librosa.load(audio_path)
78
+ cqt_spec = librosa.cqt(y=y, sr=sr)
79
+ log_cqt_spec = librosa.power_to_db(np.abs(cqt_spec) ** 2, ref=np.max)
80
+ dur = librosa.get_duration(y=y, sr=sr)
81
+ total_frames = log_cqt_spec.shape[1]
82
+ step = int(width * total_frames / dur)
83
+ count = int(total_frames / step)
84
+ begin = int(0.5 * (total_frames - count * step))
85
+ end = begin + step * count
86
+ for i in range(begin, end, step):
87
+ librosa.display.specshow(log_cqt_spec[:, i : i + step])
88
+ plt.axis("off")
89
+ plt.savefig(
90
+ f"./flagged/cqt_{round(dur, 2)}_{i}.jpg",
91
+ bbox_inches="tight",
92
+ pad_inches=0.0,
93
+ )
94
+ plt.close()
95
+
96
+ except Exception as e:
97
+ print(f"Error converting {audio_path} : {e}")
98
+
99
+
100
+ def mp3_to_chroma(audio_path: str, width=11.4):
101
+ os.makedirs("./flagged", exist_ok=True)
102
+ try:
103
+ y, sr = librosa.load(audio_path)
104
+ chroma_spec = librosa.feature.chroma_stft(y=y, sr=sr)
105
+ log_chroma_spec = librosa.power_to_db(np.abs(chroma_spec) ** 2, ref=np.max)
106
+ dur = librosa.get_duration(y=y, sr=sr)
107
+ total_frames = log_chroma_spec.shape[1]
108
+ step = int(width * total_frames / dur)
109
+ count = int(total_frames / step)
110
+ begin = int(0.5 * (total_frames - count * step))
111
+ end = begin + step * count
112
+ for i in range(begin, end, step):
113
+ librosa.display.specshow(log_chroma_spec[:, i : i + step])
114
+ plt.axis("off")
115
+ plt.savefig(
116
+ f"./flagged/chroma_{round(dur, 2)}_{i}.jpg",
117
+ bbox_inches="tight",
118
+ pad_inches=0.0,
119
+ )
120
+ plt.close()
121
+
122
+ except Exception as e:
123
+ print(f"Error converting {audio_path} : {e}")
124
+
125
+
126
+ def embed_img(img_path, input_size=224):
127
+ transform = transforms.Compose(
128
+ [
129
+ transforms.Resize([input_size, input_size]),
130
+ transforms.ToTensor(),
131
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
132
+ ]
133
+ )
134
+ img = Image.open(img_path).convert("RGB")
135
+ return transform(img).unsqueeze(0)
136
+
137
+
138
+ def inference(mp3_path, log_name: str, folder_path="./flagged"):
139
+ if os.path.exists(folder_path):
140
+ shutil.rmtree(folder_path)
141
+
142
+ if not mp3_path:
143
+ return None, "请输入音频 Please input an audio!"
144
+
145
+ network = EvalNet(log_name)
146
+ spec = log_name.split("_")[-1]
147
+ eval("mp3_to_%s" % spec)(mp3_path)
148
+ outputs = []
149
+ all_files = os.listdir(folder_path)
150
+ for file_name in all_files:
151
+ if file_name.lower().endswith(".jpg"):
152
+ file_path = os.path.join(folder_path, file_name)
153
+ input = embed_img(file_path)
154
+ output: torch.Tensor = network.model(input)
155
+ pred_id = torch.max(output.data, 1)[1]
156
+ outputs.append(int(pred_id))
157
+
158
+ max_count_item = most_common_element(outputs)
159
+ shutil.rmtree(folder_path)
160
+ return os.path.basename(mp3_path), TRANSLATE[CLASSES[max_count_item]]
161
+
162
+
163
+ if __name__ == "__main__":
164
+ warnings.filterwarnings("ignore")
165
+ ffmpeg = "ffmpeg-release-amd64-static"
166
+ if sys.platform.startswith("linux"):
167
+ if not os.path.exists(f"./{ffmpeg}.tar.xz"):
168
+ download(
169
+ f"https://www.modelscope.cn/studio/ccmusic-database/music_genre/resolve/master/{ffmpeg}.tar.xz"
170
+ )
171
+
172
+ folder_path = f"{os.getcwd()}/{ffmpeg}"
173
+ if not os.path.exists(folder_path):
174
+ subprocess.call(f"tar -xvf {ffmpeg}.tar.xz", shell=True)
175
+
176
+ os.environ["PATH"] = f"{folder_path}:{os.environ.get('PATH', '')}"
177
+
178
+ models = get_modelist()
179
+ examples = []
180
+ example_mp3s = find_mp3_files()
181
+ model_num = len(models)
182
+ for mp3 in example_mp3s:
183
+ examples.append([mp3, models[random.randint(0, model_num - 1)]])
184
+
185
+ with gr.Blocks() as demo:
186
+ gr.Interface(
187
+ fn=inference,
188
+ inputs=[
189
+ gr.Audio(label="上传MP3音频 Upload MP3", type="filepath"),
190
+ gr.Dropdown(
191
+ choices=models, label="选择模型 Select a model", value=models[6]
192
+ ),
193
+ ],
194
+ outputs=[
195
+ gr.Textbox(label="音频文件名 Audio filename", show_copy_button=True),
196
+ gr.Textbox(label="流派识别 Genre recognition", show_copy_button=True),
197
+ ],
198
+ examples=examples,
199
+ cache_examples=False,
200
+ allow_flagging="never",
201
+ title="建议录音时长保持在 15s 以内, 过长会影响识别效率<br>It is recommended to keep the duration of recording within 15s, too long will affect the recognition efficiency.",
202
+ )
203
+
204
+ gr.Markdown(
205
+ """
206
+ # 引用 Cite
207
+ ```bibtex
208
+ @dataset{zhaorui_liu_2021_5676893,
209
+ author = {Monan Zhou, Shenyang Xu, Zhaorui Liu, Zhaowen Wang, Feng Yu, Wei Li and Baoqiang Han},
210
+ title = {CCMusic: an Open and Diverse Database for Chinese and General Music Information Retrieval Research},
211
+ month = {mar},
212
+ year = {2024},
213
+ publisher = {HuggingFace},
214
+ version = {1.2},
215
+ url = {https://huggingface.co/ccmusic-database}
216
+ }
217
+ ```"""
218
+ )
219
+
220
+ demo.launch()
model.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ from modelscope.msdatasets import MsDataset
5
+ from utils import MODEL_DIR
6
+
7
+
8
+ class EvalNet:
9
+ model: nn.Module = None
10
+ m_type = "squeezenet"
11
+ input_size = 224
12
+ output_size = 512
13
+
14
+ def __init__(self, log_name: str, cls_num=16):
15
+ saved_model_path = f"{MODEL_DIR}/{log_name}/save.pt"
16
+ m_ver = "_".join(log_name.split("_")[:-1])
17
+ self.m_type, self.input_size = self._model_info(m_ver)
18
+
19
+ if not hasattr(models, m_ver):
20
+ print("Unsupported model.")
21
+ exit()
22
+
23
+ self.model = eval("models.%s()" % m_ver)
24
+ linear_output = self._set_outsize()
25
+ self._set_classifier(cls_num, linear_output)
26
+ checkpoint = torch.load(saved_model_path, map_location="cpu")
27
+ if torch.cuda.is_available():
28
+ checkpoint = torch.load(saved_model_path)
29
+
30
+ self.model.load_state_dict(checkpoint, False)
31
+ self.model.eval()
32
+
33
+ def _get_backbone(self, ver, backbone_list):
34
+ for bb in backbone_list:
35
+ if ver == bb["ver"]:
36
+ return bb
37
+
38
+ print("Backbone name not found, using default option - alexnet.")
39
+ return backbone_list[0]
40
+
41
+ def _model_info(self, m_ver):
42
+ backbone_list = MsDataset.load(
43
+ "monetjoe/cv_backbones",
44
+ split="v1",
45
+ trust_remote_code=True,
46
+ )
47
+ backbone = self._get_backbone(m_ver, backbone_list)
48
+ m_type = str(backbone["type"])
49
+ input_size = int(backbone["input_size"])
50
+ return m_type, input_size
51
+
52
+ def _classifier(self, cls_num: int, output_size: int, linear_output: bool):
53
+ q = (1.0 * output_size / cls_num) ** 0.25
54
+ l1 = int(q * cls_num)
55
+ l2 = int(q * l1)
56
+ l3 = int(q * l2)
57
+ if linear_output:
58
+ return torch.nn.Sequential(
59
+ nn.Dropout(),
60
+ nn.Linear(output_size, l3),
61
+ nn.ReLU(inplace=True),
62
+ nn.Dropout(),
63
+ nn.Linear(l3, l2),
64
+ nn.ReLU(inplace=True),
65
+ nn.Dropout(),
66
+ nn.Linear(l2, l1),
67
+ nn.ReLU(inplace=True),
68
+ nn.Linear(l1, cls_num),
69
+ )
70
+
71
+ else:
72
+ return torch.nn.Sequential(
73
+ nn.Dropout(),
74
+ nn.Conv2d(output_size, l3, kernel_size=(1, 1), stride=(1, 1)),
75
+ nn.ReLU(inplace=True),
76
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
77
+ nn.Flatten(),
78
+ nn.Linear(l3, l2),
79
+ nn.ReLU(inplace=True),
80
+ nn.Dropout(),
81
+ nn.Linear(l2, l1),
82
+ nn.ReLU(inplace=True),
83
+ nn.Linear(l1, cls_num),
84
+ )
85
+
86
+ def _set_outsize(self, debug_mode=False):
87
+ for name, module in self.model.named_modules():
88
+ if (
89
+ str(name).__contains__("classifier")
90
+ or str(name).__eq__("fc")
91
+ or str(name).__contains__("head")
92
+ ):
93
+ if isinstance(module, torch.nn.Linear):
94
+ self.output_size = module.in_features
95
+ if debug_mode:
96
+ print(
97
+ f"{name}(Linear): {self.output_size} -> {module.out_features}"
98
+ )
99
+ return True
100
+
101
+ if isinstance(module, torch.nn.Conv2d):
102
+ self.output_size = module.in_channels
103
+ if debug_mode:
104
+ print(
105
+ f"{name}(Conv2d): {self.output_size} -> {module.out_channels}"
106
+ )
107
+ return False
108
+
109
+ return False
110
+
111
+ def _set_classifier(self, cls_num, linear_output):
112
+ if self.m_type == "convnext":
113
+ del self.model.classifier[2]
114
+ self.model.classifier = nn.Sequential(
115
+ *list(self.model.classifier)
116
+ + list(self._classifier(cls_num, self.output_size, linear_output))
117
+ )
118
+ return
119
+
120
+ if hasattr(self.model, "classifier"):
121
+ self.model.classifier = self._classifier(
122
+ cls_num, self.output_size, linear_output
123
+ )
124
+ return
125
+
126
+ elif hasattr(self.model, "fc"):
127
+ self.model.fc = self._classifier(cls_num, self.output_size, linear_output)
128
+ return
129
+
130
+ elif hasattr(self.model, "head"):
131
+ self.model.head = self._classifier(cls_num, self.output_size, linear_output)
132
+ return
133
+
134
+ self.model.heads.head = self._classifier(
135
+ cls_num, self.output_size, linear_output
136
+ )
137
+
138
+ def forward(self, x):
139
+ if torch.cuda.is_available():
140
+ x = x.cuda()
141
+ self.model = self.model.cuda()
142
+
143
+ if self.m_type == "googlenet" and self.training:
144
+ return self.model(x)[0]
145
+ else:
146
+ return self.model(x)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ librosa
2
+ torch
3
+ matplotlib
4
+ torchvision
5
+ pillow
6
+ modelscope==1.15
utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import requests
4
+ from modelscope import snapshot_download
5
+
6
+ MODEL_DIR = snapshot_download(
7
+ "ccmusic-database/music_genre",
8
+ cache_dir="./__pycache__",
9
+ )
10
+
11
+
12
+ def toCUDA(x):
13
+ if hasattr(x, "cuda"):
14
+ if torch.cuda.is_available():
15
+ return x.cuda()
16
+
17
+ return x
18
+
19
+
20
+ def find_mp3_files(folder_path=f"{MODEL_DIR}/examples"):
21
+ wav_files = []
22
+ for root, _, files in os.walk(folder_path):
23
+ for file in files:
24
+ if file.endswith(".mp3"):
25
+ file_path = os.path.join(root, file)
26
+ wav_files.append(file_path)
27
+
28
+ return wav_files
29
+
30
+
31
+ def get_modelist(model_dir=MODEL_DIR):
32
+ try:
33
+ entries = os.listdir(model_dir)
34
+ except OSError as e:
35
+ print(f"无法访问 {model_dir}: {e}")
36
+ return
37
+
38
+ # 遍历所有条目
39
+ output = []
40
+ for entry in entries:
41
+ # 获取完整路径
42
+ full_path = os.path.join(model_dir, entry)
43
+
44
+ # 跳过'.git'文件夹
45
+ if entry == ".git" or entry == "examples":
46
+ print(f"跳过 .git / examples 文件夹: {full_path}")
47
+ continue
48
+
49
+ # 检查条目是文件还是目录
50
+ if os.path.isdir(full_path):
51
+ # 打印目录路径
52
+ output.append(os.path.basename(full_path))
53
+
54
+ return output
55
+
56
+
57
+ def download(url: str):
58
+ filename = url.split("/")[-1]
59
+ response = requests.get(url, stream=True)
60
+ if response.status_code == 200:
61
+ with open(filename, "wb") as f:
62
+ for chunk in response.iter_content(chunk_size=8192):
63
+ f.write(chunk)
64
+
65
+ print(f"文件已下载到 {os.getcwd()}/{filename}")
66
+ else:
67
+ print(f"下载失败,状态码:{response.status_code}")