seungduk commited on
Commit
3a99495
1 Parent(s): 440c3ab

improve: Enhance code readability of prompt_tokenizers.py (#707)

Browse files
Files changed (1) hide show
  1. src/axolotl/prompt_tokenizers.py +80 -107
src/axolotl/prompt_tokenizers.py CHANGED
@@ -45,6 +45,8 @@ class PromptTokenizingStrategy(abc.ABC):
45
  self.prompter = prompter
46
  self.tokenizer: PreTrainedTokenizer = tokenizer
47
  self.train_on_inputs = train_on_inputs
 
 
48
  self.sequence_len = sequence_len
49
  self.max_length = sequence_len
50
 
@@ -59,34 +61,31 @@ class PromptTokenizingStrategy(abc.ABC):
59
  def _tokenize(
60
  self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
61
  ) -> BatchEncoding:
62
- result: BatchEncoding
63
  if not prompt:
64
  LOG.warning("Empty text requested for tokenization.")
65
- result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
66
- else:
67
- result = self.tokenizer(
68
- prompt,
69
- truncation=True,
70
- max_length=self.max_length,
71
- padding=False,
72
- return_tensors=None,
73
- )
74
  if len(result["input_ids"]) == 0:
75
  LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
 
 
76
  if (
77
- len(result["input_ids"]) > 0
78
- and result["input_ids"][-1] != self.tokenizer.eos_token_id
79
  and len(result["input_ids"]) < self.max_length
80
  and add_eos_token
81
  ):
82
  result["input_ids"].append(self.tokenizer.eos_token_id)
83
  result["attention_mask"].append(1)
84
 
85
- if (
86
- len(result["input_ids"]) > 0
87
- and result["input_ids"][0] == self.tokenizer.bos_token_id
88
- and strip_bos_token
89
- ):
90
  result["input_ids"] = result["input_ids"][1:]
91
  result["attention_mask"] = result["attention_mask"][1:]
92
 
@@ -122,7 +121,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
122
  if not self.train_on_inputs:
123
  user_prompt_len = len(tokenized_prompt["input_ids"])
124
  # TODO this could be sped up using numpy array slicing
125
- tokenized_prompt["labels"] = [-100] * user_prompt_len
126
  tokenized_res_prompt = self._tokenize(
127
  response, strip_bos_token=True, add_eos_token=True
128
  )
@@ -270,7 +269,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
270
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
271
  # TODO this could be sped up using numpy array slicing
272
  tokenized_full_prompt["labels"] = [
273
- -100
274
  ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
275
 
276
  return tokenized_full_prompt
@@ -334,6 +333,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
334
  return prompt["conversations"]
335
 
336
  def tokenize_prompt(self, prompt):
 
337
  result, current_len = tokenize_prompt_default()
338
  conversation: Conversation = (
339
  self.prompter._conversation.copy() # pylint: disable=protected-access
@@ -355,62 +355,67 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
355
  for _, part in enumerate(
356
  self.prompter.build_prompt(self.get_conversation_thread(prompt))
357
  ):
358
- if isinstance(part, tuple):
359
- if conversation.roles[0] in part[0]:
360
- role = (
361
- part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
362
- if role_remap
363
- else part[0]
364
- )
365
- turn = role + part[1]
366
- # this is still the user query, we should
367
- if not part[1].strip():
368
- LOG.warning(f"user turn has empty text: {prompt}")
369
- res = self._tokenize(
370
- turn,
371
- add_eos_token=False,
372
- strip_bos_token=True,
373
- )
374
- # everything from this is masked out from the labels
375
- labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
376
- elif conversation.roles[1] in part[0]:
377
- # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
378
- role = (
379
- part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
380
- if role_remap
381
- else part[0]
382
- )
383
- turn = role + part[1]
384
- # this should be the assistant response, should end with an eos token
385
- if not part[1].strip():
386
- LOG.warning(f"assistant turn has empty text: {prompt}")
387
- res = self._tokenize(
388
- turn,
389
- add_eos_token=True,
390
- strip_bos_token=True,
391
- )
392
- role_res = self._tokenize(
393
- role.rstrip(),
394
- add_eos_token=False,
395
- strip_bos_token=True,
396
- )
397
- # not masked out from labels
398
- labels = copy.deepcopy(res["input_ids"])
399
- len_role = len(role_res["input_ids"])
400
- labels[:len_role] = [IGNORE_TOKEN_ID] * min(
401
- len_role, len(labels)
402
- )
403
- elif part[0] == "":
404
- turn = part[1]
405
- # this is only ever the first part, should include the bos token and the user query
406
- res = self._tokenize(
407
- turn, add_eos_token=False, strip_bos_token=False
408
- )
409
- # everything from this is masked out from the labels
410
- labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
411
- else:
412
- LOG.warning(f"unhandled role: {part[0]}")
413
- continue
 
 
 
 
 
414
 
415
  # pylint: disable=duplicate-code
416
  result, current_len = parse_tokenized_to_result(
@@ -424,38 +429,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
424
  except (KeyError, AssertionError, IndexError) as err:
425
  raise InvalidDataException(str(err)) from err
426
 
427
- def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
428
- if not prompt.strip():
429
- LOG.warning("Empty text requested for tokenization.")
430
- result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
431
- else:
432
- result = self.tokenizer(
433
- prompt,
434
- truncation=True,
435
- max_length=self.sequence_len,
436
- padding=False,
437
- return_tensors=None,
438
- )
439
- if (
440
- len(result["input_ids"]) > 0
441
- and result["input_ids"][-1] != self.tokenizer.eos_token_id
442
- and len(result["input_ids"]) < self.sequence_len
443
- and add_eos_token
444
- ):
445
- result["input_ids"].append(self.tokenizer.eos_token_id)
446
- result["attention_mask"].append(1)
447
-
448
- if (
449
- len(result["input_ids"]) > 0
450
- and result["input_ids"][0] == self.tokenizer.bos_token_id
451
- and strip_bos_token
452
- ):
453
- result["input_ids"] = result["input_ids"][1:]
454
- result["attention_mask"] = result["attention_mask"][1:]
455
-
456
- result["labels"] = result["input_ids"].copy()
457
- return result
458
-
459
 
460
  def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
461
  """
 
45
  self.prompter = prompter
46
  self.tokenizer: PreTrainedTokenizer = tokenizer
47
  self.train_on_inputs = train_on_inputs
48
+ # sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
49
+ # TODO: Document how they are different.
50
  self.sequence_len = sequence_len
51
  self.max_length = sequence_len
52
 
 
61
  def _tokenize(
62
  self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
63
  ) -> BatchEncoding:
64
+ empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
65
  if not prompt:
66
  LOG.warning("Empty text requested for tokenization.")
67
+ return empty
68
+
69
+ result = self.tokenizer(
70
+ prompt,
71
+ truncation=True,
72
+ max_length=self.max_length,
73
+ padding=False,
74
+ return_tensors=None,
75
+ )
76
  if len(result["input_ids"]) == 0:
77
  LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
78
+ return empty
79
+
80
  if (
81
+ result["input_ids"][-1] != self.tokenizer.eos_token_id
 
82
  and len(result["input_ids"]) < self.max_length
83
  and add_eos_token
84
  ):
85
  result["input_ids"].append(self.tokenizer.eos_token_id)
86
  result["attention_mask"].append(1)
87
 
88
+ if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
 
 
 
 
89
  result["input_ids"] = result["input_ids"][1:]
90
  result["attention_mask"] = result["attention_mask"][1:]
91
 
 
121
  if not self.train_on_inputs:
122
  user_prompt_len = len(tokenized_prompt["input_ids"])
123
  # TODO this could be sped up using numpy array slicing
124
+ tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len
125
  tokenized_res_prompt = self._tokenize(
126
  response, strip_bos_token=True, add_eos_token=True
127
  )
 
269
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
270
  # TODO this could be sped up using numpy array slicing
271
  tokenized_full_prompt["labels"] = [
272
+ IGNORE_INDEX
273
  ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
274
 
275
  return tokenized_full_prompt
 
333
  return prompt["conversations"]
334
 
335
  def tokenize_prompt(self, prompt):
336
+ # Initial values. We will append to these as we go through the conversation.
337
  result, current_len = tokenize_prompt_default()
338
  conversation: Conversation = (
339
  self.prompter._conversation.copy() # pylint: disable=protected-access
 
355
  for _, part in enumerate(
356
  self.prompter.build_prompt(self.get_conversation_thread(prompt))
357
  ):
358
+ if not isinstance(part, tuple):
359
+ LOG.warning(f"expected tuple, got {part}")
360
+ continue
361
+
362
+ user, assistant = conversation.roles
363
+ role, content = part
364
+
365
+ # Uses "in" because role contains extra characters
366
+ if user in role:
367
+ role = (
368
+ role.replace(role_remap[0]["from"], role_remap[0]["to"])
369
+ if role_remap
370
+ else role
371
+ )
372
+ turn = role + content
373
+ # this is still the user query, we should
374
+ if not content.strip():
375
+ LOG.warning(f"user turn has empty text: {prompt}")
376
+ res = self._tokenize(
377
+ turn,
378
+ add_eos_token=False,
379
+ strip_bos_token=True,
380
+ )
381
+ # everything from this is masked out from the labels
382
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
383
+ elif assistant in role:
384
+ # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
385
+ role = (
386
+ role.replace(role_remap[1]["from"], role_remap[1]["to"])
387
+ if role_remap
388
+ else role
389
+ )
390
+ turn = role + content
391
+ # this should be the assistant response, should end with an eos token
392
+ if not content.strip():
393
+ LOG.warning(f"assistant turn has empty text: {prompt}")
394
+ res = self._tokenize(
395
+ turn,
396
+ add_eos_token=True,
397
+ strip_bos_token=True,
398
+ )
399
+ role_res = self._tokenize(
400
+ role.rstrip(),
401
+ add_eos_token=False,
402
+ strip_bos_token=True,
403
+ )
404
+ # not masked out from labels
405
+ labels = copy.deepcopy(res["input_ids"])
406
+ len_role = len(role_res["input_ids"])
407
+ labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
408
+ elif role == "":
409
+ turn = content
410
+ # this is only ever the first part, should include the bos token and the user query
411
+ res = self._tokenize(
412
+ turn, add_eos_token=False, strip_bos_token=False
413
+ )
414
+ # everything from this is masked out from the labels
415
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
416
+ else:
417
+ LOG.warning(f"unhandled role: {role}")
418
+ continue
419
 
420
  # pylint: disable=duplicate-code
421
  result, current_len = parse_tokenized_to_result(
 
429
  except (KeyError, AssertionError, IndexError) as err:
430
  raise InvalidDataException(str(err)) from err
431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
  def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
434
  """