Ricercar's picture
GemRic-18K clip score fixed!
6389f31
raw
history blame
No virus
5.46 kB
import streamlit as st
from streamlit_sortables import sort_items
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from torchmetrics.multimodal import CLIPScore
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset, Dataset, load_from_disk
import os
import clip
def compute_clip_score(promptbook, device, drop_negative=False):
# if 'clip_score' in promptbook.columns:
# print('==> Skipping CLIP-Score computation')
# return
print('==> CLIP-Score computation started')
clip_scores = []
to_tensor = transforms.ToTensor()
# metric = CLIPScore(model_name_or_path='openai/clip-vit-base-patch16').to(DEVICE)
metric = CLIPScore(model_name_or_path='openai/clip-vit-large-patch14').to(device)
for i in tqdm(range(0, len(promptbook), BATCH_SIZE)):
images = []
prompts = list(promptbook.prompt.values[i:i+BATCH_SIZE])
for image in promptbook.image.values[i:i+BATCH_SIZE]:
images.append(to_tensor(image))
with torch.no_grad():
x = metric.processor(text=prompts, images=images, return_tensors='pt', padding=True)
img_features = metric.model.get_image_features(x['pixel_values'].to(device))
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
txt_features = metric.model.get_text_features(x['input_ids'].to(device), x['attention_mask'].to(device))
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
scores = 100 * (img_features * txt_features).sum(axis=-1).detach().cpu()
if drop_negative:
scores = torch.max(scores, torch.zeros_like(scores))
clip_scores += [round(s.item(), 4) for s in scores]
promptbook['clip_score'] = np.asarray(clip_scores)
print('==> CLIP-Score computation completed')
return promptbook
def compute_clip_score_hmd(promptbook):
metric_cpu = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to('cpu')
metric_gpu = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to('mps')
for idx in promptbook.index:
clip_score_hm = promptbook.loc[idx, 'clip_score']
with torch.no_grad():
image = promptbook.loc[idx, 'image']
image.save(f"./tmp/{promptbook.loc[idx, 'image_id']}.png")
image = transforms.ToTensor()(image)
image_cpu = torch.unsqueeze(image, dim=0).to('cpu')
image_gpu = torch.unsqueeze(image, dim=0).to('mps')
prompts = [promptbook.loc[idx, 'prompt']]
clip_score_cpu = metric_cpu(image_cpu, prompts)
clip_score_gpu = metric_gpu(image_gpu, prompts)
print(
f'==> clip_score_hm: {clip_score_hm:.4f}, clip_score_cpu: {clip_score_cpu:.4f}, clip_score_gpu: {clip_score_gpu:.4f}')
def compute_clip_score_transformers(promptbook, device='cpu'):
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
with torch.no_grad():
inputs = processor(text=promptbook.prompt.tolist(), images=promptbook.image.tolist(), return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
promptbook.loc[:, 'clip_score'] = logits_per_image[:, 0].tolist()
return promptbook
def compute_clip_score_clip(promptbook, device='cpu'):
model, preprocess = clip.load("ViT-B/32", device=device)
with torch.no_grad():
for idx in promptbook.index:
# image_input = preprocess(promptbook.loc[idx, 'image']).unsqueeze(0).to(device)
image_inputs = preprocess(promptbook.image.tolist()).to(device)
text_inputs = torch.cat([clip.tokenize(promptbook.prompt.tolist()).to(device)]).to(device)
image_features = model.encode_image(image_inputs)
text_features = model.encode_text(text_inputs)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
promptbook.loc[:, 'clip_score'] = probs[:, 0].tolist()
return promptbook
if __name__ == "__main__":
BATCH_SIZE = 200
# DEVICE = 'mps' if torch.has_mps else 'cpu'
print(torch.__version__)
images_ds = load_from_disk(os.path.join(os.pardir, 'data', 'promptbook'))
images_ds = images_ds.sort(['prompt_id', 'modelVersion_id'])
print(images_ds)
print(type(images_ds[0]['image']))
promptbook_hmd = pd.DataFrame(images_ds[:20])
promptbook_new = promptbook_hmd.drop(columns=['clip_score'])
promptbook_cpu = compute_clip_score(promptbook_new.copy(deep=True), device='cpu')
promptbook_mps = compute_clip_score(promptbook_new.copy(deep=True), device='mps')
promptbook_tra_cpu = compute_clip_score_transformers(promptbook_new.copy(deep=True))
promptbook_tra_mps = compute_clip_score_transformers(promptbook_new.copy(deep=True), device='mps')
#
for idx in promptbook_mps.index:
print(
'image id: ', promptbook_mps['image_id'][idx],
'mps: ', promptbook_mps['clip_score'][idx],
'cpu: ', promptbook_cpu['clip_score'][idx],
'tra cpu: ', promptbook_tra_cpu['clip_score'][idx],
'tra mps: ', promptbook_tra_mps['clip_score'][idx],
'hmd: ', promptbook_hmd['clip_score'][idx]
)
#
# compute_clip_score_hmd(promptbook_hmd)