{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "aebbbcff-1e73-4c19-b11e-fc16784bd669", "metadata": {}, "outputs": [], "source": [ "from transformers import DistilBertTokenizerFast\n", "from tensorflow.keras.models import load_model, Model\n", "import numpy as np\n", "import tensorflow as tf\n", "from tqdm import tqdm\n", "from dotenv import load_dotenv\n", "import os\n", "import pandas as pd\n", "from pymilvus import connections, utility\n", "from pymilvus import Collection, DataType, FieldSchema, CollectionSchema" ] }, { "cell_type": "code", "execution_count": 4, "id": "942a5712-5271-4e23-a959-a8e5192710a0", "metadata": {}, "outputs": [], "source": [ "model_checkpoint = \"distilbert-base-uncased\"\n", "tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)\n", "\n", "interpreter = tf.lite.Interpreter(model_path=\"news_classification_hf_distilbert.tflite\")\n", "interpreter.allocate_tensors()\n", "input_details = interpreter.get_input_details()" ] }, { "cell_type": "code", "execution_count": 5, "id": "efa3786e-c7c2-4ec5-b7bb-cc97ec63914d", "metadata": {}, "outputs": [], "source": [ "class TextVectorizer:\n", " '''\n", " sentence transformers to extract sentence embeddings\n", " '''\n", " def vectorize(self, text): \n", " tokens = tokenizer(text, max_length=80, padding=\"max_length\", truncation=True, return_tensors=\"tf\")\n", " attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']\n", " interpreter.set_tensor(input_details[0][\"index\"], attention_mask)\n", " interpreter.set_tensor(input_details[1][\"index\"], input_ids)\n", " interpreter.invoke()\n", " tflite_embeds = interpreter.get_tensor(711)[0]\n", " return [*tflite_embeds]" ] }, { "cell_type": "code", "execution_count": 6, "id": "91903ad2-4093-4014-ba08-7fdcdef60500", "metadata": {}, "outputs": [], "source": [ "vectorizer = TextVectorizer()" ] }, { "cell_type": "code", "execution_count": 7, "id": "1a3f7e18-cd3f-449f-a698-2cdbc4c54adf", "metadata": {}, "outputs": [], "source": [ "# Reading milvus URI & API token From secrets.env\n", "load_dotenv('secrets.env')\n", "uri = os.environ.get(\"URI\")\n", "token = os.environ.get(\"TOKEN\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "3afe8a0f-61cc-4018-ad61-5b5d9c904f1f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Connected to DB\n" ] } ], "source": [ "# connecting to db\n", "connections.connect(\"default\", uri=uri, token=token)\n", "print(f\"Connected to DB\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "98ff14a7-9300-4c24-83fe-c54effb14ac8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "collection_name = os.environ.get(\"COLLECTION_NAME\")\n", "check_collection = utility.has_collection(collection_name)\n", "check_collection # checks if collection exisits" ] }, { "cell_type": "code", "execution_count": 10, "id": "a216d220-247c-4d19-808b-7d1cf08e28f8", "metadata": {}, "outputs": [], "source": [ "# load the collection before querying\n", "collection = Collection(name=collection_name)\n", "collection.load()" ] }, { "cell_type": "code", "execution_count": 11, "id": "e9215af5-49b7-4984-aa4d-d3ff03446012", "metadata": {}, "outputs": [], "source": [ "def find_similar_news(text: str, top_n: int=3):\n", " search_params = {\"metric_type\": \"L2\"}\n", " search_vec = vectorizer.vectorize([text])\n", " result = collection.search([search_vec],\n", " anns_field='article_embed', # annotations field specified in the schema definition\n", " param=search_params,\n", " limit=top_n,\n", " guarantee_timestamp=1, \n", " output_fields=['article_desc', 'article_category']) # which fields to return in output\n", "\n", " \n", " output_dict = {\"input_text\": text, \"similar_texts\": [hit.entity.get('article_desc') for hits in result for hit in hits], \n", " \"text_category\": [hit.entity.get('article_category') for hits in result for hit in hits]} \n", " txt_category = [f'{txt} ({cat})' for txt, cat in zip(output_dict.get('similar_texts'), output_dict.get('text_category'))]\n", " similar_txt = '\\n\\n'.join(txt_category)\n", " print(f\"INPUT\\n{'-'*5}\\n{text}\\n\\nSIMILAR NEWS\\n{'-'*12}\\n{similar_txt}\")\n", " return output_dict, search_vec" ] }, { "cell_type": "code", "execution_count": 13, "id": "6c3aa64b-235d-471a-a8af-f6a71c02fdd5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INPUT\n", "-----\n", "HMD Global raises $230 million to bolster 5G smartphone business across US, emerging markets\n", "\n", "SIMILAR NEWS\n", "------------\n", "HMD Global raises $230 million to bolster 5G smartphone business across US, emerging markets (TECHNOLOGY)\n", "\n", "HMD Global raises $230 million from Google, Qualcomm investment (TECHNOLOGY)\n", "\n", "POV: The latest investment in HMD Global might be the first step in creating a European mobile giant (TECHNOLOGY)\n", "\n", "HMD Global receives huge investments from top tech companies (TECHNOLOGY)\n", "\n", "A new hope: HMD Global, creator of Nokia phones, receives USD 230 million strategic investment (TECHNOLOGY)\n" ] } ], "source": [ "text = '''HMD Global raises $230 million to bolster 5G smartphone business across US, emerging markets'''\n", "\n", "_ , sv = find_similar_news(text, top_n=5)" ] }, { "cell_type": "code", "execution_count": null, "id": "0307e6e6-2ec3-486f-81fb-6e796cadf643", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python (tf_gpu)", "language": "python", "name": "tf_gpu" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }