In [2]:
from transformers import DistilBertTokenizerFast
from tensorflow.keras.models import load_model, Model
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from dotenv import load_dotenv
import os
import pandas as pd
from pymilvus import connections, utility
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema

In [4]:
model_checkpoint = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)

interpreter = tf.lite.Interpreter(model_path="news_classification_hf_distilbert.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()

In [5]:
class TextVectorizer:
 '''
 sentence transformers to extract sentence embeddings
 '''
 def vectorize(self, text): 
 tokens = tokenizer(text, max_length=80, padding="max_length", truncation=True, return_tensors="tf")
 attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
 interpreter.set_tensor(input_details[0]["index"], attention_mask)
 interpreter.set_tensor(input_details[1]["index"], input_ids)
 interpreter.invoke()
 tflite_embeds = interpreter.get_tensor(711)[0]
 return [*tflite_embeds]

In [6]:
vectorizer = TextVectorizer()

In [7]:
# Reading milvus URI & API token From secrets.env
load_dotenv('secrets.env')
uri = os.environ.get("URI")
token = os.environ.get("TOKEN")

In [8]:
# connecting to db
connections.connect("default", uri=uri, token=token)
print(f"Connected to DB")

Connected to DB


In [9]:
collection_name = os.environ.get("COLLECTION_NAME")
check_collection = utility.has_collection(collection_name)
check_collection # checks if collection exisits

True

In [10]:
# load the collection before querying
collection = Collection(name=collection_name)
collection.load()

In [11]:
def find_similar_news(text: str, top_n: int=3):
 search_params = {"metric_type": "L2"}
 search_vec = vectorizer.vectorize([text])
 result = collection.search([search_vec],
 anns_field='article_embed', # annotations field specified in the schema definition
 param=search_params,
 limit=top_n,
 guarantee_timestamp=1, 
 output_fields=['article_desc', 'article_category']) # which fields to return in output

 
 output_dict = {"input_text": text, "similar_texts": [hit.entity.get('article_desc') for hits in result for hit in hits], 
 "text_category": [hit.entity.get('article_category') for hits in result for hit in hits]} 
 txt_category = [f'{txt} ({cat})' for txt, cat in zip(output_dict.get('similar_texts'), output_dict.get('text_category'))]
 similar_txt = '\n\n'.join(txt_category)
 print(f"INPUT\n{'-'*5}\n{text}\n\nSIMILAR NEWS\n{'-'*12}\n{similar_txt}")
 return output_dict, search_vec

In [13]:
text = '''HMD Global raises $230 million to bolster 5G smartphone business across US, emerging markets'''

_ , sv = find_similar_news(text, top_n=5)

INPUT
-----
HMD Global raises $230 million to bolster 5G smartphone business across US, emerging markets

SIMILAR NEWS
------------
HMD Global raises $230 million to bolster 5G smartphone business across US, emerging markets (TECHNOLOGY)

HMD Global raises $230 million from Google, Qualcomm investment (TECHNOLOGY)

POV: The latest investment in HMD Global might be the first step in creating a European mobile giant (TECHNOLOGY)

HMD Global receives huge investments from top tech companies (TECHNOLOGY)

A new hope: HMD Global, creator of Nokia phones, receives USD 230 million strategic investment (TECHNOLOGY)
