import streamlit as st from streamlit_sortables import sort_items from torchvision import transforms from transformers import CLIPProcessor, CLIPModel from torchmetrics.multimodal import CLIPScore import torch import numpy as np import pandas as pd from tqdm import tqdm from datasets import load_dataset, Dataset, load_from_disk import os import clip def compute_clip_score(promptbook, device, drop_negative=False): # if 'clip_score' in promptbook.columns: # print('==> Skipping CLIP-Score computation') # return print('==> CLIP-Score computation started') clip_scores = [] to_tensor = transforms.ToTensor() # metric = CLIPScore(model_name_or_path='openai/clip-vit-base-patch16').to(DEVICE) metric = CLIPScore(model_name_or_path='openai/clip-vit-large-patch14').to(device) for i in tqdm(range(0, len(promptbook), BATCH_SIZE)): images = [] prompts = list(promptbook.prompt.values[i:i+BATCH_SIZE]) for image in promptbook.image.values[i:i+BATCH_SIZE]: images.append(to_tensor(image)) with torch.no_grad(): x = metric.processor(text=prompts, images=images, return_tensors='pt', padding=True) img_features = metric.model.get_image_features(x['pixel_values'].to(device)) img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) txt_features = metric.model.get_text_features(x['input_ids'].to(device), x['attention_mask'].to(device)) txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True) scores = 100 * (img_features * txt_features).sum(axis=-1).detach().cpu() if drop_negative: scores = torch.max(scores, torch.zeros_like(scores)) clip_scores += [round(s.item(), 4) for s in scores] promptbook['clip_score'] = np.asarray(clip_scores) print('==> CLIP-Score computation completed') return promptbook def compute_clip_score_hmd(promptbook): metric_cpu = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to('cpu') metric_gpu = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to('mps') for idx in promptbook.index: clip_score_hm = promptbook.loc[idx, 'clip_score'] with torch.no_grad(): image = promptbook.loc[idx, 'image'] image.save(f"./tmp/{promptbook.loc[idx, 'image_id']}.png") image = transforms.ToTensor()(image) image_cpu = torch.unsqueeze(image, dim=0).to('cpu') image_gpu = torch.unsqueeze(image, dim=0).to('mps') prompts = [promptbook.loc[idx, 'prompt']] clip_score_cpu = metric_cpu(image_cpu, prompts) clip_score_gpu = metric_gpu(image_gpu, prompts) print( f'==> clip_score_hm: {clip_score_hm:.4f}, clip_score_cpu: {clip_score_cpu:.4f}, clip_score_gpu: {clip_score_gpu:.4f}') def compute_clip_score_transformers(promptbook, device='cpu'): model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") with torch.no_grad(): inputs = processor(text=promptbook.prompt.tolist(), images=promptbook.image.tolist(), return_tensors="pt", padding=True) outputs = model(**inputs) logits_per_image = outputs.logits_per_image promptbook.loc[:, 'clip_score'] = logits_per_image[:, 0].tolist() return promptbook def compute_clip_score_clip(promptbook, device='cpu'): model, preprocess = clip.load("ViT-B/32", device=device) with torch.no_grad(): for idx in promptbook.index: # image_input = preprocess(promptbook.loc[idx, 'image']).unsqueeze(0).to(device) image_inputs = preprocess(promptbook.image.tolist()).to(device) text_inputs = torch.cat([clip.tokenize(promptbook.prompt.tolist()).to(device)]).to(device) image_features = model.encode_image(image_inputs) text_features = model.encode_text(text_inputs) probs = logits_per_image.softmax(dim=-1).cpu().numpy() promptbook.loc[:, 'clip_score'] = probs[:, 0].tolist() return promptbook if __name__ == "__main__": BATCH_SIZE = 200 # DEVICE = 'mps' if torch.has_mps else 'cpu' print(torch.__version__) images_ds = load_from_disk(os.path.join(os.pardir, 'data', 'promptbook')) images_ds = images_ds.sort(['prompt_id', 'modelVersion_id']) print(images_ds) print(type(images_ds[0]['image'])) promptbook_hmd = pd.DataFrame(images_ds[:20]) promptbook_new = promptbook_hmd.drop(columns=['clip_score']) promptbook_cpu = compute_clip_score(promptbook_new.copy(deep=True), device='cpu') promptbook_mps = compute_clip_score(promptbook_new.copy(deep=True), device='mps') promptbook_tra_cpu = compute_clip_score_transformers(promptbook_new.copy(deep=True)) promptbook_tra_mps = compute_clip_score_transformers(promptbook_new.copy(deep=True), device='mps') # for idx in promptbook_mps.index: print( 'image id: ', promptbook_mps['image_id'][idx], 'mps: ', promptbook_mps['clip_score'][idx], 'cpu: ', promptbook_cpu['clip_score'][idx], 'tra cpu: ', promptbook_tra_cpu['clip_score'][idx], 'tra mps: ', promptbook_tra_mps['clip_score'][idx], 'hmd: ', promptbook_hmd['clip_score'][idx] ) # # compute_clip_score_hmd(promptbook_hmd)