docker01 / app.py
juanpablosanchez's picture
update swagger
f1929cd
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
from threading import Thread
import uvicorn
import requests
# Configurar FastAPI
app = FastAPI()
# Cargar el modelo y el tokenizador
model_name = "mdarhri00/named-entity-recognition"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
class TextInput(BaseModel):
text: str
@app.post("/predict")
async def predict(input: TextInput):
text = input.text
# Tokenizar el texto
inputs = tokenizer(text, return_tensors="pt")
# Realizar la inferencia
with torch.no_grad():
outputs = model(**inputs)
# Procesar los resultados
logits = outputs.logits
predictions = torch.argmax(logits, dim=2)
# Mapear etiquetas
id2label = model.config.id2label
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
entities = [{"token": token, "label": id2label[prediction.item()]} for token, prediction in zip(tokens, predictions[0])]
return {"entities": entities}
# Iniciar el servidor de FastAPI en un hilo separado
def start_api():
uvicorn.run(app, host="0.0.0.0", port=8000)
api_thread = Thread(target=start_api, daemon=True)
api_thread.start()
# Configurar Gradio
def predict_gradio(text):
response = requests.post("https://asmalljob-docker01.hf.space/predict", json={"text": text}) # Asegúrate de que esta URL es correcta
entities = response.json().get("entities", [])
return entities
demo = gr.Interface(fn=predict_gradio, inputs="text", outputs="json")
demo.launch(share=True)