Kamtera commited on
Commit
056c529
1 Parent(s): 020636d

Create new file

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torchaudio
3
+ from datasets import load_dataset, load_metric
4
+ from transformers import (
5
+ Wav2Vec2ForCTC,
6
+ Wav2Vec2Processor,
7
+ AutoTokenizer,
8
+ AutoModelWithLMHead
9
+ )
10
+ import torch
11
+ import re
12
+ import sys
13
+ import soundfile as sf
14
+ from utils import SpeechRecognition
15
+ sp = SpeechRecognition()
16
+ sp.load_model()
17
+
18
+
19
+
20
+ model_name = "voidful/wav2vec2-xlsr-multilingual-56"
21
+ device = "cuda"
22
+ processor_name = "voidful/wav2vec2-xlsr-multilingual-56"
23
+
24
+ import pickle
25
+ with open("lang_ids.pk", 'rb') as output:
26
+ lang_ids = pickle.load(output)
27
+
28
+ model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
29
+ processor = Wav2Vec2Processor.from_pretrained(processor_name)
30
+
31
+ model.eval()
32
+
33
+ def load_file_to_data(file,sampling_rate=16_000):
34
+ batch = {}
35
+ speech, _ = torchaudio.load(file)
36
+ if sampling_rate != '16_000' or sampling_rate != '16000':
37
+ resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16_000)
38
+ batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
39
+ batch["sampling_rate"] = resampler.new_freq
40
+ else:
41
+ batch["speech"] = speech.squeeze(0).numpy()
42
+ batch["sampling_rate"] = '16000'
43
+ return batch
44
+
45
+
46
+ def predict(data):
47
+ data=load_file_to_data(data,sampling_rate='16_000')
48
+ features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
49
+ input_values = features.input_values.to(device)
50
+ attention_mask = features.attention_mask.to(device)
51
+ with torch.no_grad():
52
+ logits = model(input_values, attention_mask=attention_mask).logits
53
+ decoded_results = []
54
+ for logit in logits:
55
+ pred_ids = torch.argmax(logit, dim=-1)
56
+ mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size())
57
+ vocab_size = logit.size()[-1]
58
+ voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
59
+ comb_pred_ids = torch.argmax(voice_prob, dim=-1)
60
+ decoded_results.append(processor.decode(comb_pred_ids))
61
+
62
+ return decoded_results
63
+
64
+ def predict_lang_specific(data,lang_code):
65
+ data=load_file_to_data(data,sampling_rate='16_000')
66
+ features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
67
+ input_values = features.input_values.to(device)
68
+ attention_mask = features.attention_mask.to(device)
69
+ with torch.no_grad():
70
+ logits = model(input_values, attention_mask=attention_mask).logits
71
+ decoded_results = []
72
+ for logit in logits:
73
+ pred_ids = torch.argmax(logit, dim=-1)
74
+ mask = ~pred_ids.eq(processor.tokenizer.pad_token_id).unsqueeze(-1).expand(logit.size())
75
+ vocab_size = logit.size()[-1]
76
+ voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
77
+ filtered_input = pred_ids[pred_ids!=processor.tokenizer.pad_token_id].view(1,-1).to(device)
78
+ if len(filtered_input[0]) == 0:
79
+ decoded_results.append("")
80
+ else:
81
+ lang_mask = torch.empty(voice_prob.shape[-1]).fill_(0)
82
+ lang_index = torch.tensor(sorted(lang_ids[lang_code]))
83
+ lang_mask.index_fill_(0, lang_index, 1)
84
+ lang_mask = lang_mask.to(device)
85
+ comb_pred_ids = torch.argmax(lang_mask*voice_prob, dim=-1)
86
+ decoded_results.append(processor.decode(comb_pred_ids))
87
+
88
+ return decoded_results
89
+
90
+ def recognition(audio_file):
91
+ print("audio_file", audio_file.name)
92
+ speech, rate = sp.load_speech_with_file(audio_file.name)
93
+
94
+ result = sp.predict_audio_file(speech)
95
+ print(result)
96
+
97
+ return result
98
+
99
+ #predict(load_file_to_data('audio file path',sampling_rate=16_000)) # beware of the audio file sampling rate
100
+
101
+ #predict_lang_specific(load_file_to_data('audio file path',sampling_rate=16_000),'en') # beware of the audio file sampling rate
102
+ with gr.Blocks() as demo:
103
+ gr.Markdown("multilingual Speech Recognition")
104
+ with gr.Tab("Auto"):
105
+ gr.Markdown("automatically detects your language")
106
+ inputs_speech =gr.Audio(source="upload", type="filepath", optional=True)
107
+ output_transcribe = gr.HTML(label="")
108
+ transcribe_audio= gr.Button("Submit")
109
+ with gr.Tab("manual"):
110
+ gr.Markdown("set your speech language")
111
+ inputs_speech1 =[
112
+ gr.Audio(source="upload", type="filepath"),
113
+ gr.Dropdown(choices=["ar","as","br","ca","cnh","cs","cv","cy","de","dv","el","en","eo","es","et","eu","fa","fi","fr","fy-NL","ga-IE","hi","hsb","hu","ia","id","it","ja","ka","ky","lg","lt","lv","mn","mt","nl","or","pa-IN","pl","pt","rm-sursilv","rm-vallader","ro","ru","sah","sl","sv-SE","ta","th","tr","tt","uk","vi","zh-CN","zh-HK","zh-TW"]
114
+ ,value="fa",label="language code")
115
+ ]
116
+ output_transcribe1 = gr.Textbox(label="output")
117
+ transcribe_audio1= gr.Button("Submit")
118
+ with gr.Tab("Auto1"):
119
+ gr.Markdown("automatically detects your language")
120
+ inputs_speech2 = gr.Audio(label="Input Audio", type="file")
121
+ output_transcribe2 = gr.Textbox()
122
+ transcribe_audio2= gr.Button("Submit")
123
+ transcribe_audio.click(fn=predict,
124
+ inputs=inputs_speech,
125
+ outputs=output_transcribe)
126
+
127
+ transcribe_audio1.click(fn=predict_lang_specific,
128
+ inputs=inputs_speech1 ,
129
+ outputs=output_transcribe1 )
130
+
131
+ transcribe_audio2.click(fn=recognition,
132
+ inputs=inputs_speech2 ,
133
+ outputs=output_transcribe2 )
134
+
135
+
136
+
137
+ if __name__ == "__main__":
138
+ demo.launch(share=True)