JohanWork Nanobit commited on
Commit
5439707
1 Parent(s): 6840381

Feat(test): Add tests for alpaca chatml prompt tokenizer (#1088)

Browse files

* draft for adding test for tokenizer

* clean up

* clean up

* fix pre commit

* fix pylint

* Revert "fix pylint"

This reverts commit cd2cda3cdae6f31f6d038a0673c2c7abd8e8e46a.

* add pylint exception for pytest fixture

* update comments

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* update spelling and import promptstyle

* reaname, restrucure

* clean up

* add fmt:on

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

tests/prompt_strategies/test_alpaca.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test module for alpaca integration w chatml
3
+ """
4
+ import pytest
5
+ from datasets import Dataset
6
+ from tokenizers import AddedToken
7
+ from transformers import AutoTokenizer
8
+
9
+ from axolotl.datasets import TokenizedPromptDataset
10
+ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
11
+ from axolotl.prompters import AlpacaPrompter, PromptStyle
12
+
13
+
14
+ @pytest.fixture(name="alpaca_dataset")
15
+ def fixture_alpaca_dataset():
16
+ return Dataset.from_list(
17
+ [
18
+ {
19
+ "instruction": "Evaluate this sentence for spelling and grammar mistakes",
20
+ "input": "He finnished his meal and left the resturant",
21
+ "output": "He finished his meal and left the restaurant.",
22
+ }
23
+ ]
24
+ )
25
+
26
+
27
+ @pytest.fixture(name="tokenizer")
28
+ def fixture_tokenizer():
29
+ # pylint: disable=all
30
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
31
+ tokenizer.add_special_tokens(
32
+ {
33
+ "eos_token": AddedToken(
34
+ "<|im_end|>", rstrip=False, lstrip=False, normalized=False
35
+ )
36
+ }
37
+ )
38
+ tokenizer.add_tokens(
39
+ [
40
+ AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
41
+ ]
42
+ )
43
+
44
+ return tokenizer
45
+
46
+
47
+ class TestAlpacaChatml:
48
+ """
49
+ Test class for alpaca prompter
50
+ """
51
+
52
+ def test_no_double_im_end(self, alpaca_dataset, tokenizer):
53
+ strategy = AlpacaPromptTokenizingStrategy(
54
+ AlpacaPrompter(prompt_style=PromptStyle.CHATML.value),
55
+ tokenizer,
56
+ False, # train_on_inputs
57
+ 2048, # sequence_len
58
+ )
59
+
60
+ dataset_wrapper = TokenizedPromptDataset(
61
+ strategy, alpaca_dataset, process_count=1
62
+ )
63
+
64
+ input_ids = dataset_wrapper[0]["input_ids"]
65
+ # fmt: off
66
+ assert input_ids == [
67
+ 1, # Bos
68
+ 32001, 1587, 13, 20548, 336, 349, 396, 13126, 369, 13966, 264, 3638, 28725, 5881, 1360, 395, 396, 2787, 369, 5312, 3629, 2758, 28723, 12018, 264, 2899, 369, 6582, 1999, 2691, 274, 272, 2159, 28723, 32000, 28705, 13, # instruction
69
+ 32001, 2188, 13, 16627, 11931, 456, 12271, 354, 668, 3572, 304, 18756, 3479, 17179, 13, 2428, 854, 28711, 1497, 516, 11314, 304, 1749, 272, 1846, 324, 440, 32000, 28705, 13, # input
70
+ 32001, 13892, 13, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000, # output
71
+ ]
72
+ # fmt: on
73
+
74
+ def test_no_train_on_input(self, alpaca_dataset, tokenizer):
75
+ strategy = AlpacaPromptTokenizingStrategy(
76
+ AlpacaPrompter(prompt_style=PromptStyle.CHATML.value),
77
+ tokenizer,
78
+ False, # train_on_inputs
79
+ 2048, # sequence_len
80
+ )
81
+
82
+ dataset_wrapper = TokenizedPromptDataset(
83
+ strategy, alpaca_dataset, process_count=1
84
+ )
85
+
86
+ labels = dataset_wrapper[0]["labels"]
87
+ # fmt: off
88
+ assert labels == [
89
+ -100, # bos
90
+ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # instruction
91
+ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # input
92
+ -100, -100, -100, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000, # Output
93
+ ]
94
+ # fmt: on
95
+
96
+ def test_w_train_on_input(self, alpaca_dataset, tokenizer):
97
+ strategy = AlpacaPromptTokenizingStrategy(
98
+ AlpacaPrompter(prompt_style=PromptStyle.CHATML.value),
99
+ tokenizer,
100
+ True, # train_on_inputs
101
+ 2048, # sequence_len
102
+ )
103
+
104
+ dataset_wrapper = TokenizedPromptDataset(
105
+ strategy, alpaca_dataset, process_count=1
106
+ )
107
+
108
+ labels = dataset_wrapper[0]["labels"]
109
+ # fmt: off
110
+ assert labels == [
111
+ 1, # Bos
112
+ 32001, 1587, 13, 20548, 336, 349, 396, 13126, 369, 13966, 264, 3638, 28725, 5881, 1360, 395, 396, 2787, 369, 5312, 3629, 2758, 28723, 12018, 264, 2899, 369, 6582, 1999, 2691, 274, 272, 2159, 28723, 32000, 28705, 13, # instruction
113
+ 32001, 2188, 13, 16627, 11931, 456, 12271, 354, 668, 3572, 304, 18756, 3479, 17179, 13, 2428, 854, 28711, 1497, 516, 11314, 304, 1749, 272, 1846, 324, 440, 32000, 28705, 13, # input
114
+ 32001, 13892, 13, 650, 5967, 516, 11314, 304, 1749, 272, 9926, 28723, 32000, # output
115
+ ]
116
+ # fmt: on