import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False" import numpy as np import torch from PIL import Image import matplotlib.pyplot as plt from fromage import models from fromage import utils import gradio as gr import huggingface_hub import tempfile # Download model from HF Hub. ckpt_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar') args_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json') model = models.load_fromage('./', args_path, ckpt_path) def upload_image(state, image_input): state += [(f"![](/file={image_input.name})", "(Image received. Type or ask something to continue.)")] input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB') return [state, input_image], state def save_image_to_local(image: Image.Image): # TODO(jykoh): Update so the url path is used, to prevent repeat saving. filename = next(tempfile._get_candidate_names()) + '.png' image.save(filename) return filename def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature): input_prompt = 'Q: ' + input_text + '\nA:' input_image = state[1] chat_history += input_prompt print('Generating for', chat_history, flush=True) # If an image was uploaded, prepend it to the model. model_inputs = None if input_image is not None: model_inputs = [input_image, chat_history] else: model_inputs = [chat_history] top_p = 1.0 if temperature != 0.0: top_p = 0.95 print('Running model.generate_for_images_and_texts with', model_inputs, flush=True) model_outputs = model.generate_for_images_and_texts(model_inputs, num_words=num_words, ret_scale_factor=ret_scale_factor, top_p=top_p, temperature=temperature, max_num_rets=max_nm_rets) print('model_outputs', model_outputs, flush=True) im_names = [] response = '' text_outputs = [] for output in model_outputs: if type(output) == str: text_outputs.append(output) response += output elif type(output) == list: for image in output: filename = save_image_to_local(image) response += f'' elif type(output) == Image.Image: filename = save_image_to_local(output) response += f'' # TODO(jykoh): Persist image inputs. chat_history += ' '.join(text_outputs) if chat_history[-1] != '\n': chat_history += '\n' state.append((input_text, response)) # Set input image to None. print('state', state, flush=True) return [state, None], state with gr.Blocks() as demo: gr.Markdown( '### Grounding Language Models to Images for Multimodal Generation' ) chatbot = gr.Chatbot() gr_state = gr.State([[], None]) # chat_history, input_image with gr.Row(): with gr.Column(scale=0.3, min_width=0): ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)") max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return") gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True) gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True) with gr.Column(scale=0.7, min_width=0): image_btn = gr.UploadButton("Image Input", file_types=["image"]) text_input = gr.Textbox(label="Text Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!") clear_btn = gr.Button("Clear History") text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot]) image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot]) clear_btn.click(reset, [], [gr_state, chatbot]) demo.launch(share=False, debug=True, server_name="0.0.0.0")