hezhihui commited on
Commit
b352d20
1 Parent(s): 9403e15

update chat msgs

Browse files
Files changed (1) hide show
  1. modeling_minicpmv.py +7 -7
modeling_minicpmv.py CHANGED
@@ -1,13 +1,10 @@
1
  import math
2
- from typing import List, Optional
3
  import json
4
  import torch
5
- import torchvision
6
  from threading import Thread
7
  from copy import deepcopy
8
- from PIL import Image
9
  from torchvision import transforms
10
- from transformers import LlamaTokenizer, LlamaPreTrainedModel, LlamaForCausalLM, AutoModel, PreTrainedTokenizerFast, TextIteratorStreamer
11
  from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
12
  from transformers import AutoProcessor
13
 
@@ -91,7 +88,9 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
91
  img_cnt = []
92
  for pixel_values in pixel_values_list:
93
  img_cnt.append(len(pixel_values))
94
- all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values]) # exist image
 
 
95
  if all_pixel_values:
96
  tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
97
 
@@ -290,17 +289,18 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
290
  processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
291
  if isinstance(msgs, str):
292
  msgs = json.loads(msgs)
 
293
 
294
  assert len(msgs) > 0, 'msgs is empty'
295
  assert sampling or not stream, 'if use stream mode, make sure sampling=True'
296
 
297
  if image is not None and isinstance(msgs[0]['content'], str):
298
- msgs[0]['content'] = '(<image>./</image>)\n' + msgs[0]['content']
299
  if system_prompt:
300
  sys_msg = {'role': 'system', 'content': system_prompt}
301
  copy_msgs = [sys_msg] + copy_msgs
302
 
303
- prompt = processor.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
304
  inputs = processor(prompt, [image], return_tensors="pt", max_length=max_inp_length).to(self.device)
305
 
306
  if sampling:
 
1
  import math
 
2
  import json
3
  import torch
 
4
  from threading import Thread
5
  from copy import deepcopy
 
6
  from torchvision import transforms
7
+ from transformers import LlamaPreTrainedModel, LlamaForCausalLM, TextIteratorStreamer
8
  from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
9
  from transformers import AutoProcessor
10
 
 
88
  img_cnt = []
89
  for pixel_values in pixel_values_list:
90
  img_cnt.append(len(pixel_values))
91
+ all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
92
+
93
+ # exist image
94
  if all_pixel_values:
95
  tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
96
 
 
289
  processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
290
  if isinstance(msgs, str):
291
  msgs = json.loads(msgs)
292
+ copy_msgs = deepcopy(msgs)
293
 
294
  assert len(msgs) > 0, 'msgs is empty'
295
  assert sampling or not stream, 'if use stream mode, make sure sampling=True'
296
 
297
  if image is not None and isinstance(msgs[0]['content'], str):
298
+ copy_msgs[0]['content'] = '(<image>./</image>)\n' + copy_msgs[0]['content']
299
  if system_prompt:
300
  sys_msg = {'role': 'system', 'content': system_prompt}
301
  copy_msgs = [sys_msg] + copy_msgs
302
 
303
+ prompt = processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)
304
  inputs = processor(prompt, [image], return_tensors="pt", max_length=max_inp_length).to(self.device)
305
 
306
  if sampling: