winglian commited on
Commit
74ef5cc
β€’
2 Parent(s): 94f310c c7dee56

Merge pull request #192 from OpenAccess-AI-Collective/sharegpt-custom-prompt

Browse files
README.md CHANGED
@@ -219,6 +219,14 @@ Have dataset(s) in one of the following format (JSONL recommended):
219
  ```json
220
  {"conversations": [{"role": "...", "value": "..."}]}
221
  ```
 
 
 
 
 
 
 
 
222
 
223
  </details>
224
 
@@ -530,7 +538,7 @@ Try set `fp16: true`
530
 
531
  Try to turn off xformers.
532
 
533
- ## Need help? πŸ™‹β€β™‚οΈ
534
 
535
  Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
536
 
 
219
  ```json
220
  {"conversations": [{"role": "...", "value": "..."}]}
221
  ```
222
+ - `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
223
+ ```json
224
+ {"conversations": [{"role": "...", "value": "..."}]}
225
+ ```
226
+ - `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
227
+ ```json
228
+ {"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
229
+ ```
230
 
231
  </details>
232
 
 
538
 
539
  Try to turn off xformers.
540
 
541
+ ## Need help? πŸ™‹β™‚οΈ
542
 
543
  Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
544
 
src/axolotl/datasets.py CHANGED
@@ -33,12 +33,16 @@ class TokenizedPromptDataset(IterableDataset):
33
 
34
  def __iter__(self):
35
  iterator = iter(self.dataset)
 
36
  # Loop through the entire dataset
37
  for example in iterator:
38
  try:
39
  yield self.prompt_tokenizer.tokenize_prompt(example)
 
40
  except InvalidDataException:
41
  pass
 
 
42
 
43
 
44
  # TODO this isn't the best since it can't interleave datasets
 
33
 
34
  def __iter__(self):
35
  iterator = iter(self.dataset)
36
+ count = 0
37
  # Loop through the entire dataset
38
  for example in iterator:
39
  try:
40
  yield self.prompt_tokenizer.tokenize_prompt(example)
41
+ count += 1
42
  except InvalidDataException:
43
  pass
44
+ if count == 0:
45
+ raise RuntimeError("Expected at least one datapoint in dataset.")
46
 
47
 
48
  # TODO this isn't the best since it can't interleave datasets
src/axolotl/prompt_strategies/sharegpt_jokes.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for Jokes prompts using sharegpt style """
2
+ from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
3
+ from axolotl.prompters import PromptStyle, ShareGPTPrompter
4
+
5
+
6
+ def load(tokenizer, cfg):
7
+ return SimpleJokesShareGPTPromptTokenizingStrategy(
8
+ ShareGPTPrompter(PromptStyle.CHAT.value),
9
+ tokenizer,
10
+ cfg.train_on_inputs,
11
+ cfg.sequence_len,
12
+ )
13
+
14
+
15
+ class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
16
+ """
17
+ Tokenization strategy for asking bot to tell a joke and then explain why its funny
18
+ """
19
+
20
+ # title, text, explanation
21
+ def get_conversation_thread(self, prompt):
22
+ title = "" if not prompt["title"] else prompt["title"] + " "
23
+ return [
24
+ {"from": "human", "value": "Tell me a joke."},
25
+ {"from": "gpt", "value": title + prompt["text"]},
26
+ {"from": "human", "value": "Why is that joke funny?"},
27
+ {"from": "gpt", "value": prompt["explanation"]},
28
+ ]
src/axolotl/prompt_strategies/sharegpt_simple.py CHANGED
@@ -13,6 +13,15 @@ def load(tokenizer, cfg):
13
  )
14
 
15
 
 
 
 
 
 
 
 
 
 
16
  def load_guanaco(tokenizer, cfg):
17
  return GuanacoShareGPTPromptTokenizingStrategy(
18
  ShareGPTPrompter(PromptStyle.CHAT.value),
@@ -31,6 +40,18 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
31
  return prompt["conversations"]
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
35
  """
36
  sharegpt strategy that remaps oasst data to sharegpt format
 
13
  )
14
 
15
 
16
+ def load_role(tokenizer, cfg):
17
+ return SimpleRoleShareGPTPromptTokenizingStrategy(
18
+ ShareGPTPrompter(PromptStyle.CHAT.value),
19
+ tokenizer,
20
+ cfg.train_on_inputs,
21
+ cfg.sequence_len,
22
+ )
23
+
24
+
25
  def load_guanaco(tokenizer, cfg):
26
  return GuanacoShareGPTPromptTokenizingStrategy(
27
  ShareGPTPrompter(PromptStyle.CHAT.value),
 
40
  return prompt["conversations"]
41
 
42
 
43
+ class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
44
+ """
45
+ basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
46
+ """
47
+
48
+ def get_conversation_thread(self, prompt):
49
+ conversations = prompt["conversations"]
50
+ # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
51
+ turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
52
+ return turns
53
+
54
+
55
  class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
56
  """
57
  sharegpt strategy that remaps oasst data to sharegpt format
src/axolotl/prompters.py CHANGED
@@ -261,28 +261,33 @@ class Conversation:
261
  self.messages.append([role, message])
262
 
263
 
264
- conv_vicuna_v1_1 = Conversation(
265
- system="A chat between a curious user and an artificial intelligence assistant. "
266
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
267
- roles=["USER", "ASSISTANT"],
268
- messages=[],
269
- offset=0,
270
- sep_style=SeparatorStyle.TWO,
271
- sep=" ",
272
- sep2=" ",
273
- )
274
-
275
-
276
  class ShareGPTPrompter: # pylint: disable=too-few-public-methods
277
  """
278
  A prompter that generates prompts for the ShareGPT
279
  """
280
 
281
- def __init__(self, prompt_style=None):
282
  if prompt_style != PromptStyle.CHAT.value:
283
  raise ValueError(
284
  f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
285
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  # def match_prompt_style(self):
288
  # if self.prompt_style == PromptStyle.chat.value:
@@ -300,7 +305,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
300
  # also happens on the data splitting leaving empty conversations
301
  raise IndexError
302
 
303
- conv = conv_vicuna_v1_1.copy()
304
  roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
305
 
306
  try:
 
261
  self.messages.append([role, message])
262
 
263
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  class ShareGPTPrompter: # pylint: disable=too-few-public-methods
265
  """
266
  A prompter that generates prompts for the ShareGPT
267
  """
268
 
269
+ def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
270
  if prompt_style != PromptStyle.CHAT.value:
271
  raise ValueError(
272
  f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
273
  )
274
+ system: str = (
275
+ system_prompt
276
+ if system_prompt
277
+ else (
278
+ "A chat between a curious user and an artificial intelligence assistant. "
279
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
280
+ )
281
+ )
282
+ self._conversation = Conversation(
283
+ system=system,
284
+ roles=["USER", "ASSISTANT"],
285
+ messages=[],
286
+ offset=0,
287
+ sep_style=SeparatorStyle.TWO,
288
+ sep=" ",
289
+ sep2=" ",
290
+ )
291
 
292
  # def match_prompt_style(self):
293
  # if self.prompt_style == PromptStyle.chat.value:
 
305
  # also happens on the data splitting leaving empty conversations
306
  raise IndexError
307
 
308
+ conv = self._conversation.copy()
309
  roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
310
 
311
  try:
src/axolotl/utils/data.py CHANGED
@@ -239,8 +239,15 @@ def load_tokenized_prepared_datasets(
239
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
240
  datasets.append(ds_wrapper)
241
  else:
242
- logging.error(f"unhandled prompt tokenization strategy: {d.type}")
243
- raise ValueError(f"unhandled prompt tokenization strategy: {d.type}")
 
 
 
 
 
 
 
244
  logging.info("tokenizing, merging, and shuffling master dataset")
245
 
246
  samples: List[int] = []
 
239
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
240
  datasets.append(ds_wrapper)
241
  else:
242
+ suffix = ""
243
+ if ":load_" in d.type:
244
+ suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
245
+ logging.error(
246
+ f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
247
+ )
248
+ raise ValueError(
249
+ f"unhandled prompt tokenization strategy: {d.type} {suffix}"
250
+ )
251
  logging.info("tokenizing, merging, and shuffling master dataset")
252
 
253
  samples: List[int] = []