ksvmuralidhar
commited on
Commit
•
daf30d2
1
Parent(s):
ab21b59
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
import os
|
3 |
+
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
|
4 |
+
import streamlit as st
|
5 |
+
|
6 |
+
|
7 |
+
class TextVectorizer:
|
8 |
+
'''
|
9 |
+
sentence transformers to extract sentence embeddings
|
10 |
+
'''
|
11 |
+
def vectorize(self, x):
|
12 |
+
sen_embeddings = sent_model.encode(x)
|
13 |
+
return sen_embeddings
|
14 |
+
|
15 |
+
def get_milvus_collection():
|
16 |
+
uri = os.environ.get("URI")
|
17 |
+
token = os.environ.get("TOKEN")
|
18 |
+
connections.connect("default", uri=uri, token=token)
|
19 |
+
print(f"Connected to DB")
|
20 |
+
collection_name = os.environ.get("COLLECTION_NAME")
|
21 |
+
collection = Collection(name=collection_name)
|
22 |
+
collection.load()
|
23 |
+
return collection
|
24 |
+
|
25 |
+
def find_similar_news(text: str, top_n: int=5):
|
26 |
+
search_params = {"metric_type": "L2"}
|
27 |
+
search_vec = vectorizer.vectorize(text)
|
28 |
+
result = collection.search([search_vec],
|
29 |
+
anns_field='article_embed', # annotations field specified in the schema definition
|
30 |
+
param=search_params,
|
31 |
+
limit=top_n,
|
32 |
+
guarantee_timestamp=1,
|
33 |
+
output_fields=['article_desc', 'article_category']) # which fields to return in output
|
34 |
+
|
35 |
+
|
36 |
+
output_dict = {"input_text": text, "similar_texts": [hit.entity.get('article_desc') for hits in result for hit in hits],
|
37 |
+
"text_category": [hit.entity.get('article_category') for hits in result for hit in hits]}
|
38 |
+
txt_category = [f'<li><b>{txt}</b> (<i>{cat}</i>)</li>' for txt, cat in zip(output_dict.get('similar_texts'), output_dict.get('text_category'))]
|
39 |
+
similar_txt = ''.join(txt_category)
|
40 |
+
return f"<h4>Similar news Articles</h4><ol>{similar_txt}</ol>"
|
41 |
+
# return output_dict
|
42 |
+
|
43 |
+
|
44 |
+
vectorizer = TextVectorizer()
|
45 |
+
collection = get_milvus_collection()
|
46 |
+
sent_model = SentenceTransformer('all-mpnet-base-v2')
|
47 |
+
|
48 |
+
def main():
|
49 |
+
|
50 |
+
# st.title("Find Similar News")
|
51 |
+
st.markdown("<h3>Find Similar News</h3>", unsafe_allow_html=True)
|
52 |
+
desc = '''<p style="font-size: 13px;">
|
53 |
+
Embeddings of 300,000 news headlines are stored in Milvus vector database, used as a feature store.
|
54 |
+
Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2).
|
55 |
+
Similar news headlines are retrieved from the vector database using Euclidean distance as similarity metric.
|
56 |
+
<span style="color: red;">This method is found to be more accurate and faster compared to the method of extracting embeddings
|
57 |
+
from fine-tuned classification model, discussed </span><a href="https://huggingface.co/spaces/ksvmuralidhar/vector-db-search">here.</a>
|
58 |
+
</p>
|
59 |
+
'''
|
60 |
+
st.markdown(desc, unsafe_allow_html=True)
|
61 |
+
news_txt = st.text_area("Paste the headline of a news article:", "", height=50)
|
62 |
+
top_n = st.slider('Select number of similar articles to display', 1, 100, 10)
|
63 |
+
|
64 |
+
if st.button("Submit"):
|
65 |
+
result = find_similar_news(news_txt, top_n)
|
66 |
+
# st.write(result)
|
67 |
+
st.markdown(result, unsafe_allow_html=True)
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
main()
|