winglian commited on
Commit
9218ebe
1 Parent(s): 2284209

e2e testing (#574)

Browse files
.github/workflows/e2e.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: E2E
2
+ on:
3
+ workflow_dispatch:
4
+
5
+ jobs:
6
+ e2e-test:
7
+ runs-on: [self-hosted, gpu]
8
+ strategy:
9
+ fail-fast: false
10
+ matrix:
11
+ python_version: ["3.10"]
12
+ timeout-minutes: 10
13
+
14
+ steps:
15
+ - name: Check out repository code
16
+ uses: actions/checkout@v3
17
+
18
+ - name: Setup Python
19
+ uses: actions/setup-python@v4
20
+ with:
21
+ python-version: ${{ matrix.python_version }}
22
+ cache: 'pip' # caching pip dependencies
23
+
24
+ - name: Install dependencies
25
+ run: |
26
+ pip3 install -e .
27
+ pip3 install -r requirements-tests.txt
28
+
29
+ - name: Run e2e tests
30
+ run: |
31
+ pytest tests/e2e/
.github/workflows/main.yml CHANGED
@@ -23,7 +23,7 @@ jobs:
23
  python_version: "3.10"
24
  pytorch: 2.0.1
25
  axolotl_extras:
26
- runs-on: self-hosted
27
  steps:
28
  - name: Checkout
29
  uses: actions/checkout@v3
@@ -68,7 +68,7 @@ jobs:
68
  pytorch: 2.0.1
69
  axolotl_extras:
70
  is_latest: true
71
- runs-on: self-hosted
72
  steps:
73
  - name: Checkout
74
  uses: actions/checkout@v3
 
23
  python_version: "3.10"
24
  pytorch: 2.0.1
25
  axolotl_extras:
26
+ runs-on: [self-hosted, gpu, docker]
27
  steps:
28
  - name: Checkout
29
  uses: actions/checkout@v3
 
68
  pytorch: 2.0.1
69
  axolotl_extras:
70
  is_latest: true
71
+ runs-on: [self-hosted, gpu, docker]
72
  steps:
73
  - name: Checkout
74
  uses: actions/checkout@v3
.github/workflows/tests.yml CHANGED
@@ -29,4 +29,4 @@ jobs:
29
 
30
  - name: Run tests
31
  run: |
32
- pytest tests/
 
29
 
30
  - name: Run tests
31
  run: |
32
+ pytest --ignore=tests/e2e/ tests/
tests/e2e/test_lora_llama.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for lora llama
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import tempfile
8
+ import unittest
9
+
10
+ from axolotl.common.cli import TrainerCliArgs
11
+ from axolotl.train import TrainDatasetMeta, train
12
+ from axolotl.utils.config import normalize_config
13
+ from axolotl.utils.data import prepare_dataset
14
+ from axolotl.utils.dict import DictDefault
15
+ from axolotl.utils.models import load_tokenizer
16
+
17
+ LOG = logging.getLogger("axolotl.tests.e2e")
18
+ os.environ["WANDB_DISABLED"] = "true"
19
+
20
+
21
+ def load_datasets(
22
+ *,
23
+ cfg: DictDefault,
24
+ cli_args: TrainerCliArgs, # pylint:disable=unused-argument
25
+ ) -> TrainDatasetMeta:
26
+ tokenizer = load_tokenizer(cfg)
27
+
28
+ train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
29
+
30
+ return TrainDatasetMeta(
31
+ train_dataset=train_dataset,
32
+ eval_dataset=eval_dataset,
33
+ total_num_steps=total_num_steps,
34
+ )
35
+
36
+
37
+ class TestLoraLlama(unittest.TestCase):
38
+ """
39
+ Test case for Llama models using LoRA
40
+ """
41
+
42
+ def test_lora(self):
43
+ cfg = DictDefault(
44
+ {
45
+ "base_model": "JackFram/llama-68m",
46
+ "base_model_config": "JackFram/llama-68m",
47
+ "tokenizer_type": "LlamaTokenizer",
48
+ "sequence_len": 1024,
49
+ "load_in_8bit": True,
50
+ "adapter": "lora",
51
+ "lora_r": 32,
52
+ "lora_alpha": 64,
53
+ "lora_dropout": 0.05,
54
+ "lora_target_linear": True,
55
+ "val_set_size": 0.1,
56
+ "special_tokens": {
57
+ "unk_token": "<unk>",
58
+ "bos_token": "<s>",
59
+ "eos_token": "</s>",
60
+ },
61
+ "datasets": [
62
+ {
63
+ "path": "mhenrichsen/alpaca_2k_test",
64
+ "type": "alpaca",
65
+ },
66
+ ],
67
+ "num_epochs": 2,
68
+ "micro_batch_size": 8,
69
+ "gradient_accumulation_steps": 1,
70
+ "output_dir": tempfile.mkdtemp(),
71
+ "learning_rate": 0.00001,
72
+ "optimizer": "adamw_torch",
73
+ "lr_scheduler": "cosine",
74
+ }
75
+ )
76
+ normalize_config(cfg)
77
+ cli_args = TrainerCliArgs()
78
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
79
+
80
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)