qwerrwe / src /axolotl /prompters.py
winglian's picture
WIP large refactor to make finetune script a little more manageable (#3)
6045345 unverified
raw
history blame
7.82 kB
import copy
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any, Union
IGNORE_TOKEN_ID = -100
class AlpacaPrompter:
prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
prompt_no_input = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
response_split = "### Response:"
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None,
output: Union[None, str] = None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.prompt_input.format(instruction=instruction, input=input)
else:
res = self.prompt_no_input.format(instruction=instruction)
if output:
res = f"{res}{output}"
return res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class GPTeacherPrompter(AlpacaPrompter):
...
class NomicGPT4AllPrompter(AlpacaPrompter):
...
class ReflectAlpacaPrompter:
prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
prompt_no_input = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Response:\n"
agent_label = "{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
response_split = "### Response:"
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None,
output: Union[None, str] = None,
reflection: Union[None, str] = None,
corrected: Union[None, str] = None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.prompt_input.format(instruction=instruction, input=input)
else:
res = self.prompt_no_input.format(instruction=instruction)
if output and reflection and corrected:
label = self.agent_label.format(output=output, reflection=reflection, corrected=corrected)
res = f"{res}{label}"
return res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
DOLLY = auto()
# TODO clean this 💩 up
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
def get_prompt(self):
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
)
def append_message(self, role, message):
self.messages.append([role, message])
conv_vicuna_v1_1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=["USER", "ASSISTANT"],
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
class ShareGPTPrompter:
def build_prompt(self, source, tokenizer):
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError
conv = conv_vicuna_v1_1.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
try:
# Apply prompt templates
if (
source[0]["from"] not in roles
or roles[source[0]["from"]] != conv.roles[0]
):
# Skip the first one if it is not from human
source = source[1:]
except IndexError as e:
# sometimes there is a bing or system chat
raise e
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2]
conv.append_message(role, sentence["value"])
conversation = conv.get_prompt()
# Tokenize conversations
tokenized_result = tokenizer(
conversation,
truncation=True,
max_length=2048, # FIXME
padding=False,
return_tensors=None,
)
target = copy.deepcopy(tokenized_result["input_ids"])
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
rounds = conversation.split(conv.sep2)
cur_len = 1
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(tokenizer(rou)["input_ids"])
instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2
target[cur_len : cur_len + instruction_len] = [
IGNORE_TOKEN_ID
] * instruction_len
cur_len += round_len
target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
attention_mask = [
1 if x != tokenizer.pad_token_id else 0
for x in tokenized_result["input_ids"]
]
return dict(
input_ids=tokenized_result["input_ids"],
labels=target,
attention_mask=attention_mask,
)