{ "cells": [ { "cell_type": "markdown", "id": "f0dc396f", "metadata": {}, "source": [ "### TSDAE: Fine-tune sentence transformers using unsupervised learning with Pytorch\n", "https://www.sbert.net/examples/unsupervised_learning/TSDAE/README.html" ] }, { "cell_type": "code", "execution_count": 2, "id": "34329058", "metadata": { "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", "execution": { "iopub.execute_input": "2024-01-13T16:51:37.887254Z", "iopub.status.busy": "2024-01-13T16:51:37.886390Z", "iopub.status.idle": "2024-01-13T16:51:37.891827Z", "shell.execute_reply": "2024-01-13T16:51:37.890706Z", "shell.execute_reply.started": "2024-01-13T16:51:37.887212Z" } }, "outputs": [], "source": [ "# !pip install sentence_transformers==2.2.2" ] }, { "cell_type": "code", "execution_count": 4, "id": "ebd138d8", "metadata": { "execution": { "iopub.execute_input": "2024-01-13T16:51:50.310221Z", "iopub.status.busy": "2024-01-13T16:51:50.309586Z", "iopub.status.idle": "2024-01-13T16:51:50.315850Z", "shell.execute_reply": "2024-01-13T16:51:50.314927Z", "shell.execute_reply.started": "2024-01-13T16:51:50.310185Z" } }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import string\n", "from tqdm import tqdm\n", "from numpy.linalg import norm\n", "from sentence_transformers import SentenceTransformer, LoggingHandler\n", "from sentence_transformers import models, util, datasets, evaluation, losses\n", "from torch.utils.data import DataLoader" ] }, { "cell_type": "code", "execution_count": 6, "id": "7ef2d063", "metadata": { "execution": { "iopub.execute_input": "2024-01-13T13:24:44.770211Z", "iopub.status.busy": "2024-01-13T13:24:44.769806Z", "iopub.status.idle": "2024-01-13T13:24:44.775042Z", "shell.execute_reply": "2024-01-13T13:24:44.773860Z", "shell.execute_reply.started": "2024-01-13T13:24:44.770177Z" } }, "outputs": [], "source": [ "# import nltk\n", "# nltk.download('punkt')" ] }, { "cell_type": "code", "execution_count": 5, "id": "453c1add", "metadata": { "execution": { "iopub.execute_input": "2024-01-13T16:51:54.070689Z", "iopub.status.busy": "2024-01-13T16:51:54.069945Z", "iopub.status.idle": "2024-01-13T16:51:54.809726Z", "shell.execute_reply": "2024-01-13T16:51:54.808920Z", "shell.execute_reply.started": "2024-01-13T16:51:54.070657Z" } }, "outputs": [], "source": [ "data = pd.read_csv('news_processed.csv', usecols=['short_description'])" ] }, { "cell_type": "code", "execution_count": 6, "id": "61629b79", "metadata": { "execution": { "iopub.execute_input": "2024-01-13T16:51:55.125758Z", "iopub.status.busy": "2024-01-13T16:51:55.124990Z", "iopub.status.idle": "2024-01-13T16:51:55.180470Z", "shell.execute_reply": "2024-01-13T16:51:55.179559Z", "shell.execute_reply.started": "2024-01-13T16:51:55.125716Z" } }, "outputs": [], "source": [ "data.dropna(inplace=True)" ] }, { "cell_type": "code", "execution_count": 7, "id": "14162e2c", "metadata": { "execution": { "iopub.execute_input": "2024-01-13T16:51:55.817535Z", "iopub.status.busy": "2024-01-13T16:51:55.816764Z", "iopub.status.idle": "2024-01-13T16:51:55.836578Z", "shell.execute_reply": "2024-01-13T16:51:55.835697Z", "shell.execute_reply.started": "2024-01-13T16:51:55.817499Z" } }, "outputs": [ { "data": { "text/plain": [ "'Experimental coronavirus vaccine prevents severe disease in mice'" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data['short_description'][1000]" ] }, { "cell_type": "code", "execution_count": 14, "id": "8fd1a801", "metadata": { "execution": { "iopub.execute_input": "2024-01-13T17:08:09.906826Z", "iopub.status.busy": "2024-01-13T17:08:09.906182Z", "iopub.status.idle": "2024-01-13T17:08:09.914834Z", "shell.execute_reply": "2024-01-13T17:08:09.913737Z", "shell.execute_reply.started": "2024-01-13T17:08:09.906795Z" } }, "outputs": [], "source": [ "def finetune_model(data: pd.DataFrame, col_to_use: str='short_description', \n", " model_id: str=\"sentence-transformers/multi-qa-distilbert-cos-v1\", \n", " batch_size: int=8, epochs: int=2):\n", " \n", "# https://www.sbert.net/examples/unsupervised_learning/TSDAE/README.html\n", " \n", " word_embedding_model = models.Transformer(model_id)\n", " pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'cls')\n", " model = SentenceTransformer(modules=[word_embedding_model, pooling_model])\n", " \n", " train_examples = data[col_to_use].tolist()\n", " train_dataset = datasets.DenoisingAutoEncoderDataset(train_examples)\n", " train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", " train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=\"sentence-transformers/paraphrase-distilroberta-base-v2\", tie_encoder_decoder=False)\n", "# train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_id, tie_encoder_decoder=True)\n", "\n", " \n", " model.fit(\n", " train_objectives=[(train_dataloader, train_loss)],\n", " epochs=epochs,\n", " weight_decay=0,\n", " scheduler='constantlr',\n", " optimizer_params={'lr': 3e-5},\n", " show_progress_bar=True\n", " )\n", " model_save_path = model_id + '_finetuned'\n", " model.save(model_save_path)\n", " return model_save_path" ] }, { "cell_type": "code", "execution_count": null, "id": "b17a9779", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# fine-tune sentence transformer\n", "finetuned_model_id = finetune_model(data=data)\n", "finetuned_model = SentenceTransformer(finetuned_model_id)" ] }, { "cell_type": "code", "execution_count": null, "id": "4d87c128", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kaggle": { "accelerator": "gpu", "dataSources": [ { "datasetId": 4298708, "sourceId": 7394110, "sourceType": "datasetVersion" } ], "dockerImageVersionId": 30636, "isGpuEnabled": true, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }