File size: 831 Bytes
ce34d64
 
 
 
4ea9a66
 
 
 
 
ce34d64
 
 
 
4ea9a66
3a50377
 
 
 
 
 
 
 
 
 
 
 
 
ce34d64
 
 
 
3a50377
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from axolotl.prompt_tokenizers import (
    AlpacaPromptTokenizingStrategy,
    InstructionPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle


def load(tokenizer, cfg):
    return AlpacaPromptTokenizingStrategy(
        AlpacaPrompter(PromptStyle.chat.value),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )


class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    def parse_instruction_fields(self, prompt) -> (str, str, str):
        return (
            prompt["question"],
            "",
            prompt["answer"],
        )


def load_qa(tokenizer, cfg):
    return AlpacaQAPromptTokenizingStrategy(
        AlpacaPrompter(PromptStyle.chat.value),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )