File size: 2,220 Bytes
8cc0aad
 
 
37293dc
ce34d64
 
 
 
4ea9a66
 
 
 
 
8cc0aad
ce34d64
 
 
4ea9a66
3a50377
 
4ac9e25
 
 
 
 
 
 
 
 
3a50377
8cc0aad
 
 
 
 
3a50377
 
 
 
 
 
 
4ac9e25
 
 
 
 
 
 
 
 
59bb219
4ac9e25
 
 
 
 
 
 
 
 
 
 
 
3a50377
 
8cc0aad
ce34d64
 
 
3a50377
4ac9e25
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""

from typing import Tuple

from axolotl.prompt_tokenizers import (
    AlpacaPromptTokenizingStrategy,
    InstructionPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle


def load(tokenizer, cfg):
    return AlpacaPromptTokenizingStrategy(
        AlpacaPrompter(PromptStyle.CHAT.value),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )


class AlpacaConcisePrompter(AlpacaPrompter):
    """
    Alpaca Prompter extending the system prompt to ask for concise answers
    """

    system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n"
    system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n"


class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for AlpacaQA
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt["question"],
            "",
            prompt["answer"],
        )


class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for CamelAI datasets
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt["message_1"],
            "",
            prompt["message_2"],
        )


def load_concise(tokenizer, cfg):
    return AlpacaPromptTokenizingStrategy(
        AlpacaConcisePrompter(PromptStyle.CHAT.value),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )


def load_qa(tokenizer, cfg):
    return AlpacaQAPromptTokenizingStrategy(
        AlpacaPrompter(PromptStyle.CHAT.value),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )


def load_camel_ai(tokenizer, cfg):
    return CamelAIPromptTokenizingStrategy(
        AlpacaPrompter(PromptStyle.CHAT.value),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )