from sentence_transformers import SentenceTransformer import os from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema import streamlit as st class TextVectorizer: ''' sentence transformers to extract sentence embeddings ''' def vectorize(self, x): sen_embeddings = sent_model.encode(x) return sen_embeddings def get_milvus_collection(): uri = os.environ.get("URI") token = os.environ.get("TOKEN") connections.connect("default", uri=uri, token=token) print(f"Connected to DB") collection_name = os.environ.get("COLLECTION_NAME") collection = Collection(name=collection_name) collection.load() return collection def find_similar_news(text: str, top_n: int=5): search_params = {"metric_type": "L2"} search_vec = vectorizer.vectorize(text) result = collection.search([search_vec], anns_field='article_embed', # annotations field specified in the schema definition param=search_params, limit=top_n, guarantee_timestamp=1, output_fields=['article_desc', 'article_category']) # which fields to return in output output_dict = {"input_text": text, "similar_texts": [hit.entity.get('article_desc') for hits in result for hit in hits], "text_category": [hit.entity.get('article_category') for hits in result for hit in hits]} txt_category = [f'
  • {txt} ({cat})
  • ' for txt, cat in zip(output_dict.get('similar_texts'), output_dict.get('text_category'))] similar_txt = ''.join(txt_category) return f"

    Similar News Articles

      {similar_txt}
    " # return output_dict vectorizer = TextVectorizer() collection = get_milvus_collection() sent_model = SentenceTransformer('all-mpnet-base-v2') def main(): # st.title("Find Similar News") st.markdown("

    Find Similar News With Sentence Transformers (all-mpnet-base-v2)

    ", unsafe_allow_html=True) desc = '''

    Embeddings of 300,000 news headlines are stored in Milvus vector database, used as a feature store. Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2). Similar news headlines are retrieved from the vector database using Euclidean distance as similarity metric. This method (all-mpnet-base-v2) has the best performance compared to multi-qa-distilbert-cos-v1 fine-tuned using TSDAE and extracting embeddings from fine-tuned DistilBERT classifier.

    ''' st.markdown(desc, unsafe_allow_html=True) news_txt = st.text_area("Paste the headline of a news article:", "", height=50) top_n = st.slider('Select number of similar articles to display', 1, 100, 10) if st.button("Submit"): result = find_similar_news(news_txt, top_n) # st.write(result) st.markdown(result, unsafe_allow_html=True) if __name__ == "__main__": main()