ksvmuralidhar commited on
Commit
daf30d2
1 Parent(s): ab21b59

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import os
3
+ from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
4
+ import streamlit as st
5
+
6
+
7
+ class TextVectorizer:
8
+ '''
9
+ sentence transformers to extract sentence embeddings
10
+ '''
11
+ def vectorize(self, x):
12
+ sen_embeddings = sent_model.encode(x)
13
+ return sen_embeddings
14
+
15
+ def get_milvus_collection():
16
+ uri = os.environ.get("URI")
17
+ token = os.environ.get("TOKEN")
18
+ connections.connect("default", uri=uri, token=token)
19
+ print(f"Connected to DB")
20
+ collection_name = os.environ.get("COLLECTION_NAME")
21
+ collection = Collection(name=collection_name)
22
+ collection.load()
23
+ return collection
24
+
25
+ def find_similar_news(text: str, top_n: int=5):
26
+ search_params = {"metric_type": "L2"}
27
+ search_vec = vectorizer.vectorize(text)
28
+ result = collection.search([search_vec],
29
+ anns_field='article_embed', # annotations field specified in the schema definition
30
+ param=search_params,
31
+ limit=top_n,
32
+ guarantee_timestamp=1,
33
+ output_fields=['article_desc', 'article_category']) # which fields to return in output
34
+
35
+
36
+ output_dict = {"input_text": text, "similar_texts": [hit.entity.get('article_desc') for hits in result for hit in hits],
37
+ "text_category": [hit.entity.get('article_category') for hits in result for hit in hits]}
38
+ txt_category = [f'<li><b>{txt}</b> (<i>{cat}</i>)</li>' for txt, cat in zip(output_dict.get('similar_texts'), output_dict.get('text_category'))]
39
+ similar_txt = ''.join(txt_category)
40
+ return f"<h4>Similar news Articles</h4><ol>{similar_txt}</ol>"
41
+ # return output_dict
42
+
43
+
44
+ vectorizer = TextVectorizer()
45
+ collection = get_milvus_collection()
46
+ sent_model = SentenceTransformer('all-mpnet-base-v2')
47
+
48
+ def main():
49
+
50
+ # st.title("Find Similar News")
51
+ st.markdown("<h3>Find Similar News</h3>", unsafe_allow_html=True)
52
+ desc = '''<p style="font-size: 13px;">
53
+ Embeddings of 300,000 news headlines are stored in Milvus vector database, used as a feature store.
54
+ Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2).
55
+ Similar news headlines are retrieved from the vector database using Euclidean distance as similarity metric.
56
+ <span style="color: red;">This method is found to be more accurate and faster compared to the method of extracting embeddings
57
+ from fine-tuned classification model, discussed </span><a href="https://huggingface.co/spaces/ksvmuralidhar/vector-db-search">here.</a>
58
+ </p>
59
+ '''
60
+ st.markdown(desc, unsafe_allow_html=True)
61
+ news_txt = st.text_area("Paste the headline of a news article:", "", height=50)
62
+ top_n = st.slider('Select number of similar articles to display', 1, 100, 10)
63
+
64
+ if st.button("Submit"):
65
+ result = find_similar_news(news_txt, top_n)
66
+ # st.write(result)
67
+ st.markdown(result, unsafe_allow_html=True)
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main()