File size: 7,451 Bytes
304e51f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import langchain
from langchain.agents import create_csv_agent
from langchain.schema import HumanMessage
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from typing import List, Dict
from langchain.agents import AgentType
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
from utils.functions import Matcha_model
from PIL import Image
from pathlib import Path
from langchain.tools import StructuredTool
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings

class Bot:

    def __init__(
            self,
            openai_api_key: str,
            file_descriptions: List[Dict[str, any]],
            text_documents: List[langchain.schema.Document],
            verbose: bool = False
    ):
        self.verbose = verbose
        self.file_descriptions = file_descriptions

        self.llm = ChatOpenAI(
            openai_api_key=openai_api_key,
            temperature=0,
            model_name="gpt-3.5-turbo"
        )
        embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
        # embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
        vector_store = Chroma.from_documents(text_documents, embedding_function)
        self.text_retriever = langchain.chains.RetrievalQAWithSourcesChain.from_chain_type(
            llm=self.llm,
            chain_type='stuff',
            retriever=vector_store.as_retriever()
        )
        self.text_search_tool = langchain.agents.Tool(
            func=self._text_search,
            description="Use this tool when searching for text information",
            name="search text information"
        )

        self.chart_model = Matcha_model()
        
    def __call__(
            self,
            question: str
    ):
        self.tools = []
        self.tools.append(self.text_search_tool)
        file = self._define_appropriate_file(question)
        if file != "None of the files":
            number = int(file[file.find('№')+1:])
            file_description = [x for x in self.file_descriptions if x['number'] == number][0]
            file_path = file_description['path']
            
            if Path(file).suffix == '.csv':                
                self.csv_agent = create_csv_agent(
                    llm=self.llm,
                    path=file_path,
                    verbose=self.verbose
                )
                
                self._init_tabular_search_tool(file_description)
                self.tools.append(self.tabular_search_tool)
            
            else:
                self._init_chart_search_tool(file_description)
                self.tools.append(self.chart_search_tool)
                
        self._init_chatbot()
        # print(file)
        response = self.agent(question)
        return response

    def _init_chatbot(self):

        conversational_memory = ConversationBufferWindowMemory(
            memory_key='chat_history',
            k=5,
            return_messages=True
        )

        self.agent = langchain.agents.initialize_agent(
            agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
            tools=self.tools,
            llm=self.llm,
            verbose=self.verbose,
            max_iterations=5,
            early_stopping_method='generate',
            memory=conversational_memory
        )
        sys_msg = (
            "You are an expert summarizer and deliverer of information. "
            "Yet, the reason you are so intelligent is that you make complex "
            "information incredibly simple to understand. It's actually rather incredible."
            "When users ask information you refer to the relevant tools."
            "if one of the tools helped you with only a part of the necessary information, you must "
            "try to find the missing information using another tool"
            "if you can't find the information using the provided tools, you MUST "
            "say 'I don't know'. Don't try to make up an answer."
        )
        prompt = self.agent.agent.create_prompt(
            tools=self.tools,
            prefix = sys_msg
        )
        self.agent.agent.llm_chain.prompt = prompt

    def _text_search(
            self,
            query: str
    ) -> str:
        query = self.text_retriever.prep_inputs(query)
        res = self.text_retriever(query)['answer']
        return res

    def _tabular_search(
            self,
            query: str
    ) -> str:
        res = self.csv_agent.run(query)
        return res
    
    def _chart_search(
        self,
        image, 
        query: str
    ) -> str:
        image = Image.open(image)
        res = self.chart_model.chart_qa(image, query)
        return res
    
    def _init_chart_search_tool(
        self,
        title: str
    ) -> None:
        title = title
        description = f"""
            Use this tool when searching for information on charts.
            With this tool you can answer the question about related chart. 
            You should ask simple question about a chart, then the tool will give you number.
            This chart is called {title}.
        """
        
        self.chart_search_tool = StructuredTool(
            func=self._chart_search,
            description=description,
            name="Ask over charts"
        )

    def _init_tabular_search_tool(
            self,
            file_: Dict[str, any]
    ) -> None:


        description = f"""
            Use this tool when searching for tabular information.
            With this tool you could get access to table.
            This table title is "{title}" and the names of the columns in this table: {columns} 
        """

        self.tabular_search_tool = langchain.agents.Tool(
            func=self._tabular_search,
            description=description,
            name="search tabular information"
        )

    def _define_appropriate_file(
            self,
            question: str
    ) -> str:
        ''' Определяет по описаниям таблиц в какой из них может содержаться ответ на вопрос.
        Возвращает номер таблицы по шаблону "Table №1" или "None of the tables" '''

        message = 'I have list of descriptions: \n'
        k = 0
        
        for description in self.file_descriptions:
            k += 1
            str_description = f"""  {k}) description for File №{description['number']}: """
            for key, value in description.items():
                string_val = str(key) + ' : ' + str(value) + '\n'
                str_description += string_val
            message += str_description
        print(message)
        question = f""" How do you think, which file can help answer the question: "{question}" .
        Your answer MUST be specific, 
        for example if you think that File №2 can help answer the question, you MUST just write  "File №2!". 
        If you think that none of the files can help answer the question just write "None of the files!"
        Don't include to answer information about your thinking.
        """
        message += question

        res = self.llm([HumanMessage(content=message)])
        print(res.content)
        print(res.content[:-1])
        return res.content[:-1]