wav2vec2-xls-r-300m-emotion-ru / emotion_recognition_pipeline.py
KELONMYOSA's picture
Update emotion_recognition_pipeline.py
85d2768
raw
history blame
1.52 kB
import librosa
import requests
import torch
import torch.nn.functional as F
from transformers import AudioClassificationPipeline, AutoConfig, Wav2Vec2Processor
from io import BytesIO
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name_or_path = "KELONMYOSA/wav2vec2-xls-r-300m-emotion-ru"
config = AutoConfig.from_pretrained(model_name_or_path)
processor = Wav2Vec2Processor.from_pretrained(model_name_or_path)
sampling_rate = processor.feature_extractor.sampling_rate
class SpeechEmotionRecognitionPipeline(AudioClassificationPipeline):
def _sanitize_parameters(self, **pipeline_parameters):
return {}, {}, {}
def preprocess(self, inputs):
if isinstance(inputs, str):
if inputs.startswith("http://") or inputs.startswith("https://"):
inputs = BytesIO(requests.get(inputs).content)
else:
inputs = open(inputs, "rb")
speech, sr = librosa.load(inputs, sr=sampling_rate)
features = processor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
return features.input_values.to(device)
def _forward(self, model_inputs):
return self.model(model_inputs)
def postprocess(self, model_outputs):
logits = model_outputs.logits
scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
outputs = [{"label": config.id2label[i], "score": round(score, 5)} for i, score in
enumerate(scores)]
return outputs