"""Module for testing prompt tokenizers.""" import json import logging import unittest from pathlib import Path from transformers import AutoTokenizer from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter from axolotl.prompt_strategies.alpaca_w_system import ( InstructionWSystemPromptTokenizingStrategy, SystemDataPrompter, ) from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, ) from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter LOG = logging.getLogger("axolotl") class TestPromptTokenizationStrategies(unittest.TestCase): """ Test class for prompt tokenization strategies. """ def setUp(self) -> None: # pylint: disable=duplicate-code self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer.add_special_tokens( { "bos_token": "", "eos_token": "", "unk_token": "", } ) def test_sharegpt_integration(self): with open( Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" ) as fin: data = fin.read() conversation = json.loads(data) with open( Path(__file__).parent / "fixtures/conversation.tokenized.json", encoding="utf-8", ) as fin: data = fin.read() tokenized_conversation = json.loads(data) prompter = ShareGPTPrompter("chat") strat = ShareGPTPromptTokenizingStrategy( prompter, self.tokenizer, False, 2048, ) example = strat.tokenize_prompt(conversation) for fields in ["input_ids", "attention_mask", "labels"]: self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) self.assertEqual(example[fields], tokenized_conversation[fields]) def test_no_sys_prompt(self): """ tests the interface between the user and assistant parts """ prompter = NoSystemPrompter() # pylint: disable=duplicate-code strat = AlpacaPromptTokenizingStrategy( prompter, self.tokenizer, False, 2048, ) sample = { "instruction": "hello cruel. lorem ipsum dolor sit amet.", "output": "world!", } example = strat.tokenize_prompt(sample) world_idx = example["input_ids"].index(3186) assert example["labels"][world_idx] == 3186 assert example["labels"][world_idx - 1] == -100 def test_alpaca(self): """ tests the interface between the user and assistant parts """ # pylint: disable=duplicate-code prompter = AlpacaPrompter() strat = AlpacaPromptTokenizingStrategy( prompter, self.tokenizer, False, 2048, ) sample = {"instruction": "hello!", "output": "Hi! How can I help?"} example = strat.tokenize_prompt(sample) world_idx = example["input_ids"].index(6324) assert example["labels"][world_idx] == 6324 assert example["labels"][world_idx - 1] == -100 class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase): """ Test class for prompt tokenization strategies with sys prompt from the dataset """ def setUp(self) -> None: # pylint: disable=duplicate-code self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer.add_special_tokens( { "bos_token": "", "eos_token": "", "unk_token": "", } ) def test_system_alpaca(self): prompter = SystemDataPrompter(PromptStyle.CHAT.value) strat = InstructionWSystemPromptTokenizingStrategy( prompter, self.tokenizer, False, 2048, ) sample = { "system": "use cot", "instruction": "hello!", "output": "Hi! How can I help?", } example = strat.tokenize_prompt(sample) assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "### System:" assert example["input_ids"][5:7] == [1509, 20118] # "use cot" assert example["input_ids"][9] == 11889 # USER if __name__ == "__main__": unittest.main()