File size: 2,975 Bytes
ce24f5e
 
 
 
 
 
 
 
 
 
 
8d959a7
 
 
 
ce24f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d959a7
e107643
 
ce24f5e
 
 
 
a6028d3
 
 
ce24f5e
 
 
 
8d959a7
ce24f5e
e107643
ce24f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d959a7
ce24f5e
 
 
 
 
 
 
 
8d959a7
 
f2a2029
8d959a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import abc

from transformers import PreTrainedTokenizer

IGNORE_INDEX = -100
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
LLAMA_DEFAULT_EOS_TOKEN = "</s>"
LLAMA_DEFAULT_BOS_TOKEN = "<s>"
LLAMA_DEFAULT_UNK_TOKEN = "<unk>"


class InvalidDataException(Exception):
    pass


class PromptTokenizingStrategy(abc.ABC):
    def __init__(
        self,
        prompter,
        tokenizer,
        train_on_inputs: bool = False,
        sequence_len: int = 2048,
    ):
        self.prompter = prompter
        self.tokenizer: PreTrainedTokenizer = tokenizer
        self.train_on_inputs = train_on_inputs
        self.sequence_len = sequence_len

    @abc.abstractmethod
    def tokenize_prompt(self, prompt):
        pass


class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
    def tokenize_prompt(self, prompt):
        full_prompt = self._tokenize_full_prompt(prompt)
        tokenized_full_prompt = self._tokenize(full_prompt)
        if not self.train_on_inputs:
            user_prompt = self.prompter.build_prompt(
                prompt["instruction"],
                prompt["input"] if "input" in prompt else "",
            )
            tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
            user_prompt_len = len(tokenized_user_prompt["input_ids"])
            # TODO this could be sped up using numpy array slicing
            tokenized_full_prompt["labels"] = [
                -100
            ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]

        return tokenized_full_prompt

    def _tokenize_full_prompt(self, prompt):
        return self.prompter.build_prompt(
            prompt["instruction"],
            prompt["input"] if "input" in prompt else "",
            prompt["output"],
        )

    def _tokenize(self, prompt, add_eos_token=True):
        result = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.sequence_len,
            padding=False,
            return_tensors=None,
        )
        if (
            result["input_ids"][-1] != self.tokenizer.eos_token_id
            and len(result["input_ids"]) < self.sequence_len
            and add_eos_token
        ):
            result["input_ids"].append(self.tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()
        return result


class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
    def _tokenize_full_prompt(self, prompt):
        return self.prompter.build_prompt(
            prompt["instruction"],
            prompt["input"],
            prompt["response"],
        )


class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
    def tokenize_prompt(self, prompt):
        try:
            return self.prompter.build_prompt(prompt["conversations"], self.tokenizer)
        except (KeyError, AssertionError, IndexError) as e:
            raise InvalidDataException(str(e))