"""Module for testing prompt tokenizers.""" import json import logging import unittest from pathlib import Path from transformers import AutoTokenizer from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy from axolotl.prompters import ShareGPTPrompter logging.basicConfig(level="INFO") class TestPromptTokenizationStrategies(unittest.TestCase): """ Test class for prompt tokenization strategies. """ def setUp(self) -> None: # pylint: disable=duplicate-code self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer.add_special_tokens( { "bos_token": "", "eos_token": "", "unk_token": "", } ) def test_sharegpt_integration(self): print(Path(__file__).parent) with open( Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" ) as fin: data = fin.read() conversation = json.loads(data) with open( Path(__file__).parent / "fixtures/conversation.tokenized.json", encoding="utf-8", ) as fin: data = fin.read() tokenized_conversation = json.loads(data) prompter = ShareGPTPrompter("chat") strat = ShareGPTPromptTokenizingStrategy( prompter, self.tokenizer, False, 2048, ) example = strat.tokenize_prompt(conversation) for fields in ["input_ids", "attention_mask", "labels"]: self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) self.assertEqual(example[fields], tokenized_conversation[fields]) if __name__ == "__main__": unittest.main()