vqa-analysis / inference.py
dokster's picture
Upload 4 files
0d89394
raw
history blame
No virus
4.56 kB
import cv2
import torch
import streamlit as st
from PIL import Image
from torch.nn import functional as nnf
# @st.cache_data
def generate2(
model,
tokenizer,
tokens=None,
prompt='',
embed=None,
entry_count=1,
entry_length=67,
top_p=0.98,
temperature=1,
stop_token='.',
):
# model.eval()
generated_list = []
stop_token_index = tokenizer.encode(stop_token)[0]
filter_value = -float("Inf")
device = next(model.parameters()).device
with torch.no_grad():
for entry_idx in range(entry_count):
if not tokens:
tokens = torch.tensor(tokenizer.encode(prompt))
tokens = tokens.unsqueeze(0).to(device)
emb_tokens = model.gpt.transformer.wte(tokens)
if embed is not None:
generated = torch.cat((embed, emb_tokens), dim=1)
else:
generated = emb_tokens
for i in range(entry_length):
outputs = model.gpt(inputs_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1
].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[:, indices_to_remove] = filter_value
top_k = 2000
top_p = 0.98
next_token = torch.argmax(logits, -1).unsqueeze(0)
next_token_embed = model.gpt.transformer.wte(next_token)
if tokens is None:
tokens = next_token
else:
tokens = torch.cat((tokens, next_token), dim=1)
generated = torch.cat((generated, next_token_embed), dim=1)
if stop_token_index == next_token.item():
break
output_list = list(tokens.squeeze().cpu().numpy())
output_text = tokenizer.decode(output_list)
output_text = filter_ngrams(output_text)
generated_list.append(output_text)
return generated_list[0]
def filter_ngrams(output_text):
a_pos = output_text.find(' A:')
sec_a_pos = output_text.find(' A:', a_pos + 1)
return output_text[:sec_a_pos]
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
@st.cache_data
def read_video(path, transform=None, frames_num=9, window=30):
frames = []
cap = cv2.VideoCapture(path)
fps = int(cap.get(cv2.CAP_PROP_FPS))
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
N = length // (frames_num)
current_frame = 1
for i in range(length):
ret, frame = cap.read(current_frame)
if ret and i == current_frame and len(frames) < frames_num:
size = 193, 193
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
frame.thumbnail(size, Image.ANTIALIAS)
frames.append(frame)
current_frame += N
cap.release()
return frames
# @st.cache_data
def get_caption(model, tokenizer, prefix, prefix_length, prompt=''):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
prefix = prefix.to(device)
with torch.no_grad():
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
if prompt:
generated_text_prefix = generate2(model, tokenizer, prompt=prompt, embed=prefix_embed)
else:
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
return generated_text_prefix.replace('\n', ' ')
# @st.cache_data
def get_ans(model, tokenizer, clip_emb, prefix_length, prompt):
output = get_caption(model, tokenizer, clip_emb, prefix_length, prompt=prompt)
ans = output[len(prompt):].strip()
return {'answer': ans}