File size: 5,458 Bytes
bca2bcb
3f0bdca
6389f31
 
 
 
 
 
 
 
 
bca2bcb
6389f31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bca2bcb
6389f31
 
bca2bcb
6389f31
 
 
3f0bdca
6389f31
 
 
 
bca2bcb
6389f31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bca2bcb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st
from streamlit_sortables import sort_items
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from torchmetrics.multimodal import CLIPScore
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset, Dataset, load_from_disk
import os

import clip

def compute_clip_score(promptbook, device, drop_negative=False):
    # if 'clip_score' in promptbook.columns:
    #     print('==> Skipping CLIP-Score computation')
    #     return
    print('==> CLIP-Score computation started')
    clip_scores = []
    to_tensor = transforms.ToTensor()
    # metric = CLIPScore(model_name_or_path='openai/clip-vit-base-patch16').to(DEVICE)
    metric = CLIPScore(model_name_or_path='openai/clip-vit-large-patch14').to(device)
    for i in tqdm(range(0, len(promptbook), BATCH_SIZE)):
        images = []
        prompts = list(promptbook.prompt.values[i:i+BATCH_SIZE])
        for image in promptbook.image.values[i:i+BATCH_SIZE]:
            images.append(to_tensor(image))
        with torch.no_grad():
            x = metric.processor(text=prompts, images=images, return_tensors='pt', padding=True)
            img_features = metric.model.get_image_features(x['pixel_values'].to(device))
            img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
            txt_features = metric.model.get_text_features(x['input_ids'].to(device), x['attention_mask'].to(device))
            txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
            scores = 100 * (img_features * txt_features).sum(axis=-1).detach().cpu()
        if drop_negative:
            scores = torch.max(scores, torch.zeros_like(scores))
        clip_scores += [round(s.item(), 4) for s in scores]
    promptbook['clip_score'] = np.asarray(clip_scores)
    print('==> CLIP-Score computation completed')
    return promptbook


def compute_clip_score_hmd(promptbook):

    metric_cpu = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to('cpu')
    metric_gpu = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to('mps')

    for idx in promptbook.index:
        clip_score_hm = promptbook.loc[idx, 'clip_score']

        with torch.no_grad():
            image = promptbook.loc[idx, 'image']
            image.save(f"./tmp/{promptbook.loc[idx, 'image_id']}.png")
            image = transforms.ToTensor()(image)
            image_cpu = torch.unsqueeze(image, dim=0).to('cpu')
            image_gpu = torch.unsqueeze(image, dim=0).to('mps')

            prompts = [promptbook.loc[idx, 'prompt']]
            clip_score_cpu = metric_cpu(image_cpu, prompts)
            clip_score_gpu = metric_gpu(image_gpu, prompts)

        print(
            f'==> clip_score_hm: {clip_score_hm:.4f}, clip_score_cpu: {clip_score_cpu:.4f}, clip_score_gpu: {clip_score_gpu:.4f}')

def compute_clip_score_transformers(promptbook, device='cpu'):
    model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

    with torch.no_grad():
        inputs = processor(text=promptbook.prompt.tolist(), images=promptbook.image.tolist(), return_tensors="pt", padding=True)
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image

    promptbook.loc[:, 'clip_score'] = logits_per_image[:, 0].tolist()
    return promptbook

def compute_clip_score_clip(promptbook, device='cpu'):
    model, preprocess = clip.load("ViT-B/32", device=device)
    with torch.no_grad():

        for idx in promptbook.index:
            # image_input = preprocess(promptbook.loc[idx, 'image']).unsqueeze(0).to(device)
            image_inputs = preprocess(promptbook.image.tolist()).to(device)
            text_inputs = torch.cat([clip.tokenize(promptbook.prompt.tolist()).to(device)]).to(device)

            image_features = model.encode_image(image_inputs)
            text_features = model.encode_text(text_inputs)


            probs = logits_per_image.softmax(dim=-1).cpu().numpy()
    promptbook.loc[:, 'clip_score'] = probs[:, 0].tolist()
    return promptbook


if __name__ == "__main__":
    BATCH_SIZE = 200
    # DEVICE = 'mps' if torch.has_mps else 'cpu'

    print(torch.__version__)

    images_ds = load_from_disk(os.path.join(os.pardir, 'data', 'promptbook'))
    images_ds = images_ds.sort(['prompt_id', 'modelVersion_id'])
    print(images_ds)
    print(type(images_ds[0]['image']))
    promptbook_hmd = pd.DataFrame(images_ds[:20])
    promptbook_new = promptbook_hmd.drop(columns=['clip_score'])
    promptbook_cpu = compute_clip_score(promptbook_new.copy(deep=True), device='cpu')
    promptbook_mps = compute_clip_score(promptbook_new.copy(deep=True), device='mps')
    promptbook_tra_cpu = compute_clip_score_transformers(promptbook_new.copy(deep=True))
    promptbook_tra_mps = compute_clip_score_transformers(promptbook_new.copy(deep=True), device='mps')
    #
    for idx in promptbook_mps.index:
        print(
            'image id: ', promptbook_mps['image_id'][idx],
            'mps: ', promptbook_mps['clip_score'][idx],
            'cpu: ', promptbook_cpu['clip_score'][idx],
            'tra cpu: ', promptbook_tra_cpu['clip_score'][idx],
            'tra mps: ', promptbook_tra_mps['clip_score'][idx],
            'hmd: ', promptbook_hmd['clip_score'][idx]
        )
    #
    # compute_clip_score_hmd(promptbook_hmd)