winglian commited on
Commit
5e37144
1 Parent(s): bdbca8f

fix prompters, especially the sharegpt prompter

Browse files
src/axolotl/prompt_tokenizers.py CHANGED
@@ -1,7 +1,10 @@
1
  import abc
 
2
 
3
  from transformers import PreTrainedTokenizer
4
 
 
 
5
  IGNORE_INDEX = -100
6
  LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
7
  LLAMA_DEFAULT_EOS_TOKEN = "</s>"
@@ -40,10 +43,10 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
40
  full_prompt = self._build_full_prompt(instruction, input, response)
41
  tokenized_full_prompt = self._tokenize(full_prompt)
42
  if not self.train_on_inputs:
43
- user_prompt = self.prompter.build_prompt(
44
  instruction,
45
  input,
46
- )
47
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
48
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
49
  # TODO this could be sped up using numpy array slicing
@@ -54,11 +57,11 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
54
  return tokenized_full_prompt
55
 
56
  def _build_full_prompt(self, instruction, input, response):
57
- return self.prompter.build_prompt(
58
  instruction,
59
  input,
60
  response,
61
- )
62
 
63
  def _tokenize(self, prompt, add_eos_token=True):
64
  result = self.tokenizer(
@@ -131,13 +134,13 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
131
 
132
  def tokenize_prompt(self, prompt):
133
  instruction = self.parse_instruction_fields(prompt)
134
- full_prompt = self._build_full_prompt(instruction)
135
  tokenized_full_prompt = self._tokenize(full_prompt)
136
 
137
  return tokenized_full_prompt
138
 
139
- def _build_full_prompt(self, instruction):
140
- return self.prompter.build_prompt(instruction)
141
 
142
 
143
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
@@ -157,10 +160,10 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
157
  )
158
  tokenized_full_prompt = self._tokenize(full_prompt)
159
  if not self.train_on_inputs:
160
- user_prompt = self.prompter.build_prompt(
161
  instruction,
162
  input,
163
- )
164
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
165
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
166
  # TODO this could be sped up using numpy array slicing
@@ -171,13 +174,13 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
171
  return tokenized_full_prompt
172
 
173
  def _build_full_prompt(self, instruction, input, output, reflection, corrected):
174
- return self.prompter.build_prompt(
175
  instruction,
176
  input,
177
  output,
178
  reflection,
179
  corrected,
180
- )
181
 
182
  def _tokenize(self, prompt, add_eos_token=True):
183
  result = self.tokenizer(
@@ -212,7 +215,64 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
212
 
213
  class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
214
  def tokenize_prompt(self, prompt):
 
 
 
 
 
 
215
  try:
216
- return self.prompter.build_prompt(prompt["conversations"], self.tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  except (KeyError, AssertionError, IndexError) as e:
218
  raise InvalidDataException(str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import abc
2
+ import copy
3
 
4
  from transformers import PreTrainedTokenizer
5
 
6
+ from axolotl.prompters import IGNORE_TOKEN_ID
7
+
8
  IGNORE_INDEX = -100
9
  LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
10
  LLAMA_DEFAULT_EOS_TOKEN = "</s>"
 
43
  full_prompt = self._build_full_prompt(instruction, input, response)
44
  tokenized_full_prompt = self._tokenize(full_prompt)
45
  if not self.train_on_inputs:
46
+ user_prompt = next(iter(self.prompter.build_prompt(
47
  instruction,
48
  input,
49
+ )))
50
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
51
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
52
  # TODO this could be sped up using numpy array slicing
 
57
  return tokenized_full_prompt
58
 
59
  def _build_full_prompt(self, instruction, input, response):
60
+ return next(iter(self.prompter.build_prompt(
61
  instruction,
62
  input,
63
  response,
64
+ )))
65
 
66
  def _tokenize(self, prompt, add_eos_token=True):
67
  result = self.tokenizer(
 
134
 
135
  def tokenize_prompt(self, prompt):
136
  instruction = self.parse_instruction_fields(prompt)
137
+ full_prompt = self._build_full_prompt(instruction, None, None)
138
  tokenized_full_prompt = self._tokenize(full_prompt)
139
 
140
  return tokenized_full_prompt
141
 
142
+ def _build_full_prompt(self, instruction, input, response):
143
+ return next(iter(self.prompter.build_prompt(instruction)))
144
 
145
 
146
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
 
160
  )
161
  tokenized_full_prompt = self._tokenize(full_prompt)
162
  if not self.train_on_inputs:
163
+ user_prompt = next(iter(self.prompter.build_prompt(
164
  instruction,
165
  input,
166
+ )))
167
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
168
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
169
  # TODO this could be sped up using numpy array slicing
 
174
  return tokenized_full_prompt
175
 
176
  def _build_full_prompt(self, instruction, input, output, reflection, corrected):
177
+ return next(iter(self.prompter.build_prompt(
178
  instruction,
179
  input,
180
  output,
181
  reflection,
182
  corrected,
183
+ )))
184
 
185
  def _tokenize(self, prompt, add_eos_token=True):
186
  result = self.tokenizer(
 
215
 
216
  class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
217
  def tokenize_prompt(self, prompt):
218
+ result = {
219
+ "input_ids": [],
220
+ "attention_mask": [],
221
+ "labels": [],
222
+ }
223
+ current_len = 0
224
  try:
225
+ for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"], self.tokenizer)):
226
+ if i == 0:
227
+ # this is only ever the first part, should include the bos token and the user query
228
+ res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=False)
229
+ # everything from this is masked out from the labels
230
+ labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
231
+ elif i % 2 == 0:
232
+ # this is still the user query, we should
233
+ res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True)
234
+ # everything from this is masked out from the labels
235
+ labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
236
+ else:
237
+ # this should be the assistent response, should end with an eos token
238
+ res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True)
239
+ # not masked out from labels
240
+ labels = copy.deepcopy(res["input_ids"])
241
+ input_ids = res["input_ids"]
242
+ input_len = len(input_ids)
243
+ result["input_ids"][current_len : current_len + input_len] = input_ids
244
+ result["attention_mask"][current_len : current_len + input_len] = [
245
+ 1 if x != self.tokenizer.pad_token_id else 0
246
+ for x in input_ids
247
+ ]
248
+ result["labels"][current_len : current_len + input_len] = labels
249
+ current_len += input_len
250
+ return result
251
  except (KeyError, AssertionError, IndexError) as e:
252
  raise InvalidDataException(str(e))
253
+
254
+ def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
255
+ result = self.tokenizer(
256
+ prompt,
257
+ truncation=True,
258
+ max_length=self.sequence_len,
259
+ padding=False,
260
+ return_tensors=None,
261
+ )
262
+ if (
263
+ result["input_ids"][-1] != self.tokenizer.eos_token_id
264
+ and len(result["input_ids"]) < self.sequence_len
265
+ and add_eos_token
266
+ ):
267
+ result["input_ids"].append(self.tokenizer.eos_token_id)
268
+ result["attention_mask"].append(1)
269
+
270
+ if (
271
+ result["input_ids"][0] == self.tokenizer.bos_token_id
272
+ and strip_bos_token
273
+ ):
274
+ result["input_ids"] = result["input_ids"][1:]
275
+ result["attention_mask"] = result["attention_mask"][1:]
276
+
277
+ result["labels"] = result["input_ids"].copy()
278
+ return result
src/axolotl/prompters.py CHANGED
@@ -1,7 +1,7 @@
1
  import copy
2
  import dataclasses
3
  from enum import auto, Enum
4
- from typing import List, Tuple, Any, Union
5
 
6
  IGNORE_TOKEN_ID = -100
7
 
@@ -16,7 +16,7 @@ class AlpacaPrompter:
16
  instruction: str,
17
  input: Union[None, str] = None,
18
  output: Union[None, str] = None,
19
- ) -> str:
20
  # returns the full prompt from instruction and optional input
21
  # if a label (=response, =output) is provided, it's also appended.
22
  if input:
@@ -25,7 +25,7 @@ class AlpacaPrompter:
25
  res = self.prompt_no_input.format(instruction=instruction)
26
  if output:
27
  res = f"{res}{output}"
28
- return res
29
 
30
  def get_response(self, output: str) -> str:
31
  return output.split(self.response_split)[1].strip()
@@ -36,8 +36,8 @@ class JeopardyPrompter(AlpacaPrompter):
36
 
37
 
38
  class CompletionPrompter(AlpacaPrompter):
39
- def build_prompt(self, instruction: str) -> str:
40
- return instruction
41
 
42
  def get_response(self, output: str) -> str:
43
  return output.strip()
@@ -64,7 +64,7 @@ class ReflectAlpacaPrompter:
64
  output: Union[None, str] = None,
65
  reflection: Union[None, str] = None,
66
  corrected: Union[None, str] = None,
67
- ) -> str:
68
  # returns the full prompt from instruction and optional input
69
  # if a label (=response, =output) is provided, it's also appended.
70
  if input:
@@ -76,7 +76,7 @@ class ReflectAlpacaPrompter:
76
  output=output, reflection=reflection, corrected=corrected
77
  )
78
  res = f"{res}{label}"
79
- return res
80
 
81
  def get_response(self, output: str) -> str:
82
  return output.split(self.response_split)[1].strip()
@@ -103,15 +103,16 @@ class Conversation:
103
  sep: str = "###"
104
  sep2: str = None
105
 
106
- def get_prompt(self):
107
  seps = [self.sep, self.sep2]
108
- ret = self.system + seps[0]
109
  for i, (role, message) in enumerate(self.messages):
110
  if message:
111
- ret += role + ": " + message + seps[i % 2]
112
  else:
113
- ret += role + ":"
114
- return ret
 
115
 
116
  def copy(self):
117
  return Conversation(
@@ -136,12 +137,12 @@ conv_vicuna_v1_1 = Conversation(
136
  offset=0,
137
  sep_style=SeparatorStyle.TWO,
138
  sep=" ",
139
- sep2="</s>",
140
  )
141
 
142
 
143
  class ShareGPTPrompter:
144
- def build_prompt(self, source, tokenizer, sequence_len=2048):
145
  # ignore the system prompt if provided
146
  if source[0]["from"] == "system":
147
  source.pop(0)
@@ -171,61 +172,6 @@ class ShareGPTPrompter:
171
  role = roles[sentence["from"]]
172
  assert role == conv.roles[j % 2]
173
  conv.append_message(role, sentence["value"])
174
- # TODO, this concatenates everything, but doesn't seem to properly add the eos_token_id, as the eos_token gets split up
175
- conversation = conv.get_prompt()
176
-
177
- # Tokenize conversations
178
- tokenized_result = tokenizer(
179
- conversation,
180
- truncation=True,
181
- max_length=sequence_len, # FIXME
182
- padding=False,
183
- return_tensors=None,
184
- )
185
- target = copy.deepcopy(tokenized_result["input_ids"])
186
-
187
- # Mask targets
188
- sep = conv.sep + conv.roles[1] + ": "
189
-
190
- rounds = conversation.split(conv.sep2)
191
- rounds = [r + conv.sep2 for r in rounds]
192
- cur_len = 1
193
- target[0] = IGNORE_TOKEN_ID # mask out the bos
194
- for i, rou in enumerate(rounds):
195
- if rou == "":
196
- break
197
-
198
- parts = rou.split(sep)
199
- if len(parts) != 2:
200
- break
201
- parts[0] += sep
202
- round_len = (
203
- len(tokenizer(rou)["input_ids"]) - 1
204
- ) # -1 ignores the bos_token generated for this
205
- # we have to strip the initial part, any dangling whitespace creates an additional ghost token
206
- instruction_len = (
207
- len(tokenizer(parts[0].strip())["input_ids"]) - 1
208
- ) # -1 ignores the bos_token generated for this
209
- target[cur_len : cur_len + instruction_len] = [
210
- IGNORE_TOKEN_ID
211
- ] * instruction_len
212
-
213
- cur_len += round_len
214
- if cur_len >= sequence_len:
215
- break
216
-
217
- # Fix: Truncate the target to have the same length as input_ids
218
- target = target[: len(tokenized_result["input_ids"])]
219
- # target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
220
-
221
- attention_mask = [
222
- 1 if x != tokenizer.pad_token_id else 0
223
- for x in tokenized_result["input_ids"]
224
- ]
225
-
226
- # TODO truncate len to sequence_len
227
- return dict(
228
- input_ids=tokenized_result["input_ids"],
229
- labels=target,
230
- attention_mask=attention_mask,
231
- )
 
1
  import copy
2
  import dataclasses
3
  from enum import auto, Enum
4
+ from typing import List, Tuple, Any, Union, Generator
5
 
6
  IGNORE_TOKEN_ID = -100
7
 
 
16
  instruction: str,
17
  input: Union[None, str] = None,
18
  output: Union[None, str] = None,
19
+ ) -> Generator[str, None, None]:
20
  # returns the full prompt from instruction and optional input
21
  # if a label (=response, =output) is provided, it's also appended.
22
  if input:
 
25
  res = self.prompt_no_input.format(instruction=instruction)
26
  if output:
27
  res = f"{res}{output}"
28
+ yield res
29
 
30
  def get_response(self, output: str) -> str:
31
  return output.split(self.response_split)[1].strip()
 
36
 
37
 
38
  class CompletionPrompter(AlpacaPrompter):
39
+ def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]:
40
+ yield instruction
41
 
42
  def get_response(self, output: str) -> str:
43
  return output.strip()
 
64
  output: Union[None, str] = None,
65
  reflection: Union[None, str] = None,
66
  corrected: Union[None, str] = None,
67
+ ) -> Generator[str, None, None]:
68
  # returns the full prompt from instruction and optional input
69
  # if a label (=response, =output) is provided, it's also appended.
70
  if input:
 
76
  output=output, reflection=reflection, corrected=corrected
77
  )
78
  res = f"{res}{label}"
79
+ yield res
80
 
81
  def get_response(self, output: str) -> str:
82
  return output.split(self.response_split)[1].strip()
 
103
  sep: str = "###"
104
  sep2: str = None
105
 
106
+ def get_prompt(self) -> Generator[str, None, None]:
107
  seps = [self.sep, self.sep2]
108
+ preamble = self.system + seps[0]
109
  for i, (role, message) in enumerate(self.messages):
110
  if message:
111
+ yield preamble + role + ": " + message + seps[i % 2]
112
  else:
113
+ yield role + ":"
114
+ if i == 0:
115
+ preamble = ""
116
 
117
  def copy(self):
118
  return Conversation(
 
137
  offset=0,
138
  sep_style=SeparatorStyle.TWO,
139
  sep=" ",
140
+ sep2=" ",
141
  )
142
 
143
 
144
  class ShareGPTPrompter:
145
+ def build_prompt(self, source, tokenizer, sequence_len=2048) -> Generator[str, None, None]:
146
  # ignore the system prompt if provided
147
  if source[0]["from"] == "system":
148
  source.pop(0)
 
172
  role = roles[sentence["from"]]
173
  assert role == conv.roles[j % 2]
174
  conv.append_message(role, sentence["value"])
175
+
176
+ for part in conv.get_prompt():
177
+ yield part