Brian Fitzgerald winglian commited on
Commit
b7d8a7d
1 Parent(s): b0ee9ec

Add Glaive conversation format support (#1365)

Browse files

* Add Glaive conversation format support

* fix black formatting errors

* Fix black and pylint formatting errors

* only set role_key_tool if provided in the dataset constructor

* Update src/axolotl/prompt_strategies/sharegpt.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* sharegpt test

* tokenizer test

* fix formatting

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

src/axolotl/prompt_strategies/sharegpt.py CHANGED
@@ -1,10 +1,15 @@
1
  """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
 
2
  from typing import Any, Dict, Optional
3
 
4
  from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
5
 
6
  from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
7
  from axolotl.prompters import ShareGPTPrompterV2
 
 
 
 
8
 
9
 
10
  def register_chatml_template(system_message=None):
@@ -19,6 +24,16 @@ def register_chatml_template(system_message=None):
19
  sep="<|im_end|>",
20
  )
21
  )
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
@@ -77,6 +92,20 @@ def load_guanaco(tokenizer, cfg):
77
  )
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
81
  """
82
  basic sharegpt strategy to grab conversations from the sample row
@@ -158,3 +187,15 @@ class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingSt
158
  {"from": role_map[t["role"]], "value": t["content"]} for t in conversations
159
  ]
160
  return turns
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
2
+
3
  from typing import Any, Dict, Optional
4
 
5
  from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
6
 
7
  from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
8
  from axolotl.prompters import ShareGPTPrompterV2
9
+ from axolotl.utils.tokenization import (
10
+ chatml_to_conversation,
11
+ merge_consecutive_messages,
12
+ )
13
 
14
 
15
  def register_chatml_template(system_message=None):
 
24
  sep="<|im_end|>",
25
  )
26
  )
27
+ register_conv_template(
28
+ Conversation(
29
+ name="chatml_glaive",
30
+ system_template="<|im_start|>system\n{system_message}",
31
+ system_message=system_message,
32
+ roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
33
+ sep_style=SeparatorStyle.CHATML,
34
+ sep="<|im_end|>",
35
+ )
36
+ )
37
 
38
 
39
  def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
 
92
  )
93
 
94
 
95
+ def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
96
+ conversation = (
97
+ ds_cfg["conversation"]
98
+ if ds_cfg and "conversation" in ds_cfg
99
+ else "chatml_glaive"
100
+ )
101
+ return GlaiveShareGPTPromptTokenizingStrategy(
102
+ ShareGPTPrompterV2(conversation=conversation),
103
+ tokenizer,
104
+ cfg.train_on_inputs,
105
+ cfg.sequence_len,
106
+ )
107
+
108
+
109
  class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
110
  """
111
  basic sharegpt strategy to grab conversations from the sample row
 
187
  {"from": role_map[t["role"]], "value": t["content"]} for t in conversations
188
  ]
189
  return turns
190
+
191
+
192
+ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
193
+ """
194
+ sharegpt strategy that remaps glaive data to sharegpt format
195
+ """
196
+
197
+ def get_conversation_thread(self, prompt):
198
+ conversation = chatml_to_conversation(prompt)
199
+ conversation = merge_consecutive_messages(conversation)
200
+
201
+ return conversation
src/axolotl/prompt_tokenizers.py CHANGED
@@ -360,11 +360,19 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
360
  LOG.warning(f"expected tuple, got {part}")
361
  continue
362
 
363
- user, assistant = conversation.roles
 
 
 
 
 
 
 
 
364
  role, content = part
365
 
366
  # Uses "in" because role contains extra characters
367
- if user in role:
368
  role = (
369
  role.replace(role_remap[0]["from"], role_remap[0]["to"])
370
  if role_remap
@@ -384,7 +392,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
384
  else:
385
  # everything from this is masked out from the labels
386
  labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
387
- elif assistant in role:
388
  role = (
389
  role.replace(role_remap[1]["from"], role_remap[1]["to"])
390
  if role_remap
@@ -426,6 +434,8 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
426
  else:
427
  # everything from this is masked out from the labels
428
  labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
 
 
429
  else:
430
  LOG.warning(f"unhandled role: {role}")
431
  continue
 
360
  LOG.warning(f"expected tuple, got {part}")
361
  continue
362
 
363
+ tool_role_label = None
364
+ if len(conversation.roles) == 3:
365
+ (
366
+ user_role_label,
367
+ assistant_role_label,
368
+ tool_role_label,
369
+ ) = conversation.roles
370
+ else:
371
+ user_role_label, assistant_role_label = conversation.roles
372
  role, content = part
373
 
374
  # Uses "in" because role contains extra characters
375
+ if user_role_label in role:
376
  role = (
377
  role.replace(role_remap[0]["from"], role_remap[0]["to"])
378
  if role_remap
 
392
  else:
393
  # everything from this is masked out from the labels
394
  labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
395
+ elif assistant_role_label in role:
396
  role = (
397
  role.replace(role_remap[1]["from"], role_remap[1]["to"])
398
  if role_remap
 
434
  else:
435
  # everything from this is masked out from the labels
436
  labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
437
+ elif tool_role_label and tool_role_label in role:
438
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
439
  else:
440
  LOG.warning(f"unhandled role: {role}")
441
  continue
src/axolotl/prompters.py CHANGED
@@ -267,6 +267,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
267
 
268
  role_key_human = "human"
269
  role_key_model = "gpt"
 
 
270
 
271
  def __init__(
272
  self,
@@ -274,6 +276,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
274
  conversation: Optional[Union[str, Conversation]] = None,
275
  role_key_human: Optional[str] = None,
276
  role_key_model: Optional[str] = None,
 
277
  ):
278
  if conversation:
279
  if isinstance(conversation, Conversation):
@@ -286,6 +289,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
286
  self.role_key_human = role_key_human
287
  if role_key_model:
288
  self.role_key_model = role_key_model
 
 
289
 
290
  def _build_result(self, source):
291
  if len(source) < 2:
@@ -303,6 +308,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
303
  source.pop(0)
304
 
305
  roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
 
 
306
 
307
  try:
308
  # Apply prompt templates
 
267
 
268
  role_key_human = "human"
269
  role_key_model = "gpt"
270
+ # Optional, only used for tool usage datasets.
271
+ role_key_tool = None
272
 
273
  def __init__(
274
  self,
 
276
  conversation: Optional[Union[str, Conversation]] = None,
277
  role_key_human: Optional[str] = None,
278
  role_key_model: Optional[str] = None,
279
+ role_key_tool: Optional[str] = None,
280
  ):
281
  if conversation:
282
  if isinstance(conversation, Conversation):
 
289
  self.role_key_human = role_key_human
290
  if role_key_model:
291
  self.role_key_model = role_key_model
292
+ if role_key_tool:
293
+ self.role_key_tool = role_key_tool
294
 
295
  def _build_result(self, source):
296
  if len(source) < 2:
 
308
  source.pop(0)
309
 
310
  roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
311
+ if self.role_key_tool:
312
+ roles[self.role_key_tool] = conv.roles[2]
313
 
314
  try:
315
  # Apply prompt templates
src/axolotl/utils/tokenization.py CHANGED
@@ -2,6 +2,8 @@
2
 
3
 
4
  import logging
 
 
5
 
6
  from termcolor import colored
7
 
@@ -36,3 +38,65 @@ def check_example_labels(example, tokenizer, text_only=False):
36
  LOG.info("\n\n\n")
37
 
38
  return " ".join(colored_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  import logging
5
+ import re
6
+ from typing import Dict, List
7
 
8
  from termcolor import colored
9
 
 
38
  LOG.info("\n\n\n")
39
 
40
  return " ".join(colored_tokens)
41
+
42
+
43
+ GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
44
+ GLAIVE_TO_SHAREGPT_ROLE = {
45
+ "SYSTEM": "system",
46
+ "USER": "human",
47
+ "ASSISTANT": "gpt",
48
+ "FUNCTION RESPONSE": "tool",
49
+ }
50
+
51
+ GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ")
52
+
53
+
54
+ def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]:
55
+ """
56
+ Converts a ChatML formatted row to a list of messages in ShareGPT format.
57
+ Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb.
58
+ """
59
+
60
+ system_prompt = row.get("system")
61
+ if system_prompt:
62
+ system_prompt = system_prompt.removeprefix("SYSTEM: ")
63
+
64
+ chat_str = row["chat"]
65
+ chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s]
66
+
67
+ chat_msg_dicts = [
68
+ {"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value}
69
+ for role, value in zip(chat_msgs[::2], chat_msgs[1::2])
70
+ ]
71
+
72
+ if system_prompt:
73
+ chat_msg_dicts = [
74
+ {"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt}
75
+ ] + chat_msg_dicts
76
+
77
+ return chat_msg_dicts
78
+
79
+
80
+ def merge_consecutive_messages(messages):
81
+ """
82
+ Merge consecutive messages from the same sender into a single message.
83
+ This can be useful with datasets that contain multiple consecutive tool calls.
84
+ """
85
+
86
+ merged_messages = []
87
+ current_from = None
88
+ current_message = ""
89
+
90
+ for msg in messages:
91
+ if current_from == msg["from"]:
92
+ current_message += msg["value"]
93
+ else:
94
+ if current_from is not None:
95
+ merged_messages.append({"from": current_from, "value": current_message})
96
+ current_from = msg["from"]
97
+ current_message = msg["value"]
98
+
99
+ if current_from is not None:
100
+ merged_messages.append({"from": current_from, "value": current_message})
101
+
102
+ return merged_messages
tests/prompt_strategies/test_sharegpt.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  Test module for sharegpt integration w chatml
3
  """
 
4
  import pytest
5
  from datasets import Dataset
6
  from tokenizers import AddedToken
@@ -8,6 +9,7 @@ from transformers import AutoTokenizer
8
 
9
  from axolotl.datasets import TokenizedPromptDataset
10
  from axolotl.prompt_strategies.sharegpt import (
 
11
  SimpleShareGPTPromptTokenizingStrategy,
12
  register_chatml_template,
13
  )
@@ -48,6 +50,18 @@ def fixture_sharegpt_dataset():
48
  )
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @pytest.fixture(name="tokenizer")
52
  def fixture_tokenizer():
53
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
@@ -156,3 +170,29 @@ class TestSharegpt:
156
  32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
157
  ]
158
  # fmt: on
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  Test module for sharegpt integration w chatml
3
  """
4
+
5
  import pytest
6
  from datasets import Dataset
7
  from tokenizers import AddedToken
 
9
 
10
  from axolotl.datasets import TokenizedPromptDataset
11
  from axolotl.prompt_strategies.sharegpt import (
12
+ GlaiveShareGPTPromptTokenizingStrategy,
13
  SimpleShareGPTPromptTokenizingStrategy,
14
  register_chatml_template,
15
  )
 
50
  )
51
 
52
 
53
+ @pytest.fixture(name="glaive_dataset")
54
+ def fixture_sharegpt_glaive_dataset():
55
+ return Dataset.from_list(
56
+ [
57
+ {
58
+ "system": "SYSTEM: This is a system prompt",
59
+ "chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
60
+ }
61
+ ]
62
+ )
63
+
64
+
65
  @pytest.fixture(name="tokenizer")
66
  def fixture_tokenizer():
67
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
 
170
  32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
171
  ]
172
  # fmt: on
173
+
174
+ def test_chatml_glaive(self, glaive_dataset, tokenizer):
175
+ strategy = GlaiveShareGPTPromptTokenizingStrategy(
176
+ ShareGPTPrompterV2(
177
+ conversation="chatml",
178
+ role_key_model=None,
179
+ role_key_human=None,
180
+ ),
181
+ tokenizer,
182
+ True, # train_on_inputs
183
+ 2048, # sequence_len
184
+ )
185
+
186
+ dataset_wrapper = TokenizedPromptDataset(
187
+ strategy, glaive_dataset, process_count=1
188
+ )
189
+
190
+ labels = dataset_wrapper[0]["labels"]
191
+ # fmt: off
192
+ assert labels == [
193
+ 1, # bos
194
+ 32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system
195
+ 32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human
196
+ 32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
197
+ ]
198
+ # fmt: on
tests/test_prompt_tokenizers.py CHANGED
@@ -1,4 +1,5 @@
1
  """Module for testing prompt tokenizers."""
 
2
  import json
3
  import logging
4
  import unittest
@@ -18,6 +19,7 @@ from axolotl.prompt_strategies.llama2_chat import (
18
  Llama2ChatPrompter,
19
  LLama2ChatTokenizingStrategy,
20
  )
 
21
  from axolotl.prompt_tokenizers import (
22
  AlpacaPromptTokenizingStrategy,
23
  ShareGPTPromptTokenizingStrategy,
@@ -266,6 +268,23 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
266
  idx = res["input_ids"].index(20255) # assistant token
267
  assert res["labels"][idx] == -100
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  def test_no_sys_prompt(self):
270
  """
271
  tests the interface between the user and assistant parts
 
1
  """Module for testing prompt tokenizers."""
2
+
3
  import json
4
  import logging
5
  import unittest
 
19
  Llama2ChatPrompter,
20
  LLama2ChatTokenizingStrategy,
21
  )
22
+ from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
23
  from axolotl.prompt_tokenizers import (
24
  AlpacaPromptTokenizingStrategy,
25
  ShareGPTPromptTokenizingStrategy,
 
268
  idx = res["input_ids"].index(20255) # assistant token
269
  assert res["labels"][idx] == -100
270
 
271
+ def test_glaive_tool_label_ignore(self):
272
+ conversation = {
273
+ "system": "SYSTEM: This is a system prompt",
274
+ "chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
275
+ }
276
+ prompter = ShareGPTPrompterV2()
277
+ strat = GlaiveShareGPTPromptTokenizingStrategy(
278
+ prompter,
279
+ self.tokenizer,
280
+ False,
281
+ 2048,
282
+ )
283
+ with self._caplog.at_level(logging.WARNING):
284
+ res = strat.tokenize_prompt(conversation)
285
+ idx = res["input_ids"].index(13566) # assistant token
286
+ assert res["labels"][idx] == -100
287
+
288
  def test_no_sys_prompt(self):
289
  """
290
  tests the interface between the user and assistant parts