winglian commited on
Commit
59bb219
1 Parent(s): 9a02e7e

fix camel ai, add guanaco/oasst mapping for sharegpt

Browse files
src/axolotl/prompt_strategies/alpaca_chat.py CHANGED
@@ -49,7 +49,7 @@ class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
49
  return (
50
  prompt["message_1"],
51
  "",
52
- prompt["message_1"],
53
  )
54
 
55
 
 
49
  return (
50
  prompt["message_1"],
51
  "",
52
+ prompt["message_2"],
53
  )
54
 
55
 
src/axolotl/prompt_strategies/sharegpt_simple.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
2
+
3
+ from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
4
+ from axolotl.prompters import PromptStyle, ShareGPTPrompter
5
+
6
+
7
+ def load(tokenizer, cfg):
8
+ return SimpleShareGPTPromptTokenizingStrategy(
9
+ ShareGPTPrompter(PromptStyle.CHAT.value),
10
+ tokenizer,
11
+ cfg.train_on_inputs,
12
+ cfg.sequence_len,
13
+ )
14
+
15
+
16
+ def load_guanaco(tokenizer, cfg):
17
+ return GuanacoShareGPTPromptTokenizingStrategy(
18
+ ShareGPTPrompter(PromptStyle.CHAT.value),
19
+ tokenizer,
20
+ cfg.train_on_inputs,
21
+ cfg.sequence_len,
22
+ )
23
+
24
+
25
+ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
26
+ """
27
+ basic sharegpt strategy to grab conversations from the sample row
28
+ """
29
+
30
+ def get_conversation_thread(self, prompt):
31
+ return prompt["conversations"]
32
+
33
+
34
+ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
35
+ """
36
+ sharegpt strategy that remaps oasst data to sharegpt format
37
+ """
38
+
39
+ def get_conversation_thread(self, prompt):
40
+ conversations = prompt["conversations"]
41
+ # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
42
+ role_map = {"prompter": "human", "assistant": "gpt"}
43
+ turns = [
44
+ {"from": role_map[t["role"]], "value": t["text"]} for t in conversations
45
+ ]
46
+ return turns