Nanobit commited on
Commit
8e46c0f
1 Parent(s): 1f3c3f5

Refactor duplicate code between Prompter and Pygmalion

Browse files
src/axolotl/prompt_strategies/pygmalion.py CHANGED
@@ -5,7 +5,11 @@ import logging
5
  from collections import defaultdict
6
  from typing import Generator
7
 
8
- from axolotl.prompt_tokenizers import PromptTokenizingStrategy
 
 
 
 
9
 
10
  IGNORE_TOKEN_ID = -100
11
 
@@ -23,12 +27,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
23
  self.bot_prefix_token_ids = res["input_ids"]
24
 
25
  def tokenize_prompt(self, prompt):
26
- result = {
27
- "input_ids": [],
28
- "attention_mask": [],
29
- "labels": [],
30
- }
31
- current_len = 0
32
  for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
33
  role, message = part
34
  if role == "system":
@@ -67,37 +66,15 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
67
  else:
68
  logging.warning(f"unknown role in conversation: {role}")
69
  res = defaultdict(lambda: [])
70
- input_ids = res["input_ids"]
71
- input_len = len(input_ids)
72
- result["input_ids"][current_len : current_len + input_len] = input_ids
73
- result["attention_mask"][current_len : current_len + input_len] = [
74
- 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
75
- ]
76
- result["labels"][current_len : current_len + input_len] = labels
77
- current_len += input_len
78
- return result
79
-
80
- def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
81
- result = self.tokenizer(
82
- prompt,
83
- truncation=True,
84
- max_length=self.sequence_len,
85
- padding=False,
86
- return_tensors=None,
87
- )
88
- if (
89
- result["input_ids"][-1] != self.tokenizer.eos_token_id
90
- and len(result["input_ids"]) < self.sequence_len
91
- and add_eos_token
92
- ):
93
- result["input_ids"].append(self.tokenizer.eos_token_id)
94
- result["attention_mask"].append(1)
95
-
96
- if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
97
- result["input_ids"] = result["input_ids"][1:]
98
- result["attention_mask"] = result["attention_mask"][1:]
99
 
100
- result["labels"] = result["input_ids"].copy()
 
 
 
 
 
 
 
101
  return result
102
 
103
 
 
5
  from collections import defaultdict
6
  from typing import Generator
7
 
8
+ from axolotl.prompt_tokenizers import (
9
+ PromptTokenizingStrategy,
10
+ parse_tokenized_to_result,
11
+ tokenize_prompt_default,
12
+ )
13
 
14
  IGNORE_TOKEN_ID = -100
15
 
 
27
  self.bot_prefix_token_ids = res["input_ids"]
28
 
29
  def tokenize_prompt(self, prompt):
30
+ result, current_len = tokenize_prompt_default()
 
 
 
 
 
31
  for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
32
  role, message = part
33
  if role == "system":
 
66
  else:
67
  logging.warning(f"unknown role in conversation: {role}")
68
  res = defaultdict(lambda: [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ # pylint: disable=duplicate-code
71
+ result, current_len = parse_tokenized_to_result(
72
+ result,
73
+ current_len,
74
+ res,
75
+ labels,
76
+ pad_token_id=self.tokenizer.pad_token_id,
77
+ )
78
  return result
79
 
80
 
src/axolotl/prompt_tokenizers.py CHANGED
@@ -4,7 +4,7 @@ import abc
4
  import copy
5
  import functools
6
  import logging
7
- from typing import Tuple
8
 
9
  from transformers import PreTrainedTokenizer
10
 
@@ -58,6 +58,29 @@ class PromptTokenizingStrategy(abc.ABC):
58
  return id_or_ids
59
  return False
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
63
  """
@@ -106,29 +129,6 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
106
  )
107
  )
108
 
109
- def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
110
- result = self.tokenizer(
111
- prompt,
112
- truncation=True,
113
- max_length=self.sequence_len,
114
- padding=False,
115
- return_tensors=None,
116
- )
117
- if (
118
- result["input_ids"][-1] != self.tokenizer.eos_token_id
119
- and len(result["input_ids"]) < self.sequence_len
120
- and add_eos_token
121
- ):
122
- result["input_ids"].append(self.tokenizer.eos_token_id)
123
- result["attention_mask"].append(1)
124
-
125
- if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
126
- result["input_ids"] = result["input_ids"][1:]
127
- result["attention_mask"] = result["attention_mask"][1:]
128
-
129
- result["labels"] = result["input_ids"].copy()
130
- return result
131
-
132
 
133
  class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
134
  """
@@ -295,7 +295,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
295
  )
296
  )
297
 
298
- def _tokenize(self, prompt, add_eos_token=True):
299
  result = self.tokenizer(
300
  prompt,
301
  truncation=True,
@@ -339,12 +339,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
339
  return prompt["conversations"]
340
 
341
  def tokenize_prompt(self, prompt):
342
- result = {
343
- "input_ids": [],
344
- "attention_mask": [],
345
- "labels": [],
346
- }
347
- current_len = 0
348
  user_token = self._get_user_token()
349
  assistant_token = self._get_assistant_token()
350
  try:
@@ -382,14 +377,15 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
382
  )
383
  # everything from this is masked out from the labels
384
  labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
385
- input_ids = res["input_ids"]
386
- input_len = len(input_ids)
387
- result["input_ids"][current_len : current_len + input_len] = input_ids
388
- result["attention_mask"][current_len : current_len + input_len] = [
389
- 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
390
- ]
391
- result["labels"][current_len : current_len + input_len] = labels
392
- current_len += input_len
 
393
  return result
394
  except (KeyError, AssertionError, IndexError) as err:
395
  raise InvalidDataException(str(err)) from err
@@ -416,3 +412,40 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
416
 
417
  result["labels"] = result["input_ids"].copy()
418
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import copy
5
  import functools
6
  import logging
7
+ from typing import Dict, List, Tuple
8
 
9
  from transformers import PreTrainedTokenizer
10
 
 
58
  return id_or_ids
59
  return False
60
 
61
+ def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
62
+ result = self.tokenizer(
63
+ prompt,
64
+ truncation=True,
65
+ max_length=self.sequence_len,
66
+ padding=False,
67
+ return_tensors=None,
68
+ )
69
+ if (
70
+ result["input_ids"][-1] != self.tokenizer.eos_token_id
71
+ and len(result["input_ids"]) < self.sequence_len
72
+ and add_eos_token
73
+ ):
74
+ result["input_ids"].append(self.tokenizer.eos_token_id)
75
+ result["attention_mask"].append(1)
76
+
77
+ if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
78
+ result["input_ids"] = result["input_ids"][1:]
79
+ result["attention_mask"] = result["attention_mask"][1:]
80
+
81
+ result["labels"] = result["input_ids"].copy()
82
+ return result
83
+
84
 
85
  class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
86
  """
 
129
  )
130
  )
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
134
  """
 
295
  )
296
  )
297
 
298
+ def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
299
  result = self.tokenizer(
300
  prompt,
301
  truncation=True,
 
339
  return prompt["conversations"]
340
 
341
  def tokenize_prompt(self, prompt):
342
+ result, current_len = tokenize_prompt_default()
 
 
 
 
 
343
  user_token = self._get_user_token()
344
  assistant_token = self._get_assistant_token()
345
  try:
 
377
  )
378
  # everything from this is masked out from the labels
379
  labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
380
+
381
+ # pylint: disable=duplicate-code
382
+ result, current_len = parse_tokenized_to_result(
383
+ result,
384
+ current_len,
385
+ res,
386
+ labels,
387
+ pad_token_id=self.tokenizer.pad_token_id,
388
+ )
389
  return result
390
  except (KeyError, AssertionError, IndexError) as err:
391
  raise InvalidDataException(str(err)) from err
 
412
 
413
  result["labels"] = result["input_ids"].copy()
414
  return result
415
+
416
+
417
+ def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
418
+ """
419
+ Returns the default values for the tokenize prompt function
420
+ """
421
+
422
+ result = {
423
+ "input_ids": [],
424
+ "attention_mask": [],
425
+ "labels": [],
426
+ }
427
+ current_len = 0
428
+ return result, current_len
429
+
430
+
431
+ def parse_tokenized_to_result(
432
+ result: Dict[str, List[int]],
433
+ current_len: int,
434
+ res: Dict[str, List[int]],
435
+ labels: list[int],
436
+ pad_token_id: int | None = None,
437
+ ) -> Tuple[Dict[str, List[int]], int]:
438
+ """
439
+ Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result
440
+ """
441
+
442
+ input_ids = res["input_ids"]
443
+ input_len = len(input_ids)
444
+ result["input_ids"][current_len : current_len + input_len] = input_ids
445
+ result["attention_mask"][current_len : current_len + input_len] = [
446
+ 1 if x != pad_token_id else 0 for x in input_ids
447
+ ]
448
+ result["labels"][current_len : current_len + input_len] = labels
449
+ current_len += input_len
450
+
451
+ return result, current_len