Spaces:
Runtime error
Runtime error
File size: 4,045 Bytes
354fa5d 57f9a5f c4d8ca4 354fa5d 57f9a5f 354fa5d 57f9a5f 354fa5d 57f9a5f 354fa5d 57f9a5f 354fa5d 57f9a5f 354fa5d 57f9a5f 354fa5d 57f9a5f 354fa5d 57f9a5f 354fa5d 57f9a5f 354fa5d 57f9a5f 354fa5d c4d8ca4 354fa5d 57f9a5f 354fa5d 57f9a5f 354fa5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from huggingface_hub import from_pretrained_keras
import numpy as np
import gradio as gr
import transformers
import tensorflow as tf
class BertSemanticDataGenerator(tf.keras.utils.Sequence):
"""Generates batches of data."""
def __init__(
self,
sentence_pairs,
labels,
batch_size=32,
shuffle=True,
include_targets=True,
):
self.sentence_pairs = sentence_pairs
self.labels = labels
self.shuffle = shuffle
self.batch_size = batch_size
self.include_targets = include_targets
# Load our BERT Tokenizer to encode the text.
# We will use base-base-uncased pretrained model.
self.tokenizer = transformers.BertTokenizer.from_pretrained(
"bert-base-uncased", do_lower_case=True
)
self.indexes = np.arange(len(self.sentence_pairs))
self.on_epoch_end()
def __len__(self):
# Denotes the number of batches per epoch.
return len(self.sentence_pairs) // self.batch_size
def __getitem__(self, idx):
# Retrieves the batch of index.
indexes = self.indexes[idx * self.batch_size : (idx + 1) * self.batch_size]
sentence_pairs = self.sentence_pairs[indexes]
# With BERT tokenizer's batch_encode_plus batch of both the sentences are
# encoded together and separated by [SEP] token.
encoded = self.tokenizer.batch_encode_plus(
sentence_pairs.tolist(),
add_special_tokens=True,
max_length=128,
return_attention_mask=True,
return_token_type_ids=True,
pad_to_max_length=True,
return_tensors="tf",
)
# Convert batch of encoded features to numpy array.
input_ids = np.array(encoded["input_ids"], dtype="int32")
attention_masks = np.array(encoded["attention_mask"], dtype="int32")
token_type_ids = np.array(encoded["token_type_ids"], dtype="int32")
# Set to true if data generator is used for training/validation.
if self.include_targets:
labels = np.array(self.labels[indexes], dtype="int32")
return [input_ids, attention_masks, token_type_ids], labels
else:
return [input_ids, attention_masks, token_type_ids]
model = from_pretrained_keras("keras-io/bert-semantic-similarity")
labels = ["contradiction", "entailment", "neutral"]
def predict(*sentences):
if len(sentences) != 6:
return {'error': 'Se esperan 6 oraciones'}
sentence_pairs = np.array([[str(sentences[i]), str(expected_responses[i])] for i in range(6)])
test_data = BertSemanticDataGenerator(
sentence_pairs, labels=None, batch_size=1, shuffle=False, include_targets=False,
)
probs = model.predict(test_data[0])[0]
labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)}
return labels_probs
expected_responses = [
'respuesta1a', 'respuesta2a', 'respuesta3a', 'respuesta4a', 'respuesta5a', 'respuesta6a'
]
examples = [
["Two women are observing something together.", "respuesta1a"],
["A smiling costumed woman is holding an umbrella", "respuesta2a"],
["A soccer game with multiple males playing", "respuesta3a"],
["Some men are playing a sport", "respuesta4a"],
["Another example sentence", "respuesta5a"],
["One more example for the sixth input", "respuesta6a"]
]
# Interfaz Gradio
gr.Interface(
fn=predict,
title="Semantic Similarity with BERT",
description="Natural Language Inference by fine-tuning BERT model on SNLI Corpus 📰",
inputs=[gr.Textbox(label=f"Input {i+1}") for i in range(6)],
examples=examples,
outputs=gr.outputs.Label(num_top_classes=3, label='Semantic similarity'),
cache_examples=False,
article="Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the keras example from <a href=\"https://keras.io/examples/nlp/semantic_similarity_with_bert/\">Mohamad Merchant</a>",
).launch(debug=True, enable_queue=True)
|