winglian commited on
Commit
f150c02
2 Parent(s): 5c39c00 7b57ed7

Merge pull request #224 from OpenAccess-AI-Collective/system-prompt-data

Browse files
src/axolotl/datasets.py CHANGED
@@ -126,6 +126,7 @@ class ConstantLengthDataset(IterableDataset):
126
  buffer_len = 0
127
 
128
  if example:
 
129
  # just going to drop data points that are too long
130
  if len(example["input_ids"]) <= self.seq_length:
131
  input_ids = example["input_ids"]
 
126
  buffer_len = 0
127
 
128
  if example:
129
+ # FIXME
130
  # just going to drop data points that are too long
131
  if len(example["input_ids"]) <= self.seq_length:
132
  input_ids = example["input_ids"]
src/axolotl/prompt_strategies/alpaca_chat.py CHANGED
@@ -45,8 +45,10 @@ class NoSystemPrompter(AlpacaPrompter):
45
  Null Prompter with no system prompts
46
  """
47
 
48
- prompt_input = "{instruction} {input} "
49
- prompt_no_input = "{instruction} "
 
 
50
 
51
  def __init__(self): # pylint: disable=super-init-not-called
52
  pass
 
45
  Null Prompter with no system prompts
46
  """
47
 
48
+ system_prompt = ""
49
+ system_no_input_prompt = ""
50
+ turn_format = "{instruction} {input} "
51
+ turn_no_input_format = "{instruction} "
52
 
53
  def __init__(self): # pylint: disable=super-init-not-called
54
  pass
src/axolotl/prompt_strategies/alpaca_w_system.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt strategies loader for alpaca instruction datasets with system prompts
3
+ """
4
+ from typing import Generator, Tuple, Union
5
+
6
+ from axolotl.prompt_tokenizers import PromptTokenizingStrategy
7
+ from axolotl.prompters import AlpacaPrompter, PromptStyle
8
+
9
+
10
+ class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
11
+ """
12
+ Tokenizing strategy for instruction-based prompts.
13
+ """
14
+
15
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
16
+ return (
17
+ prompt["instruction"],
18
+ prompt["input"] if "input" in prompt else "",
19
+ prompt["output"],
20
+ prompt["system"],
21
+ )
22
+
23
+ def tokenize_prompt(self, prompt):
24
+ # pylint: disable=duplicate-code
25
+ (
26
+ instruction,
27
+ input, # pylint: disable=redefined-builtin
28
+ response,
29
+ system,
30
+ ) = self.parse_instruction_fields(prompt)
31
+ user_prompt = next(
32
+ iter(
33
+ self.prompter.build_prompt_w_system(
34
+ system,
35
+ instruction,
36
+ input,
37
+ )
38
+ )
39
+ )
40
+ tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
41
+ if not self.train_on_inputs:
42
+ user_prompt_len = len(tokenized_prompt["input_ids"])
43
+ # TODO this could be sped up using numpy array slicing
44
+ tokenized_prompt["labels"] = [-100] * user_prompt_len
45
+ tokenized_res_prompt = self._tokenize(
46
+ response, strip_bos_token=True, add_eos_token=True
47
+ )
48
+ tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
49
+ tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
50
+ tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
51
+
52
+ return tokenized_prompt
53
+
54
+
55
+ class SystemDataPrompter(AlpacaPrompter):
56
+ """
57
+ Alpaca Style Prompter that uses system prompts from the dataset
58
+ """
59
+
60
+ def build_prompt_w_system(
61
+ self,
62
+ system: str,
63
+ instruction: str,
64
+ input: Union[None, str] = None, # pylint: disable=redefined-builtin
65
+ output: Union[None, str] = None,
66
+ ) -> Generator[str, None, None]:
67
+ # returns the full prompt from instruction and optional input
68
+ # if a label (=response, =output) is provided, it's also appended.
69
+ if input:
70
+ res = system + self.turn_format.format(instruction=instruction, input=input)
71
+ else:
72
+ res = system + self.turn_no_input_format.format(instruction=instruction)
73
+ if output:
74
+ res = f"{res}{output}"
75
+ yield res
76
+
77
+
78
+ def load(tokenizer, cfg):
79
+ return InstructionWSystemPromptTokenizingStrategy(
80
+ SystemDataPrompter(PromptStyle.CHAT.value),
81
+ tokenizer,
82
+ cfg.train_on_inputs,
83
+ cfg.sequence_len,
84
+ )
src/axolotl/prompt_tokenizers.py CHANGED
@@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
87
  Tokenizing strategy for instruction-based prompts.
88
  """
89
 
90
- def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
 
 
91
  raise NotImplementedError
92
 
93
  def tokenize_prompt(self, prompt):
 
87
  Tokenizing strategy for instruction-based prompts.
88
  """
89
 
90
+ def parse_instruction_fields(
91
+ self, prompt
92
+ ) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
93
  raise NotImplementedError
94
 
95
  def tokenize_prompt(self, prompt):
src/axolotl/prompters.py CHANGED
@@ -24,6 +24,8 @@ class AlpacaPrompter:
24
 
25
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
26
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
 
 
27
  prompt_style: Optional[PromptStyle] = None
28
 
29
  def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
@@ -32,23 +34,13 @@ class AlpacaPrompter:
32
 
33
  def match_prompt_style(self):
34
  if self.prompt_style == PromptStyle.INSTRUCT.value:
35
- self.prompt_input = (
36
- self.system_prompt
37
- + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
38
- )
39
- self.prompt_no_input = (
40
- self.system_no_input_prompt
41
- + "### Instruction:\n{instruction}\n\n### Response:\n"
42
  )
43
- self.response_split = "### Response:"
44
  if self.prompt_style == PromptStyle.CHAT.value:
45
- self.prompt_input = (
46
- self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
47
- )
48
- self.prompt_no_input = (
49
- self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
50
- )
51
- self.response_split = "ASSISTANT:"
52
 
53
  def build_prompt(
54
  self,
@@ -59,16 +51,17 @@ class AlpacaPrompter:
59
  # returns the full prompt from instruction and optional input
60
  # if a label (=response, =output) is provided, it's also appended.
61
  if input:
62
- res = self.prompt_input.format(instruction=instruction, input=input)
 
 
63
  else:
64
- res = self.prompt_no_input.format(instruction=instruction)
 
 
65
  if output:
66
  res = f"{res}{output}"
67
  yield res
68
 
69
- def get_response(self, output: str) -> str:
70
- return output.split(self.response_split)[1].strip()
71
-
72
 
73
  class UnpromptedPrompter(AlpacaPrompter):
74
  """
@@ -93,7 +86,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
93
  """
94
 
95
  system_prompt = (
96
- "Choose the answer that best answers the question. Explain your reasoning."
 
 
 
97
  )
98
 
99
 
@@ -102,7 +98,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
102
  Prompter for multiple choice concise
103
  """
104
 
105
- prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
 
 
 
 
 
106
 
107
 
108
  class SummarizeTLDRPrompter(AlpacaPrompter):
@@ -110,9 +111,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
110
  Prompter for summarize TLDR
111
  """
112
 
113
- prompt_no_input = (
114
- "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
115
- )
 
 
 
116
 
117
 
118
  class CompletionPrompter:
@@ -128,9 +132,6 @@ class CompletionPrompter:
128
  ) -> Generator[str, None, None]:
129
  yield instruction
130
 
131
- def get_response(self, output: str) -> str:
132
- return output.strip()
133
-
134
 
135
  class GPTeacherPrompter(AlpacaPrompter):
136
  """
@@ -210,9 +211,6 @@ class ReflectAlpacaPrompter:
210
  res = f"{res}{label}"
211
  yield res
212
 
213
- def get_response(self, output: str) -> str:
214
- return output.split(self.response_split)[1].strip()
215
-
216
 
217
  class SeparatorStyle(Enum):
218
  """Different separator style."""
@@ -289,12 +287,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
289
  sep2=" ",
290
  )
291
 
292
- # def match_prompt_style(self):
293
- # if self.prompt_style == PromptStyle.chat.value:
294
- # self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
295
- # self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
296
- # self.response_split = "ASSISTANT:"
297
-
298
  def build_prompt(self, source) -> Generator[str, None, None]:
299
  # ignore the system prompt if provided
300
  if source[0]["from"] == "system":
 
24
 
25
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
26
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
27
+ turn_format: str
28
+ turn_no_input_format: str
29
  prompt_style: Optional[PromptStyle] = None
30
 
31
  def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
 
34
 
35
  def match_prompt_style(self):
36
  if self.prompt_style == PromptStyle.INSTRUCT.value:
37
+ self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
38
+ self.turn_no_input_format = (
39
+ "### Instruction:\n{instruction}\n\n### Response:\n"
 
 
 
 
40
  )
 
41
  if self.prompt_style == PromptStyle.CHAT.value:
42
+ self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
43
+ self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
 
 
 
 
 
44
 
45
  def build_prompt(
46
  self,
 
51
  # returns the full prompt from instruction and optional input
52
  # if a label (=response, =output) is provided, it's also appended.
53
  if input:
54
+ res = self.system_prompt + self.turn_format.format(
55
+ instruction=instruction, input=input
56
+ )
57
  else:
58
+ res = self.system_no_input_prompt + self.turn_no_input_format.format(
59
+ instruction=instruction
60
+ )
61
  if output:
62
  res = f"{res}{output}"
63
  yield res
64
 
 
 
 
65
 
66
  class UnpromptedPrompter(AlpacaPrompter):
67
  """
 
86
  """
87
 
88
  system_prompt = (
89
+ "Choose the answer that best answers the question. Explain your reasoning.\n"
90
+ )
91
+ system_no_input_prompt = (
92
+ "Choose the answer that best answers the question. Explain your reasoning.\n"
93
  )
94
 
95
 
 
98
  Prompter for multiple choice concise
99
  """
100
 
101
+ system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
102
+ system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
103
+
104
+ def match_prompt_style(self):
105
+ self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
106
+ self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
107
 
108
 
109
  class SummarizeTLDRPrompter(AlpacaPrompter):
 
111
  Prompter for summarize TLDR
112
  """
113
 
114
+ system_prompt = ""
115
+ system_no_input_prompt = ""
116
+
117
+ def match_prompt_style(self):
118
+ self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
119
+ self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
120
 
121
 
122
  class CompletionPrompter:
 
132
  ) -> Generator[str, None, None]:
133
  yield instruction
134
 
 
 
 
135
 
136
  class GPTeacherPrompter(AlpacaPrompter):
137
  """
 
211
  res = f"{res}{label}"
212
  yield res
213
 
 
 
 
214
 
215
  class SeparatorStyle(Enum):
216
  """Different separator style."""
 
287
  sep2=" ",
288
  )
289
 
 
 
 
 
 
 
290
  def build_prompt(self, source) -> Generator[str, None, None]:
291
  # ignore the system prompt if provided
292
  if source[0]["from"] == "system":
src/axolotl/utils/tokenization.py CHANGED
@@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer):
34
 
35
  logging.info(" ".join(colored_tokens))
36
  logging.info("\n\n\n")
 
 
 
34
 
35
  logging.info(" ".join(colored_tokens))
36
  logging.info("\n\n\n")
37
+
38
+ return " ".join(colored_tokens)
tests/test_prompt_tokenizers.py CHANGED
@@ -7,11 +7,15 @@ from pathlib import Path
7
  from transformers import AutoTokenizer
8
 
9
  from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
 
 
 
 
10
  from axolotl.prompt_tokenizers import (
11
  AlpacaPromptTokenizingStrategy,
12
  ShareGPTPromptTokenizingStrategy,
13
  )
14
- from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter
15
 
16
  logging.basicConfig(level="INFO")
17
 
@@ -96,5 +100,39 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
96
  assert example["labels"][world_idx - 1] == -100
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  if __name__ == "__main__":
100
  unittest.main()
 
7
  from transformers import AutoTokenizer
8
 
9
  from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
10
+ from axolotl.prompt_strategies.alpaca_w_system import (
11
+ InstructionWSystemPromptTokenizingStrategy,
12
+ SystemDataPrompter,
13
+ )
14
  from axolotl.prompt_tokenizers import (
15
  AlpacaPromptTokenizingStrategy,
16
  ShareGPTPromptTokenizingStrategy,
17
  )
18
+ from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
19
 
20
  logging.basicConfig(level="INFO")
21
 
 
100
  assert example["labels"][world_idx - 1] == -100
101
 
102
 
103
+ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
104
+ """
105
+ Test class for prompt tokenization strategies with sys prompt from the dataset
106
+ """
107
+
108
+ def setUp(self) -> None:
109
+ # pylint: disable=duplicate-code
110
+ self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
111
+ self.tokenizer.add_special_tokens(
112
+ {
113
+ "bos_token": "<s>",
114
+ "eos_token": "</s>",
115
+ "unk_token": "<unk>",
116
+ }
117
+ )
118
+
119
+ def test_system_alpaca(self):
120
+ prompter = SystemDataPrompter(PromptStyle.CHAT.value)
121
+ strat = InstructionWSystemPromptTokenizingStrategy(
122
+ prompter,
123
+ self.tokenizer,
124
+ False,
125
+ 2048,
126
+ )
127
+ sample = {
128
+ "system": "use cot",
129
+ "instruction": "hello!",
130
+ "output": "Hi! How can I help?",
131
+ }
132
+ example = strat.tokenize_prompt(sample)
133
+ assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot
134
+ assert example["input_ids"][3] == 11889 # USER
135
+
136
+
137
  if __name__ == "__main__":
138
  unittest.main()
tests/test_prompters.py CHANGED
@@ -2,7 +2,13 @@
2
 
3
  import unittest
4
 
5
- from axolotl.prompters import AlpacaPrompter, PromptStyle
 
 
 
 
 
 
6
 
7
 
8
  class AlpacaPrompterTest(unittest.TestCase):
@@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase):
55
  assert "### Response:" not in res
56
  assert "USER:" in res
57
  assert "ASSISTANT:" in res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import unittest
4
 
5
+ from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter
6
+ from axolotl.prompters import (
7
+ AlpacaPrompter,
8
+ MultipleChoiceExplainPrompter,
9
+ PromptStyle,
10
+ UnpromptedPrompter,
11
+ )
12
 
13
 
14
  class AlpacaPrompterTest(unittest.TestCase):
 
61
  assert "### Response:" not in res
62
  assert "USER:" in res
63
  assert "ASSISTANT:" in res
64
+
65
+ def test_system_prompt(self):
66
+ prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
67
+ res = next(
68
+ prompter.build_prompt_w_system(
69
+ "use cot", "tell me a joke about the following", "alpacas"
70
+ )
71
+ )
72
+ assert "use cot" in res
73
+ assert res.startswith("use cot")
74
+ assert "### Instruction:" not in res
75
+ assert "### Input:" not in res
76
+ assert "alpacas" in res
77
+ assert "### Response:" not in res
78
+ assert "USER:" in res
79
+ assert "ASSISTANT:" in res
80
+
81
+
82
+ class UnpromptedPrompterTest(unittest.TestCase):
83
+ """
84
+ Test class for UnpromptedPrompter with no system prompts
85
+ """
86
+
87
+ def test_prompt_style_w_none(self):
88
+ prompter = UnpromptedPrompter(prompt_style=None)
89
+ res = next(prompter.build_prompt("tell me a joke"))
90
+ assert "### Instruction:" in res
91
+ assert "tell me a joke" in res
92
+ assert res.startswith("###")
93
+
94
+ def test_prompt_style_w_instruct(self):
95
+ prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
96
+ res = next(
97
+ prompter.build_prompt("tell me a joke about the following", "alpacas")
98
+ )
99
+ assert "### Instruction:" in res
100
+ assert "tell me a joke" in res
101
+ assert res.startswith("###")
102
+
103
+ def test_prompt_style_w_chat(self):
104
+ prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
105
+ res = next(
106
+ prompter.build_prompt("tell me a joke about the following", "alpacas")
107
+ )
108
+ assert "USER:" in res
109
+ assert "tell me a joke" in res
110
+ assert res.startswith("USER:")
111
+
112
+
113
+ class MultipleChoiceExplainPrompterTest(unittest.TestCase):
114
+ """
115
+ Test class for MultipleChoiceExplainPrompter
116
+ """
117
+
118
+ def test_prompt_style_w_chat(self):
119
+ prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
120
+ res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
121
+ assert "USER:" in res
122
+ assert "choose one" in res
123
+ assert "Choose the answer that best answers the question." in res
124
+ assert "- A\n- B\n- C" in res