Spaces:
Runtime error
Runtime error
import datetime | |
import hashlib | |
import json | |
import os | |
import random | |
from io import BytesIO | |
import gradio as gr | |
from datasets import load_dataset | |
from huggingface_hub import upload_file | |
dataset = load_dataset("edbeeching/rlhf_dialog_experiment_cosmo_dialog_generation", | |
use_auth_token=os.environ['EB_TOKEN'])["validation"] | |
def sample_to_markdown(sample, index): | |
conversation = sample["conversations"][index] | |
trucation_length = sample["trucation_length"] | |
output = "" | |
aligns = ["left", "right"] | |
strongs1 = ["", "<strong>"] | |
strongs2 = ["", "</strong>"] | |
for i,conv in enumerate(conversation): | |
if i == trucation_length: | |
output += '<p style="text-align:center"> --- START OF DIALOG GENERATION --- </p><br>' | |
align = aligns[i%2] | |
strong1 = strongs1[i%2] | |
strong2 = strongs2[i%2] | |
output += f'<div style="text-align: {align}"> {strong1}{conv}{strong2} \n </div> <br>' | |
return output | |
sample = None | |
def get_sample(): | |
# I set the seed here as the randomness was a bit off otherwise | |
print(abs(hash(datetime.datetime.now().strftime("%Y%m%d_%H%M%s"))) % (10 ** 8)) | |
random.seed(abs(hash(datetime.datetime.now().strftime("%Y%m%d_%H%M%s"))) % (10 ** 8)) | |
dataset_size = len(dataset) | |
sample_index = random.randint(0, dataset_size-1) | |
sample = dataset[sample_index] | |
return sample | |
def check_and_submit_preferences(sample, preferred_text, text_quality): | |
if preferred_text is None: | |
print("not submitted due to unselected preferred text") | |
return | |
if text_quality is None: | |
print("not submitted due to unselected text_quality") | |
return | |
data = { | |
"sample": sample, | |
"preferred_text": preferred_text, | |
"text_quality": text_quality, | |
"date_time": datetime.datetime.now().strftime("%m/%d/%Y %H:%M:%S") | |
# add other info like user etc? | |
} | |
task_hash = hashlib.md5(sample["situation"].encode()) | |
time_now = datetime.datetime.now().strftime("%Y%m%d_%H%M%s") | |
task_directory = f"{time_now}_{task_hash.hexdigest()}" | |
upload_file( | |
path_or_fileobj=BytesIO(bytes(json.dumps(data), 'utf-8')), | |
path_in_repo=task_directory, | |
repo_id='edbeeching/rlhf_dialog_experiment_dataset', | |
repo_type='dataset', | |
token=os.environ['EB_TOKEN'] | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
This Space is an experiment to model human preferences on dialog generated with the [Cosmo-XL](https://huggingface.co/allenai/cosmo-xl) model, prompted with parts of conversations from the [SODA](https://huggingface.co/datasets/allenai/soda) dataset. | |
The following conversation was created with the following prompt: | |
""" | |
) | |
sample = get_sample() | |
with gr.Column() as details_col: | |
summary = gr.Markdown(f"## {sample['situation']}", label='Description') | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Box(): | |
dialog1 = gr.Markdown(sample_to_markdown(sample, 0), label='Dialog 1') | |
with gr.Column(): | |
with gr.Box(): | |
dialog2 = gr.Markdown(sample_to_markdown(sample, 1), label='Dialog 2') | |
with gr.Column(): | |
dialog_choice = gr.Radio(["Left dialog", "Right dialog"], label="Preferred text", interactive=True) | |
quality_of_dialog = gr.Radio(["Terrible", "Poor", "Ok", "Good", "Excellent"], label="Quality of preferred text", interactive=True) | |
next_button = gr.Button("Submit") | |
def on_next(preferred_text, text_quality): | |
# check and submit the current response | |
global sample | |
check_and_submit_preferences(sample, preferred_text, text_quality) | |
sample = get_sample() | |
return ( | |
gr.Markdown.update(f"## {sample['situation']}"), | |
gr.Markdown.update(sample_to_markdown(sample, 0)), | |
gr.Markdown.update(sample_to_markdown(sample, 1)), | |
gr.Radio.update(value=None), | |
gr.Radio.update(value=None) | |
) | |
next_button.click(on_next, inputs=[dialog_choice, quality_of_dialog], outputs=[summary, dialog1, dialog2, dialog_choice, quality_of_dialog]) | |
if __name__ == "__main__": | |
demo.launch() |