winglian commited on
Commit
9a02e7e
2 Parent(s): 328c3bc 5b33e29

Merge pull request #155 from OpenAccess-AI-Collective/misc-fixes

Browse files

new prompters, misc fixes for output dir missing using fsdp, and changing max seq len

README.md CHANGED
@@ -165,10 +165,30 @@ Have dataset(s) in one of the following format (JSONL recommended):
165
  ```json
166
  {"article": "...", "summary": "..."}
167
  ```
 
 
 
 
168
  - `alpaca_chat.load_qa`: question and answer for alpaca chat
169
  ```json
170
  {"question": "...", "answer": "..."}
171
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  - `creative_acr.load_answer`: instruction and revision
173
  ```json
174
  {"instruction": "...", "revision": "..."}
 
165
  ```json
166
  {"article": "...", "summary": "..."}
167
  ```
168
+ - `alpaca_chat`: basic instruct for alpaca chat
169
+ ```json
170
+ {"instruction": "...", "input": "...", "response": "..."}
171
+ ```
172
  - `alpaca_chat.load_qa`: question and answer for alpaca chat
173
  ```json
174
  {"question": "...", "answer": "..."}
175
  ```
176
+ - `alpaca_chat.load_concise`: question and answer for alpaca chat, for concise answers
177
+ ```json
178
+ {"instruction": "...", "input": "...", "response": "..."}
179
+ ```
180
+ - `alpaca_chat.load_camel_ai`: question and answer for alpaca chat, for load_camel_ai
181
+ ```json
182
+ {"message_1": "...", "message_2": "..."}
183
+ ```
184
+ - `context_qa`: in context question answering from an article
185
+ ```json
186
+ {"article": "...", "question": "...", "answer": "..."}
187
+ ```
188
+ - `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
189
+ ```json
190
+ {"article": "...", "unanswerable_question": "..."}
191
+ ```
192
  - `creative_acr.load_answer`: instruction and revision
193
  ```json
194
  {"instruction": "...", "revision": "..."}
scripts/finetune.py CHANGED
@@ -279,6 +279,9 @@ def train(
279
  logging.info(
280
  f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
281
  )
 
 
 
282
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
283
 
284
  logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
 
279
  logging.info(
280
  f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
281
  )
282
+
283
+ if not Path(cfg.output_dir).is_dir():
284
+ os.makedirs(cfg.output_dir, exist_ok=True)
285
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
286
 
287
  logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
src/axolotl/prompt_strategies/alpaca_chat.py CHANGED
@@ -18,6 +18,15 @@ def load(tokenizer, cfg):
18
  )
19
 
20
 
 
 
 
 
 
 
 
 
 
21
  class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
22
  """
23
  Tokenizing strategy for AlpacaQA
@@ -31,6 +40,28 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
31
  )
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def load_qa(tokenizer, cfg):
35
  return AlpacaQAPromptTokenizingStrategy(
36
  AlpacaPrompter(PromptStyle.CHAT.value),
@@ -38,3 +69,12 @@ def load_qa(tokenizer, cfg):
38
  cfg.train_on_inputs,
39
  cfg.sequence_len,
40
  )
 
 
 
 
 
 
 
 
 
 
18
  )
19
 
20
 
21
+ class AlpacaConcisePrompter(AlpacaPrompter):
22
+ """
23
+ Alpaca Prompter extending the system prompt to ask for concise answers
24
+ """
25
+
26
+ system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n"
27
+ system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n"
28
+
29
+
30
  class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
31
  """
32
  Tokenizing strategy for AlpacaQA
 
40
  )
41
 
42
 
43
+ class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
44
+ """
45
+ Tokenizing strategy for CamelAI datasets
46
+ """
47
+
48
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
49
+ return (
50
+ prompt["message_1"],
51
+ "",
52
+ prompt["message_1"],
53
+ )
54
+
55
+
56
+ def load_concise(tokenizer, cfg):
57
+ return AlpacaPromptTokenizingStrategy(
58
+ AlpacaConcisePrompter(PromptStyle.CHAT.value),
59
+ tokenizer,
60
+ cfg.train_on_inputs,
61
+ cfg.sequence_len,
62
+ )
63
+
64
+
65
  def load_qa(tokenizer, cfg):
66
  return AlpacaQAPromptTokenizingStrategy(
67
  AlpacaPrompter(PromptStyle.CHAT.value),
 
69
  cfg.train_on_inputs,
70
  cfg.sequence_len,
71
  )
72
+
73
+
74
+ def load_camel_ai(tokenizer, cfg):
75
+ return CamelAIPromptTokenizingStrategy(
76
+ AlpacaPrompter(PromptStyle.CHAT.value),
77
+ tokenizer,
78
+ cfg.train_on_inputs,
79
+ cfg.sequence_len,
80
+ )
src/axolotl/prompt_strategies/context_qa.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing the classes for Context QA Prompt Tokenization Strategies"""
2
+ from typing import Tuple
3
+
4
+ from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
5
+ from axolotl.prompters import AlpacaPrompter, PromptStyle
6
+
7
+
8
+ # article, unanswerable_question, question, answer
9
+ def load_404(tokenizer, cfg):
10
+ return AlpacaMissingInfoContextPromptTokenizingStrategy(
11
+ AlpacaContextPrompter(PromptStyle.CHAT.value),
12
+ tokenizer,
13
+ cfg.train_on_inputs,
14
+ cfg.sequence_len,
15
+ )
16
+
17
+
18
+ def load(tokenizer, cfg):
19
+ return AlpacaContextPromptTokenizingStrategy(
20
+ AlpacaContextPrompter(PromptStyle.CHAT.value),
21
+ tokenizer,
22
+ cfg.train_on_inputs,
23
+ cfg.sequence_len,
24
+ )
25
+
26
+
27
+ class AlpacaContextPrompter(AlpacaPrompter):
28
+ """
29
+ Customized system prompted for concise QA
30
+ """
31
+
32
+ system_prompt = (
33
+ "Use the following contextual information to concisely answer the question.\n"
34
+ )
35
+ system_no_input_prompt = (
36
+ "Use the following contextual information to concisely answer the question.\n"
37
+ )
38
+
39
+
40
+ class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
41
+ """
42
+ Tokenization Strategy to combine in-context article with a question and answer
43
+ """
44
+
45
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
46
+ return (
47
+ prompt["article"] + "\n===\n" + prompt["question"],
48
+ "",
49
+ prompt["answer"],
50
+ )
51
+
52
+
53
+ class AlpacaMissingInfoContextPromptTokenizingStrategy(
54
+ InstructionPromptTokenizingStrategy
55
+ ):
56
+ """
57
+ Tokenization Strategy to combine in-context article with a question that can't be answered
58
+ from the context and a default response to that effect
59
+ """
60
+
61
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
62
+ return (
63
+ prompt["article"] + "\n===\n" + prompt["unanswerable_question"],
64
+ "",
65
+ "The context provided does not contain any information about your inquiry. "
66
+ "Therefore, I'm unable to answer your question based on the given context.",
67
+ )
src/axolotl/utils/models.py CHANGED
@@ -234,6 +234,10 @@ def load_model(
234
  base_model,
235
  trust_remote_code=cfg.trust_remote_code or False,
236
  )
 
 
 
 
237
  model = AutoModelForCausalLM.from_pretrained(
238
  base_model,
239
  config=config,
 
234
  base_model,
235
  trust_remote_code=cfg.trust_remote_code or False,
236
  )
237
+ # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
238
+ # when training starts
239
+ if config.max_seq_len and cfg.sequence_len > config.max_seq_len:
240
+ config.max_seq_len = cfg.sequence_len
241
  model = AutoModelForCausalLM.from_pretrained(
242
  base_model,
243
  config=config,