""" User Defined prompts with configuration from the YML config """ from dataclasses import dataclass from functools import partial from typing import Optional, Tuple from axolotl.prompt_strategies.alpaca_w_system import ( InstructionWSystemPromptTokenizingStrategy, SystemDataPrompter, ) @dataclass class UserDefinedDatasetConfig: """ dataclass configuration representing a userdefined dataset type """ system_prompt: str = "" field_system: str = "system" field_instruction: str = "instruction" field_input: str = "input" field_output: str = "output" format: str = "{instruction} {input} " no_input_format: str = "{instruction} " system_format: str = "{system}" def __getitem__(self, item): return getattr(self, item) class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy): """ Prompt Tokenization Strategy for user defined prompts """ def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None): if not ds_cfg: raise ValueError("Missing dataset prompt configuration") system_prompt = "" if ds_cfg.system_prompt: system_prompt = ds_cfg.system_prompt def parse_instruction_fields( field_instruction, field_input, field_output, field_system, system_prompt, prompt, ) -> Tuple[str, str, str, str]: return ( prompt[field_instruction], prompt[field_input] if field_input in prompt else "", prompt[field_output] if field_output in prompt else "", prompt[field_system] if field_system in prompt else system_prompt, ) turn_format = ds_cfg.format turn_no_input_format = ds_cfg.no_input_format system_format = ds_cfg.system_format class UserDefinedPrompter(SystemDataPrompter): """ Prompter for user defined prompts """ def match_prompt_style(self): self.turn_format = turn_format self.turn_no_input_format = turn_no_input_format self.system_format = system_format prompter = UserDefinedPrompter() strat = UserDefinedPromptTokenizationStrategy( prompter, tokenizer, cfg.train_on_inputs, cfg.sequence_len, ) setattr( strat, "parse_instruction_fields", partial( parse_instruction_fields, ds_cfg.field_instruction, ds_cfg.field_input, ds_cfg.field_output, ds_cfg.field_system, system_prompt, ), ) return strat