ilia_khristoforov
На ветке pr/5
304e51f
raw
history blame
7.45 kB
import langchain
from langchain.agents import create_csv_agent
from langchain.schema import HumanMessage
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from typing import List, Dict
from langchain.agents import AgentType
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
from utils.functions import Matcha_model
from PIL import Image
from pathlib import Path
from langchain.tools import StructuredTool
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
class Bot:
def __init__(
self,
openai_api_key: str,
file_descriptions: List[Dict[str, any]],
text_documents: List[langchain.schema.Document],
verbose: bool = False
):
self.verbose = verbose
self.file_descriptions = file_descriptions
self.llm = ChatOpenAI(
openai_api_key=openai_api_key,
temperature=0,
model_name="gpt-3.5-turbo"
)
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
# embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
vector_store = Chroma.from_documents(text_documents, embedding_function)
self.text_retriever = langchain.chains.RetrievalQAWithSourcesChain.from_chain_type(
llm=self.llm,
chain_type='stuff',
retriever=vector_store.as_retriever()
)
self.text_search_tool = langchain.agents.Tool(
func=self._text_search,
description="Use this tool when searching for text information",
name="search text information"
)
self.chart_model = Matcha_model()
def __call__(
self,
question: str
):
self.tools = []
self.tools.append(self.text_search_tool)
file = self._define_appropriate_file(question)
if file != "None of the files":
number = int(file[file.find('№')+1:])
file_description = [x for x in self.file_descriptions if x['number'] == number][0]
file_path = file_description['path']
if Path(file).suffix == '.csv':
self.csv_agent = create_csv_agent(
llm=self.llm,
path=file_path,
verbose=self.verbose
)
self._init_tabular_search_tool(file_description)
self.tools.append(self.tabular_search_tool)
else:
self._init_chart_search_tool(file_description)
self.tools.append(self.chart_search_tool)
self._init_chatbot()
# print(file)
response = self.agent(question)
return response
def _init_chatbot(self):
conversational_memory = ConversationBufferWindowMemory(
memory_key='chat_history',
k=5,
return_messages=True
)
self.agent = langchain.agents.initialize_agent(
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
tools=self.tools,
llm=self.llm,
verbose=self.verbose,
max_iterations=5,
early_stopping_method='generate',
memory=conversational_memory
)
sys_msg = (
"You are an expert summarizer and deliverer of information. "
"Yet, the reason you are so intelligent is that you make complex "
"information incredibly simple to understand. It's actually rather incredible."
"When users ask information you refer to the relevant tools."
"if one of the tools helped you with only a part of the necessary information, you must "
"try to find the missing information using another tool"
"if you can't find the information using the provided tools, you MUST "
"say 'I don't know'. Don't try to make up an answer."
)
prompt = self.agent.agent.create_prompt(
tools=self.tools,
prefix = sys_msg
)
self.agent.agent.llm_chain.prompt = prompt
def _text_search(
self,
query: str
) -> str:
query = self.text_retriever.prep_inputs(query)
res = self.text_retriever(query)['answer']
return res
def _tabular_search(
self,
query: str
) -> str:
res = self.csv_agent.run(query)
return res
def _chart_search(
self,
image,
query: str
) -> str:
image = Image.open(image)
res = self.chart_model.chart_qa(image, query)
return res
def _init_chart_search_tool(
self,
title: str
) -> None:
title = title
description = f"""
Use this tool when searching for information on charts.
With this tool you can answer the question about related chart.
You should ask simple question about a chart, then the tool will give you number.
This chart is called {title}.
"""
self.chart_search_tool = StructuredTool(
func=self._chart_search,
description=description,
name="Ask over charts"
)
def _init_tabular_search_tool(
self,
file_: Dict[str, any]
) -> None:
description = f"""
Use this tool when searching for tabular information.
With this tool you could get access to table.
This table title is "{title}" and the names of the columns in this table: {columns}
"""
self.tabular_search_tool = langchain.agents.Tool(
func=self._tabular_search,
description=description,
name="search tabular information"
)
def _define_appropriate_file(
self,
question: str
) -> str:
''' Определяет по описаниям таблиц в какой из них может содержаться ответ на вопрос.
Возвращает номер таблицы по шаблону "Table №1" или "None of the tables" '''
message = 'I have list of descriptions: \n'
k = 0
for description in self.file_descriptions:
k += 1
str_description = f""" {k}) description for File №{description['number']}: """
for key, value in description.items():
string_val = str(key) + ' : ' + str(value) + '\n'
str_description += string_val
message += str_description
print(message)
question = f""" How do you think, which file can help answer the question: "{question}" .
Your answer MUST be specific,
for example if you think that File №2 can help answer the question, you MUST just write "File №2!".
If you think that none of the files can help answer the question just write "None of the files!"
Don't include to answer information about your thinking.
"""
message += question
res = self.llm([HumanMessage(content=message)])
print(res.content)
print(res.content[:-1])
return res.content[:-1]