### TSDAE: Fine-tune sentence transformers using unsupervised learning with Pytorch
https://www.sbert.net/examples/unsupervised_learning/TSDAE/README.html

In [2]:
# !pip install sentence_transformers==2.2.2

In [4]:
import pandas as pd
import numpy as np
import string
from tqdm import tqdm
from numpy.linalg import norm
from sentence_transformers import SentenceTransformer, LoggingHandler
from sentence_transformers import models, util, datasets, evaluation, losses
from torch.utils.data import DataLoader

In [6]:
# import nltk
# nltk.download('punkt')

In [5]:
data = pd.read_csv('news_processed.csv', usecols=['short_description'])

In [6]:
data.dropna(inplace=True)

In [7]:
data['short_description'][1000]

'Experimental coronavirus vaccine prevents severe disease in mice'

In [14]:
def finetune_model(data: pd.DataFrame, col_to_use: str='short_description', 
 model_id: str="sentence-transformers/multi-qa-distilbert-cos-v1", 
 batch_size: int=8, epochs: int=2):
 
# https://www.sbert.net/examples/unsupervised_learning/TSDAE/README.html
 
 word_embedding_model = models.Transformer(model_id)
 pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'cls')
 model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
 
 train_examples = data[col_to_use].tolist()
 train_dataset = datasets.DenoisingAutoEncoderDataset(train_examples)
 train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
 train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path="sentence-transformers/paraphrase-distilroberta-base-v2", tie_encoder_decoder=False)
# train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_id, tie_encoder_decoder=True)

 
 model.fit(
 train_objectives=[(train_dataloader, train_loss)],
 epochs=epochs,
 weight_decay=0,
 scheduler='constantlr',
 optimizer_params={'lr': 3e-5},
 show_progress_bar=True
 )
 model_save_path = model_id + '_finetuned'
 model.save(model_save_path)
 return model_save_path

In [None]:
# fine-tune sentence transformer
finetuned_model_id = finetune_model(data=data)
finetuned_model = SentenceTransformer(finetuned_model_id)