File size: 4,391 Bytes
7eb33a7
 
d199d6c
 
3a38271
8d20e0a
 
 
 
 
 
d199d6c
 
 
7eb33a7
 
 
 
d199d6c
 
 
 
 
 
 
7eb33a7
392dfd9
 
 
d199d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7eb33a7
392dfd9
 
 
d199d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
8d20e0a
 
 
 
 
 
 
 
 
2bb0b78
8d20e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Module testing prompters"""

import unittest

from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter
from axolotl.prompters import (
    AlpacaPrompter,
    MultipleChoiceExplainPrompter,
    PromptStyle,
    UnpromptedPrompter,
)


class AlpacaPrompterTest(unittest.TestCase):
    """
    Test AlpacaPrompter
    """

    def test_prompt_style_w_none(self):
        prompter = AlpacaPrompter(prompt_style=None)
        res = next(prompter.build_prompt("tell me a joke"))
        # just testing that it uses instruct style
        assert "### Instruction:" in res

    def test_prompt_style_w_instruct(self):
        prompter = AlpacaPrompter(prompt_style=PromptStyle.INSTRUCT.value)
        res = next(
            prompter.build_prompt("tell me a joke about the following", "alpacas")
        )
        assert "Below is an instruction" in res
        assert "### Instruction:" in res
        assert "### Input:" in res
        assert "alpacas" in res
        assert "### Response:" in res
        assert "USER:" not in res
        assert "ASSISTANT:" not in res
        res = next(prompter.build_prompt("tell me a joke about the following"))
        assert "Below is an instruction" in res
        assert "### Instruction:" in res
        assert "### Input:" not in res
        assert "### Response:" in res
        assert "USER:" not in res
        assert "ASSISTANT:" not in res

    def test_prompt_style_w_chat(self):
        prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value)
        res = next(
            prompter.build_prompt("tell me a joke about the following", "alpacas")
        )
        assert "Below is an instruction" in res
        assert "### Instruction:" not in res
        assert "### Input:" not in res
        assert "alpacas" in res
        assert "### Response:" not in res
        assert "USER:" in res
        assert "ASSISTANT:" in res
        res = next(prompter.build_prompt("tell me a joke about the following"))
        assert "Below is an instruction" in res
        assert "### Instruction:" not in res
        assert "### Input:" not in res
        assert "### Response:" not in res
        assert "USER:" in res
        assert "ASSISTANT:" in res

    def test_system_prompt(self):
        prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
        res = next(
            prompter.build_prompt_w_system(
                "use cot", "tell me a joke about the following", "alpacas"
            )
        )
        assert "use cot" in res
        assert res.startswith("SYSTEM:")
        assert "### Instruction:" not in res
        assert "### Input:" not in res
        assert "alpacas" in res
        assert "### Response:" not in res
        assert "USER:" in res
        assert "ASSISTANT:" in res


class UnpromptedPrompterTest(unittest.TestCase):
    """
    Test class for UnpromptedPrompter with no system prompts
    """

    def test_prompt_style_w_none(self):
        prompter = UnpromptedPrompter(prompt_style=None)
        res = next(prompter.build_prompt("tell me a joke"))
        assert "### Instruction:" in res
        assert "tell me a joke" in res
        assert res.startswith("###")

    def test_prompt_style_w_instruct(self):
        prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
        res = next(
            prompter.build_prompt("tell me a joke about the following", "alpacas")
        )
        assert "### Instruction:" in res
        assert "tell me a joke" in res
        assert res.startswith("###")

    def test_prompt_style_w_chat(self):
        prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
        res = next(
            prompter.build_prompt("tell me a joke about the following", "alpacas")
        )
        assert "USER:" in res
        assert "tell me a joke" in res
        assert res.startswith("USER:")


class MultipleChoiceExplainPrompterTest(unittest.TestCase):
    """
    Test class for MultipleChoiceExplainPrompter
    """

    def test_prompt_style_w_chat(self):
        prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
        res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
        assert "USER:" in res
        assert "choose one" in res
        assert "Choose the answer that best answers the question." in res
        assert "- A\n- B\n- C" in res