talk-to-data / src /inference.py
RohitMidha23
return table also
4a19d8a
raw
history blame contribute delete
No virus
1.95 kB
from transformers import TapasTokenizer, TapasForQuestionAnswering
import pandas as pd
from typing import List, Dict
from src.constants import id2aggregation
def infer(query: str, file_name: str, model_name: str="google/tapas-base-finetuned-wtq") -> Dict[str, str]:
# Load the file
table = pd.read_csv(file_name, delimiter=",")
table = table.astype(str)
# Load the model
model = TapasForQuestionAnswering.from_pretrained(model_name)
tokenizer = TapasTokenizer.from_pretrained(model_name)
# Make predictions
queries = [query]
inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
outputs = model(**inputs)
predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
inputs, outputs.logits.detach(), outputs.logits_aggregation.detach()
) # predicted_answer_coordinates: contains coordinates for the respective answer cells, predicted_aggregation_indices: contains the aggregation type for each query
aggregation_predictions_string = [id2aggregation[x] for x in predicted_aggregation_indices]
answers = []
for coordinates in predicted_answer_coordinates:
if len(coordinates) == 1:
# only a single cell:
answers.append(table.iat[coordinates[0]])
else:
# multiple cells
cell_values = []
for coordinate in coordinates:
cell_values.append(table.iat[coordinate])
answers.append(", ".join(cell_values))
# Create the answer string
answer_str = ""
for query, answer, predicted_agg in zip(queries, answers, aggregation_predictions_string):
if predicted_agg == "NONE":
answer_str = answer
else:
answer_str = f"{predicted_agg} : {answer}"
return {
"query": query,
"answer": answer_str
}, table