asigalov61 commited on
Commit
c1fc4e4
1 Parent(s): 40bf57f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -234
app.py CHANGED
@@ -1,4 +1,6 @@
 
1
  # https://huggingface.co/spaces/asigalov61/Chords-Progressions-Generator
 
2
 
3
  import os
4
  import time as reqtime
@@ -7,15 +9,19 @@ from pytz import timezone
7
 
8
  import gradio as gr
9
 
 
 
 
10
  import random
 
11
 
12
  import TMIDIX
13
 
14
- import numpy as np
15
-
16
  # =================================================================================================
17
 
18
- def ClassifyMIDI(input_midi, input_sampling_resolution):
19
 
20
  print('=' * 70)
21
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
@@ -33,105 +39,9 @@ def ClassifyMIDI(input_midi, input_sampling_resolution):
33
  print('Input MIDI file name:', fn)
34
 
35
  print('=' * 70)
36
- print('Loading MIDI file...')
37
-
38
- midi_name = fn
39
-
40
- raw_score = TMIDIX.midi2single_track_ms_score(open(input_midi.name, 'rb').read())
41
-
42
- escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
43
-
44
- #===============================================================================
45
- # Augmented enhanced score notes
46
-
47
- escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)
48
-
49
- escore_notes = [e for e in escore_notes if e[6] < 80 or e[6] == 128]
50
-
51
- #=======================================================
52
- # Augmentation
53
-
54
- #=======================================================
55
- # FINAL PROCESSING
56
-
57
- melody_chords = []
58
-
59
- #=======================================================
60
- # MAIN PROCESSING CYCLE
61
- #=======================================================
62
-
63
- pe = escore_notes[0]
64
-
65
- pitches = []
66
-
67
- notes_counter = 0
68
-
69
- for e in escore_notes:
70
-
71
- #=======================================================
72
- # Timings...
73
-
74
- delta_time = max(0, min(127, e[1]-pe[1]))
75
-
76
- if delta_time != 0:
77
- pitches = []
78
-
79
- # Durations and channels
80
-
81
- dur = max(1, min(127, e[2]))
82
-
83
- # Patches
84
- pat = max(0, min(128, e[6]))
85
-
86
- # Pitches
87
-
88
- if pat == 128:
89
- ptc = max(1, min(127, e[4]))+128
90
- else:
91
- ptc = max(1, min(127, e[4]))
92
-
93
- #=======================================================
94
- # FINAL NOTE SEQ
95
-
96
- # Writing final note synchronously
97
-
98
- if ptc not in pitches:
99
- melody_chords.extend([delta_time, dur+128, ptc+256])
100
- pitches.append(ptc)
101
- notes_counter += 1
102
-
103
- pe = e
104
-
105
- #==============================================================
106
-
107
- print('Done!')
108
- print('=' * 70)
109
-
110
- print('Sampling score...')
111
-
112
- chunk_size = 1020
113
-
114
- score = melody_chords
115
-
116
- input_data = []
117
-
118
- for i in range(0, len(score)-chunk_size, chunk_size // input_sampling_resolution):
119
- schunk = score[i:i+chunk_size]
120
-
121
- if len(schunk) == chunk_size:
122
-
123
- td = [937]
124
-
125
- td.extend(schunk)
126
 
127
- td.extend([938])
128
-
129
- input_data.append(td)
130
-
131
  print('Done!')
132
  print('=' * 70)
133
-
134
- #==============================================================
135
 
136
  classification_summary_string = '=' * 70
137
  classification_summary_string += '\n'
@@ -153,79 +63,11 @@ def ClassifyMIDI(input_midi, input_sampling_resolution):
153
  classification_summary_string += '=' * 70
154
  classification_summary_string += '\n'
155
 
156
- print('Loading model...')
157
-
158
- SEQ_LEN = 1026
159
- PAD_IDX = 940
160
- DEVICE = 'cuda' # 'cuda'
161
-
162
- # instantiate the model
163
-
164
- model = TransformerWrapper(
165
- num_tokens = PAD_IDX+1,
166
- max_seq_len = SEQ_LEN,
167
- attn_layers = Decoder(dim = 1024, depth = 24, heads = 32, attn_flash = True)
168
- )
169
-
170
- model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
171
-
172
- model = torch.nn.DataParallel(model)
173
-
174
- model.to(DEVICE)
175
-
176
- print('=' * 70)
177
-
178
- print('Loading model checkpoint...')
179
-
180
- model.load_state_dict(
181
- torch.load('Ultimate_MIDI_Classifier_Trained_Model_29886_steps_0.556_loss_0.8339_acc.pth',
182
- map_location=DEVICE))
183
- print('=' * 70)
184
-
185
- if DEVICE == 'cpu':
186
- dtype = torch.bfloat16
187
- else:
188
- dtype = torch.bfloat16
189
-
190
- ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
191
-
192
- print('Done!')
193
- print('=' * 70)
194
-
195
  #==================================================================
196
 
197
  print('=' * 70)
198
  print('Ultimate MIDI Classifier')
199
  print('=' * 70)
200
- print('Classifying...')
201
-
202
- torch.cuda.empty_cache()
203
-
204
- model.eval()
205
-
206
- artist_results = []
207
- song_results = []
208
-
209
- results = []
210
-
211
- for input in input_data:
212
-
213
- x = torch.tensor(input[:1022], dtype=torch.long, device='cuda')
214
-
215
- with ctx:
216
- out = model.module.generate(x,
217
- 2,
218
- filter_logits_fn=top_k,
219
- filter_kwargs={'k': 1},
220
- temperature=0.9,
221
- return_prime=False,
222
- verbose=False)
223
-
224
- result = tuple(out[0].tolist())
225
-
226
- results.append(result)
227
-
228
- final_result = mode(results)
229
 
230
  print('=' * 70)
231
  print('Done!')
@@ -302,93 +144,39 @@ if __name__ == "__main__":
302
  print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
303
  print('=' * 70)
304
 
305
- #===============================================================================
306
- # Helper functions
307
- #===============================================================================
308
-
309
- def str_strip_song(string):
310
- if string is not None:
311
- string = string.replace('-', ' ').replace('_', ' ').replace('=', ' ')
312
- str1 = re.compile('[^a-zA-Z ]').sub('', string)
313
- return re.sub(' +', ' ', str1).strip().title()
314
- else:
315
- return ''
316
-
317
- def str_strip_artist(string):
318
- if string is not None:
319
- string = string.replace('-', ' ').replace('_', ' ').replace('=', ' ')
320
- str1 = re.compile('[^0-9a-zA-Z ]').sub('', string)
321
- return re.sub(' +', ' ', str1).strip().title()
322
- else:
323
- return ''
324
-
325
- def song_artist_to_song_artist_tokens(file_name):
326
- idx = classifier_labels.index(file_name)
327
-
328
- tok1 = idx // 424
329
- tok2 = idx % 424
330
-
331
- return [tok1, tok2]
332
-
333
- def song_artist_tokens_to_song_artist(file_name_tokens):
334
-
335
- tok1 = file_name_tokens[0]
336
- tok2 = file_name_tokens[1]
337
-
338
- idx = (tok1 * 424) + tok2
339
-
340
- return classifier_labels[idx]
341
-
342
- #===============================================================================
343
-
344
  print('=' * 70)
345
  print('Loading Ultimate MIDI Classifier labels...')
346
  print('=' * 70)
347
- classifier_labels = TMIDIX.Tegridy_Any_Pickle_File_Reader('Ultimate_MIDI_Classifier_Song_Artist_Labels')
348
- print('=' * 70)
349
- genre_labels = TMIDIX.Tegridy_Any_Pickle_File_Reader('Ultimate_MIDI_Classifier_Music_Genre_Labels')
350
- genre_labels_fnames = [f[0] for f in genre_labels]
351
  print('=' * 70)
352
  print('Done!')
353
  print('=' * 70)
354
 
355
  app = gr.Blocks()
356
  with app:
357
- gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Ultimate MIDI Classifier</h1>")
358
- gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Classify absolutely any MIDI by genre, song and artist</h1>")
359
  gr.Markdown(
360
- "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Ultimate-MIDI-Classifier&style=flat)\n\n"
361
- "This is a demo for Ultimate MIDI Classifier\n\n"
362
- "Check out [Ultimate MIDI Classifier](https://github.com/asigalov61/Ultimate-MIDI-Classifier) on GitHub!\n\n"
363
  "[Open In Colab]"
364
- "(https://colab.research.google.com/github/asigalov61/Ultimate-MIDI-Classifier/blob/main/Ultimate_MIDI_Classifier.ipynb)"
365
- " for all options, faster execution and endless classification"
366
  )
367
 
368
- gr.Markdown("## Upload any MIDI to classify")
369
- gr.Markdown("### Please note that the MIDI file must have at least 340 notes for this demo to work")
370
-
371
- input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
372
  input_sampling_resolution = gr.Slider(1, 5, value=2, step=1, label="Classification sampling resolution")
373
 
374
  run_btn = gr.Button("classify", variant="primary")
375
 
376
- gr.Markdown("## Classification results")
377
 
378
  output_midi_cls_summary = gr.Textbox(label="MIDI classification results")
379
 
380
  run_event = run_btn.click(ClassifyMIDI, [input_midi, input_sampling_resolution],
381
  [output_midi_cls_summary])
382
- gr.Examples(
383
- [["Honesty.kar", 2],
384
- ["House Of The Rising Sun.mid", 2],
385
- ["Nothing Else Matters.kar", 2],
386
- ["Sharing The Night Together.kar", 2]
387
- ],
388
- [input_midi, input_sampling_resolution],
389
- [output_midi_cls_summary],
390
- ClassifyMIDI,
391
- cache_examples=True,
392
- )
393
-
394
  app.queue().launch()
 
1
+ # =================================================================================================
2
  # https://huggingface.co/spaces/asigalov61/Chords-Progressions-Generator
3
+ # =================================================================================================
4
 
5
  import os
6
  import time as reqtime
 
9
 
10
  import gradio as gr
11
 
12
+ import numpy as np
13
+
14
+ import os
15
  import random
16
+ from collections import Counter
17
 
18
  import TMIDIX
19
 
20
+ from midi_to_colab_audio import midi_to_colab_audio
21
+
22
  # =================================================================================================
23
 
24
+ def Generate_Chords_Progression(input_midi, input_sampling_resolution):
25
 
26
  print('=' * 70)
27
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
 
39
  print('Input MIDI file name:', fn)
40
 
41
  print('=' * 70)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
 
 
 
 
43
  print('Done!')
44
  print('=' * 70)
 
 
45
 
46
  classification_summary_string = '=' * 70
47
  classification_summary_string += '\n'
 
63
  classification_summary_string += '=' * 70
64
  classification_summary_string += '\n'
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  #==================================================================
67
 
68
  print('=' * 70)
69
  print('Ultimate MIDI Classifier')
70
  print('=' * 70)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  print('=' * 70)
73
  print('Done!')
 
144
  print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
145
  print('=' * 70)
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  print('=' * 70)
148
  print('Loading Ultimate MIDI Classifier labels...')
149
  print('=' * 70)
150
+ good_chords_chunks = TMIDIX.Tegridy_Any_Pickle_File_Reader('pitches_chords_progressions_5_3_15')
151
+
 
 
152
  print('=' * 70)
153
  print('Done!')
154
  print('=' * 70)
155
 
156
  app = gr.Blocks()
157
  with app:
158
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Chords Progressions Generator</h1>")
159
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Generate unique chords progressions</h1>")
160
  gr.Markdown(
161
+ "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Chords-Progressions-Generator&style=flat)\n\n"
162
+ "This is a demo for Tegridy MIDI Dataset\n\n"
163
+ "Check out [Tegridy MIDI Dataset](https://github.com/asigalov61/Tegridy-MIDI-Dataset) on GitHub!\n\n"
164
  "[Open In Colab]"
165
+ "(https://colab.research.google.com/github/asigalov61/Tegridy-MIDI-Dataset/blob/master/Chords-Progressions/Pitches_Chords_Progressions_Generator.ipynb)"
166
+ " for all options, faster execution and endless generation"
167
  )
168
 
169
+ gr.Markdown("## Select generation options")
170
+
 
 
171
  input_sampling_resolution = gr.Slider(1, 5, value=2, step=1, label="Classification sampling resolution")
172
 
173
  run_btn = gr.Button("classify", variant="primary")
174
 
175
+ gr.Markdown("## Generation results")
176
 
177
  output_midi_cls_summary = gr.Textbox(label="MIDI classification results")
178
 
179
  run_event = run_btn.click(ClassifyMIDI, [input_midi, input_sampling_resolution],
180
  [output_midi_cls_summary])
181
+
 
 
 
 
 
 
 
 
 
 
 
182
  app.queue().launch()