qwerrwe / src /axolotl /prompt_strategies /alpaca_w_system.py
winglian's picture
pylint for duplicated code for system prompts
7b57ed7
raw
history blame
2.86 kB
"""
Prompt strategies loader for alpaca instruction datasets with system prompts
"""
from typing import Generator, Tuple, Union
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
prompt["output"],
prompt["system"],
)
def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
(
instruction,
input, # pylint: disable=redefined-builtin
response,
system,
) = self.parse_instruction_fields(prompt)
user_prompt = next(
iter(
self.prompter.build_prompt_w_system(
system,
instruction,
input,
)
)
)
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_prompt["labels"] = [-100] * user_prompt_len
tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True
)
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
return tokenized_prompt
class SystemDataPrompter(AlpacaPrompter):
"""
Alpaca Style Prompter that uses system prompts from the dataset
"""
def build_prompt_w_system(
self,
system: str,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = system + self.turn_format.format(instruction=instruction, input=input)
else:
res = system + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
yield res
def load(tokenizer, cfg):
return InstructionWSystemPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)