michalilski's picture
nlg models removal
dc7ce01
raw
history blame
No virus
3.22 kB
import os
from typing import Any, Dict
from transformers import (Pipeline, T5ForConditionalGeneration, T5Tokenizer,
pipeline)
auth_token = os.environ.get("CLARIN_KNEXT")
DEFAULT_DST_INPUTS: Dict[str, str] = {
"polish": (
"[U] Chciałbym zarezerwować stolik na 4 osoby na piątek o godzinie 18:30. "
"[Dziedzina] Restauracje: Popularna usługa wyszukiwania i rezerwacji restauracji "
"[Atrybut] Czas: Wstępny czas rezerwacji restauracji"
),
"english": (
"[U] I want to book a table for 4 people on Friday, 6:30 pm. "
"[Domain] Restaurants: A popular restaurant search and reservation service "
"[Slot] Time: Tentative time of restaurant reservation"
),
}
DST_MODELS: Dict[str, Dict[str, Any]] = {
"plt5-small": {
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token),
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-small-dst", use_auth_token=auth_token),
"default_input": DEFAULT_DST_INPUTS["polish"],
},
"plt5-base": {
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token),
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-dst", use_auth_token=auth_token),
"default_input": DEFAULT_DST_INPUTS["polish"],
},
"plt5-base-poquad-dst-v2": {
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token),
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-base-poquad-dst-v2", use_auth_token=auth_token),
"default_input": DEFAULT_DST_INPUTS["polish"],
},
"t5-small": {
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token),
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-small-dst", use_auth_token=auth_token),
"default_input": DEFAULT_DST_INPUTS["english"],
},
"t5-base": {
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token),
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/t5-base-dst", use_auth_token=auth_token),
"default_input": DEFAULT_DST_INPUTS["english"],
},
"flant5-small [EN/PL]": {
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token),
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-small-dst", use_auth_token=auth_token),
"default_input": DEFAULT_DST_INPUTS["english"],
},
"flant5-base [EN/PL]": {
"model": T5ForConditionalGeneration.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token),
"tokenizer": T5Tokenizer.from_pretrained("clarin-knext/flant5-base-dst", use_auth_token=auth_token),
"default_input": DEFAULT_DST_INPUTS["english"],
},
}
PIPELINES: Dict[str, Pipeline] = {
model_name: pipeline(
"text2text-generation", model=DST_MODELS[model_name]["model"], tokenizer=DST_MODELS[model_name]["tokenizer"]
)
for model_name in DST_MODELS
}