test02 / model.py
manishgupta006's picture
Update model.py
9789b35 verified
raw
history blame contribute delete
No virus
3.08 kB
import requests
from langchain.chat_models import ChatOpenAI #model server
from langchain_groq import ChatGroq
from langchain.chains import LLMChain
from langchain.prompts import (
PromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
ChatPromptTemplate,
)
from config import app_config
import mongo_utils as mongo
GROQ_API_KEY = "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"
def __image2text(image):
"""Generates a short description of the image"""
headers = {"Authorization": app_config.HF_TOKEN}
try:
response = requests.post(app_config.I2T_API_URL, headers=headers, data=image)
response = response.json()[0]["generated_text"]
except Exception as e:
print(e)
return response
def __text2story(image_desc, genre, style, word_count, creativity):
""" "Generates a short story based on image description text prompt"""
## chat LLM model
# story_model = ChatOpenAI(
# model="gpt-3.5-turbo",
# openai_api_key=app_config.OPENAI_KEY,
# temperature=creativity,
# )
story_model = ChatGroq(model="llama3-8b-8192",
temperature=0.0,
api_key=GROQ_API_KEY)
## chat message prompts
sys_prompt = PromptTemplate(
template="""You are an expert story writer, write a maximum of {word_count}
words long story in {genre} genre in {style} writing style, based on the user
provided story-context.
""",
input_variables=["word_count", "genre", "style"],
)
system_msg_prompt = SystemMessagePromptTemplate(prompt=sys_prompt)
human_prompt = PromptTemplate(
template="story-context: {context}", input_variables=["context"]
)
human_msg_prompt = HumanMessagePromptTemplate(prompt=human_prompt)
chat_prompt = ChatPromptTemplate.from_messages(
[system_msg_prompt, human_msg_prompt]
)
## LLM chain
story_chain = LLMChain(llm=story_model, prompt=chat_prompt)
response = story_chain.run(
genre=genre, style=style, word_count=word_count, context=image_desc
)
return response
def generate_story(image_file, genre, style, word_count, creativity):
"""Generates a story given an image"""
# read image as bytes arrayS
with open(image_file, "rb") as f:
input_image = f.read()
# generate caption for image
image_desc = __image2text(image=input_image)
print("++++++++++++++++++++++++++++++++++++++")
print(image_desc)
print("++++++++++++++++++++++++++++++++++++++")
# generate story from caption
story = __text2story(
image_desc=image_desc,
genre=genre,
style=style,
word_count=word_count,
creativity=creativity,
)
# increment the openai access counter and compute count stats
mongo.increment_curr_access_count()
max_count = app_config.openai_max_access_count
curr_count = app_config.openai_curr_access_count
available_count = max_count - curr_count
return story, max_count, curr_count, available_count