hot-ones-trivia / preprocessing.py
RedTachyon's picture
Initial commit from GitHub repository without history
31b6e27
raw
history blame contribute delete
No virus
3.43 kB
import pandas as pd
from tqdm.auto import tqdm
import requests
import tiktoken
from typarse import BaseParser
from openai import OpenAI
import dotenv
import pickle
from core import get_batch_embeddings, Chunk, Dataset
class Parser(BaseParser):
chunk_size: int = 4000
save_path: str = "dataset.pkl"
_abbrev = {
"chunk_size": "c",
"save_path": "s",
}
_help = {
"chunk_size": "The maximum number of tokens per chunk",
"save_path": "The path to save the dataset",
}
def get_youtube_title(url: str) -> str | None:
"""
Get the title of a youtube video from the url
"""
video_id = url.split("v=")[-1]
api_url = f"https://www.youtube.com/oembed?url=http://www.youtube.com/watch?v={video_id}&format=json"
response = requests.get(api_url)
if response.status_code == 200:
data = response.json()
return data["title"]
else:
return None
def num_tokens_from_string(string: str, encoding_name: str) -> int:
"""
Calculate the number of tokens in a string
"""
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
def required_chunks(
text: str, max_tokens: int = 8191, encoding_name: str = "cl100k_base"
) -> int:
"""
Calculate the number of chunks required to split a text into chunks of a maximum number of tokens.
"""
num_tokens = num_tokens_from_string(text, encoding_name)
num_chunks = num_tokens // max_tokens
if num_tokens % max_tokens != 0:
num_chunks += 1
return num_chunks
def split_in_chunks(
text: str, max_tokens: int = 8191, encoding_name: str = "cl100k_base"
) -> list[str]:
"""
Split a long text into chunks of a maximum number of tokens
"""
encoding = tiktoken.get_encoding(encoding_name)
tokens = encoding.encode(text)
chunks: list[str] = []
current_chunk: list[int] = []
current_chunk_size = 0
for token in tokens:
if current_chunk_size + 1 > max_tokens:
chunks.append(encoding.decode(current_chunk))
current_chunk = []
current_chunk_size = 0
current_chunk.append(token)
current_chunk_size += 1
if current_chunk:
chunks.append(encoding.decode(current_chunk))
return chunks
if __name__ == "__main__":
dotenv.load_dotenv()
client = OpenAI()
args = Parser()
chunk_size = args.chunk_size
links = pd.read_csv("links.csv").URL.tolist()
titles = [get_youtube_title(link) for link in tqdm(links)]
# Get all transcripts
episodes = []
for i in range(17):
filename = f"transcripts/{i}.vtt"
with open(filename, "r") as file:
data = file.read()
episodes.append(data)
episode_chunks = [
split_in_chunks(episode, max_tokens=chunk_size) for episode in episodes
]
chunk_metadata = [
Chunk(
title=titles[i],
video_idx=i,
text=episode_chunks[i][j],
link=links[i],
)
for i in range(17)
for j in range(len(episode_chunks[i]))
]
chunk_texts = [chunk.text for chunk in chunk_metadata]
embeddings = get_batch_embeddings(client, chunk_texts)
dataset = Dataset(chunks=chunk_metadata, embeddings=embeddings)
with open(args.save_path, "wb") as file:
pickle.dump(dataset, file)