File size: 3,789 Bytes
12a2dbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03e5907
12a2dbb
03e5907
12a2dbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03e5907
12a2dbb
03e5907
12a2dbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03e5907
12a2dbb
03e5907
12a2dbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03e5907
12a2dbb
03e5907
12a2dbb
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
E2E tests for lora llama
"""

import logging
import os
import tempfile
import unittest

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"


class TestPhi(unittest.TestCase):
    """
    Test case for Llama models using LoRA
    """

    def test_ft(self):
        # pylint: disable=duplicate-code
        cfg = DictDefault(
            {
                "base_model": "microsoft/phi-1_5",
                "base_model_config": "microsoft/phi-1_5",
                "trust_remote_code": True,
                "model_type": "MixFormerSequentialForCausalLM",
                "tokenizer_type": "AutoTokenizer",
                "sequence_len": 512,
                "sample_packing": False,
                "load_in_8bit": False,
                "adapter": None,
                "val_set_size": 0.1,
                "special_tokens": {
                    "unk_token": "<|endoftext|>",
                    "bos_token": "<|endoftext|>",
                    "eos_token": "<|endoftext|>",
                    "pad_token": "<|endoftext|>",
                },
                "datasets": [
                    {
                        "path": "mhenrichsen/alpaca_2k_test",
                        "type": "alpaca",
                    },
                ],
                "dataset_shard_num": 10,
                "dataset_shard_idx": 0,
                "num_epochs": 1,
                "micro_batch_size": 1,
                "gradient_accumulation_steps": 1,
                "output_dir": tempfile.mkdtemp(),
                "learning_rate": 0.00001,
                "optimizer": "adamw_bnb_8bit",
                "lr_scheduler": "cosine",
                "bf16": True,
            }
        )
        normalize_config(cfg)
        cli_args = TrainerCliArgs()
        dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

        train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)

    def test_ft_packed(self):
        # pylint: disable=duplicate-code
        cfg = DictDefault(
            {
                "base_model": "microsoft/phi-1_5",
                "base_model_config": "microsoft/phi-1_5",
                "trust_remote_code": True,
                "model_type": "MixFormerSequentialForCausalLM",
                "tokenizer_type": "AutoTokenizer",
                "sequence_len": 512,
                "sample_packing": True,
                "load_in_8bit": False,
                "adapter": None,
                "val_set_size": 0.1,
                "special_tokens": {
                    "unk_token": "<|endoftext|>",
                    "bos_token": "<|endoftext|>",
                    "eos_token": "<|endoftext|>",
                    "pad_token": "<|endoftext|>",
                },
                "datasets": [
                    {
                        "path": "mhenrichsen/alpaca_2k_test",
                        "type": "alpaca",
                    },
                ],
                "dataset_shard_num": 10,
                "dataset_shard_idx": 0,
                "num_epochs": 1,
                "micro_batch_size": 1,
                "gradient_accumulation_steps": 1,
                "output_dir": tempfile.mkdtemp(),
                "learning_rate": 0.00001,
                "optimizer": "adamw_bnb_8bit",
                "lr_scheduler": "cosine",
                "bf16": True,
            }
        )
        normalize_config(cfg)
        cli_args = TrainerCliArgs()
        dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

        train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)