""" Basic completion text """ from collections import defaultdict from typing import Any, Dict, Generator, Optional, Tuple from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): """ Tokenizing strategy for Completion prompts. """ _field: str = "text" def __init__(self, *args, max_length=None, **kwargs): super().__init__(*args, **kwargs) if max_length is not None: self.max_length = max_length @property def supports_batched(self): return True @property def field(self) -> str: return self._field @field.setter def field(self, new_field: str): self._field = new_field def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt[self.field], "", "", ) def tokenize_prompt(self, prompt): res = defaultdict(lambda: []) feature_names = list(prompt.keys()) for row in zip(*prompt.values()): prompt_row = dict(zip(feature_names, row)) ( instruction, _, _, ) = self.parse_instruction_fields(prompt_row) full_prompt = self._build_full_prompt(instruction, None, None) tokenized_full_prompt = self._tokenize(full_prompt) for key, val in tokenized_full_prompt.items(): for i in range(0, len(val), self.sequence_len): res[key].append(val[i : i + self.sequence_len]) return dict(res) def _build_full_prompt( self, instruction, input, response ): # pylint: disable=redefined-builtin return next(iter(self.prompter.build_prompt(instruction, input, response))) class CompletionPrompter: """ Prompter for completion """ def build_prompt( self, instruction: str, input=None, # pylint: disable=redefined-builtin, unused-argument output=None, # pylint: disable=unused-argument ) -> Generator[str, None, None]: yield instruction def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): strat = CompletionPromptTokenizingStrategy( CompletionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len, max_length=cfg.sequence_len * 64, ) if ds_cfg and "field" in ds_cfg: strat.field = ds_cfg["field"] return strat