import argparse from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig import torch import os import json from tqdm import tqdm import shortuuid from llava import LlavaLlamaForCausalLM from llava.conversation import conv_templates from llava.utils import disable_torch_init from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria from PIL import Image import random import math def split_list(lst, n): """Split a list into n (roughly) equal-sized chunks""" chunk_size = math.ceil(len(lst) / n) # integer division return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] def get_chunk(lst, n, k): chunks = split_list(lst, n) return chunks[k] DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" def patch_config(config): patch_dict = { "use_mm_proj": True, "mm_vision_tower": "openai/clip-vit-large-patch14", "mm_hidden_size": 1024 } cfg = AutoConfig.from_pretrained(config) if not hasattr(cfg, "mm_vision_tower"): print(f'`mm_vision_tower` not found in `{config}`, applying patch and save to disk.') for k, v in patch_dict.items(): setattr(cfg, k, v) cfg.save_pretrained(config) def eval_model(args): # Model disable_torch_init() model_name = os.path.expanduser(args.model_name) if 'lora' in model_name.lower(): lora_cfg_pretrained = AutoConfig.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(args.base_model_path) else: tokenizer = AutoTokenizer.from_pretrained(model_name) if args.mm_projector is None: patch_config(model_name) if 'lora' in model_name.lower(): print('Loading LLaVA from base model...') llama_state_dict = AutoModelForCausalLM.from_pretrained(args.base_model_path, torch_dtype=torch.float16).state_dict() model = LlavaLlamaForCausalLM.from_pretrained(args.base_model_path, config=lora_cfg_pretrained, state_dict=llama_state_dict, torch_dtype=torch.float16, ignore_mismatched_sizes=True) print('Loading additional LLaVA weights...') if os.path.exists(os.path.join(model_name, 'non_lora_trainables.bin')): non_lora_trainables = torch.load(os.path.join(model_name, 'non_lora_trainables.bin'), map_location='cpu') else: # this is probably from HF Hub from huggingface_hub import hf_hub_download def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder) return torch.load(cache_file, map_location='cpu') non_lora_trainables = load_from_hf(model_name, 'non_lora_trainables.bin') non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} if any(k.startswith('model.model.embed_tokens') for k in non_lora_trainables): non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} non_lora_trainables = {k: v.to(torch.float16) for k, v in non_lora_trainables.items()} model.load_state_dict(non_lora_trainables, strict=False) from peft import PeftModel print('Loading LoRA weights...') model = PeftModel.from_pretrained(model, model_name) print('Merging LoRA weights...') model = model.merge_and_unload() print('Moving to CUDA...') model = model.cuda() else: model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda() image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) vision_tower = model.model.vision_tower[0] vision_tower.to(device='cuda', dtype=torch.float16) vision_config = vision_tower.config vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 else: # in case of using a pretrained model with only a MLP projector weights model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda() vision_tower = CLIPVisionModel.from_pretrained(args.vision_tower, torch_dtype=torch.float16).cuda() image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower, torch_dtype=torch.float16) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) vision_config = vision_tower.config vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 mm_projector = torch.nn.Linear(vision_config.hidden_size, model.config.hidden_size) mm_projector_weights = torch.load(args.mm_projector, map_location='cpu') mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()}) model.model.mm_projector = mm_projector.cuda().half() model.model.vision_tower = [vision_tower] questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] questions = get_chunk(questions, args.num_chunks, args.chunk_idx) answers_file = os.path.expanduser(args.answers_file) os.makedirs(os.path.dirname(answers_file), exist_ok=True) ans_file = open(answers_file, "w") for i, line in enumerate(tqdm(questions)): idx = line["question_id"] image_file = line["image"] qs = line["text"] cur_prompt = qs if mm_use_im_start_end: qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN else: qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) if args.conv_mode != 'simple': conv.append_message(conv.roles[1], "") prompt = conv.get_prompt() inputs = tokenizer([prompt]) image = Image.open(os.path.join(args.image_folder, image_file)) image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] input_ids = torch.as_tensor(inputs.input_ids).cuda() # new stopping implementation class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.tokenizer = tokenizer self.start_len = None self.input_ids = input_ids def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if self.start_len is None: self.start_len = self.input_ids.shape[1] else: outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False if args.conv_mode == 'simple': keywords = ['###'] else: keywords = [conv.sep2] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor.unsqueeze(0).half().cuda(), do_sample=True, temperature=0.7, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria]) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids') outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0].strip() if args.conv_mode == 'simple': while True: cur_len = len(outputs) outputs = outputs.strip() for pattern in ['###', 'Assistant:', 'Response:']: if outputs.startswith(pattern): outputs = outputs[len(pattern):].strip() if len(outputs) == cur_len: break try: index = outputs.index(conv.sep) except ValueError: outputs += conv.sep index = outputs.index(conv.sep) outputs = outputs[:index].strip() else: outputs = outputs.strip() ans_id = shortuuid.uuid() ans_file.write(json.dumps({"question_id": idx, "prompt": cur_prompt, "text": outputs, "answer_id": ans_id, "model_id": model_name, "metadata": {}}) + "\n") ans_file.flush() ans_file.close() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--base-model-path", type=str, default=None) parser.add_argument("--image-folder", type=str, default="") parser.add_argument("--question-file", type=str, default="tables/question.jsonl") parser.add_argument("--answers-file", type=str, default="answer.jsonl") parser.add_argument("--mm-projector", type=str, default=None) parser.add_argument("--vision-tower", type=str, default=None) parser.add_argument("--conv-mode", type=str, default="simple") parser.add_argument("--num-chunks", type=int, default=1) parser.add_argument("--chunk-idx", type=int, default=0) args = parser.parse_args() eval_model(args)