import torch from llava.model import * from transformers import AutoConfig, StoppingCriteria def auto_upgrade(config): cfg = AutoConfig.from_pretrained(config) if 'llava' in config and 'llava' not in cfg.model_type: assert cfg.model_type == 'llama' print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") print("You must upgrade the checkpoint to the new code base (this can be done automatically).") confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") if confirm.lower() in ["y", "yes"]: print("Upgrading checkpoint...") assert len(cfg.architectures) == 1 setattr(cfg.__class__, "model_type", "llava") cfg.architectures[0] = 'LlavaLlamaForCausalLM' cfg.save_pretrained(config) print("Checkpoint upgraded.") else: print("Checkpoint upgrade aborted.") exit(1) class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1] self.tokenizer = tokenizer self.start_len = None self.input_ids = input_ids def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if self.start_len is None: self.start_len = self.input_ids.shape[1] else: for keyword_id in self.keyword_ids: if output_ids[0, -1] == keyword_id: return True outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False