ksvmuralidhar's picture
Update app.py
4cb94f8 verified
raw
history blame
No virus
3.12 kB
from sentence_transformers import SentenceTransformer
import os
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
import streamlit as st
class TextVectorizer:
'''
sentence transformers to extract sentence embeddings
'''
def vectorize(self, x):
sen_embeddings = sent_model.encode(x)
return sen_embeddings
def get_milvus_collection():
uri = os.environ.get("URI")
token = os.environ.get("TOKEN")
connections.connect("default", uri=uri, token=token)
print(f"Connected to DB")
collection_name = os.environ.get("COLLECTION_NAME")
collection = Collection(name=collection_name)
collection.load()
return collection
def find_similar_news(text: str, top_n: int=5):
search_params = {"metric_type": "L2"}
search_vec = vectorizer.vectorize(text)
result = collection.search([search_vec],
anns_field='article_embed', # annotations field specified in the schema definition
param=search_params,
limit=top_n,
guarantee_timestamp=1,
output_fields=['article_desc', 'article_category']) # which fields to return in output
output_dict = {"input_text": text, "similar_texts": [hit.entity.get('article_desc') for hits in result for hit in hits],
"text_category": [hit.entity.get('article_category') for hits in result for hit in hits]}
txt_category = [f'<li><b>{txt}</b> (<i>{cat}</i>)</li>' for txt, cat in zip(output_dict.get('similar_texts'), output_dict.get('text_category'))]
similar_txt = ''.join(txt_category)
return f"<h4>Similar News Articles</h4><ol>{similar_txt}</ol>"
# return output_dict
vectorizer = TextVectorizer()
collection = get_milvus_collection()
sent_model = SentenceTransformer('all-mpnet-base-v2')
def main():
# st.title("Find Similar News")
st.markdown("<h3>Find Similar News With Sentence Transformers (all-mpnet-base-v2)</h3>", unsafe_allow_html=True)
desc = '''<p style="font-size: 13px;">
Embeddings of 300,000 news headlines are stored in Milvus vector database, used as a feature store.
Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2).
Similar news headlines are retrieved from the vector database using Euclidean distance as similarity metric.
<span style="color: red;">This method (all-mpnet-base-v2) has the best performance compared to multi-qa-distilbert-cos-v1 fine-tuned using TSDAE
and extracting embeddings from fine-tuned DistilBERT classifier.</span>
</p>
'''
st.markdown(desc, unsafe_allow_html=True)
news_txt = st.text_area("Paste the headline of a news article:", "", height=50)
top_n = st.slider('Select number of similar articles to display', 1, 100, 10)
if st.button("Submit"):
result = find_similar_news(news_txt, top_n)
# st.write(result)
st.markdown(result, unsafe_allow_html=True)
if __name__ == "__main__":
main()