QA / app.py
kwojtasik's picture
Update app.py
1b276b4
raw
history blame contribute delete
No virus
1.61 kB
import os
import gradio as gr
from transformers import T5ForConditionalGeneration, AutoTokenizer
# from transformers import pipeline
auth_token = os.environ.get("CLARIN_KNEXT")
model_name = "clarin-knext/plt5-large-poquad" # "clarin-knext/plt5-large-poquad-ext-qa-autotoken"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
model = T5ForConditionalGeneration.from_pretrained(model_name, use_auth_token=auth_token)
default_generate_kwargs = {
"max_length": 192,
"num_beams": 2,
"length_penalty": 0,
"early_stopping": True,
}
# keywords_pipe = pipeline(model=model, tokenizer=tokenizer, **default_generate_kwargs)
examples = [
["Jakie miasto jest stolicą Polski?", "Polska ma wiele wspaniałych miast, Wrocław, Poznań czy Gdańsk. Jednak stolicą jest Warszawa."]]
def generate(question, context):
context = f"question: {question} context: {context} </s>"
inputs = tokenizer(
context,
max_length=512,
add_special_tokens=True,
truncation=True,
padding=False,
return_tensors="pt"
)
outs = model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
**default_generate_kwargs
)
prediction = tokenizer.decode(outs[0], skip_special_tokens=True)
return prediction
demo = gr.Interface(
fn=generate,
inputs=[gr.Textbox(lines=1, label="Question"), gr.Textbox(lines=5, label="Context")],
outputs=gr.Textbox(label="Answer"),
examples=examples,
)
demo.launch()