Edward Beeching
fixing seed issues
df0a622
raw
history blame contribute delete
No virus
4.31 kB
import datetime
import hashlib
import json
import os
import random
from io import BytesIO
import gradio as gr
from datasets import load_dataset
from huggingface_hub import upload_file
dataset = load_dataset("edbeeching/rlhf_dialog_experiment_cosmo_dialog_generation",
use_auth_token=os.environ['EB_TOKEN'])["validation"]
def sample_to_markdown(sample, index):
conversation = sample["conversations"][index]
trucation_length = sample["trucation_length"]
output = ""
aligns = ["left", "right"]
strongs1 = ["", "<strong>"]
strongs2 = ["", "</strong>"]
for i,conv in enumerate(conversation):
if i == trucation_length:
output += '<p style="text-align:center"> --- START OF DIALOG GENERATION --- </p><br>'
align = aligns[i%2]
strong1 = strongs1[i%2]
strong2 = strongs2[i%2]
output += f'<div style="text-align: {align}"> {strong1}{conv}{strong2} \n </div> <br>'
return output
sample = None
def get_sample():
# I set the seed here as the randomness was a bit off otherwise
print(abs(hash(datetime.datetime.now().strftime("%Y%m%d_%H%M%s"))) % (10 ** 8))
random.seed(abs(hash(datetime.datetime.now().strftime("%Y%m%d_%H%M%s"))) % (10 ** 8))
dataset_size = len(dataset)
sample_index = random.randint(0, dataset_size-1)
sample = dataset[sample_index]
return sample
def check_and_submit_preferences(sample, preferred_text, text_quality):
if preferred_text is None:
print("not submitted due to unselected preferred text")
return
if text_quality is None:
print("not submitted due to unselected text_quality")
return
data = {
"sample": sample,
"preferred_text": preferred_text,
"text_quality": text_quality,
"date_time": datetime.datetime.now().strftime("%m/%d/%Y %H:%M:%S")
# add other info like user etc?
}
task_hash = hashlib.md5(sample["situation"].encode())
time_now = datetime.datetime.now().strftime("%Y%m%d_%H%M%s")
task_directory = f"{time_now}_{task_hash.hexdigest()}"
upload_file(
path_or_fileobj=BytesIO(bytes(json.dumps(data), 'utf-8')),
path_in_repo=task_directory,
repo_id='edbeeching/rlhf_dialog_experiment_dataset',
repo_type='dataset',
token=os.environ['EB_TOKEN']
)
with gr.Blocks() as demo:
gr.Markdown(
"""
This Space is an experiment to model human preferences on dialog generated with the [Cosmo-XL](https://huggingface.co/allenai/cosmo-xl) model, prompted with parts of conversations from the [SODA](https://huggingface.co/datasets/allenai/soda) dataset.
The following conversation was created with the following prompt:
"""
)
sample = get_sample()
with gr.Column() as details_col:
summary = gr.Markdown(f"## {sample['situation']}", label='Description')
with gr.Row():
with gr.Column():
with gr.Box():
dialog1 = gr.Markdown(sample_to_markdown(sample, 0), label='Dialog 1')
with gr.Column():
with gr.Box():
dialog2 = gr.Markdown(sample_to_markdown(sample, 1), label='Dialog 2')
with gr.Column():
dialog_choice = gr.Radio(["Left dialog", "Right dialog"], label="Preferred text", interactive=True)
quality_of_dialog = gr.Radio(["Terrible", "Poor", "Ok", "Good", "Excellent"], label="Quality of preferred text", interactive=True)
next_button = gr.Button("Submit")
def on_next(preferred_text, text_quality):
# check and submit the current response
global sample
check_and_submit_preferences(sample, preferred_text, text_quality)
sample = get_sample()
return (
gr.Markdown.update(f"## {sample['situation']}"),
gr.Markdown.update(sample_to_markdown(sample, 0)),
gr.Markdown.update(sample_to_markdown(sample, 1)),
gr.Radio.update(value=None),
gr.Radio.update(value=None)
)
next_button.click(on_next, inputs=[dialog_choice, quality_of_dialog], outputs=[summary, dialog1, dialog2, dialog_choice, quality_of_dialog])
if __name__ == "__main__":
demo.launch()