import pandas as pd from langchain.document_loaders import PyPDFLoader from langchain.text_splitter import CharacterTextSplitter import torch from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor from pathlib import Path def make_descriptions(file, title): if Path(file).suffix == '.csv': # print(file) df = pd.read_csv(file) print(df.head()) columns = list(df.columns) print(columns) table_description0 = { 'path': 'random', 'number': 1, 'columns': ["clothes", "animals", "students"], 'title': "fashionable student clothes" } table_description1 = { 'path': file, 'number': 2, 'columns': columns, 'title': title } table_descriptions = [table_description0, table_description1] return table_descriptions else: file_description = { 'path': file, 'number': 1, 'title': title } file_descriptions = [file_description] return file_descriptions def make_documents(pdf): loader = PyPDFLoader(pdf) documents = loader.load() text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=0, separator='\n') documents = text_splitter.split_documents(documents) return documents class Matcha_model: def __init__(self) -> None: # torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/20294671002019.png', 'chart_example.png') # torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/multi_col_1081.png', 'chart_example_2.png') # torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/18143564004789.png', 'chart_example_3.png') # torch.hub.download_url_to_file('https://sharkcoder.com/files/article/matplotlib-bar-plot.png', 'chart_example_4.png') self.model_name = "google/matcha-chartqa" self.model = Pix2StructForConditionalGeneration.from_pretrained(self.model_name) self.processor = Pix2StructProcessor.from_pretrained(self.model_name) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) def _filter_output(self, output): return output.replace("<0x0A>", "") def chart_qa(self, image, question: str) -> str: inputs = self.processor(images=image, text=question, return_tensors="pt").to(self.device) predictions = self.model.generate(**inputs, max_new_tokens=512) return self._filter_output(self.processor.decode(predictions[0], skip_special_tokens=True))