winglian commited on
Commit
bcc78d8
1 Parent(s): 74532dd

bump transformers and update attention class map name (#1023)

Browse files

* bump transformers and update attention class map name

* also run the tests in docker

* add mixtral e2e smoke test

* fix base name for docker image in test

* mixtral lora doesn't seem to work, at least check qlora

* add testcase for mixtral w sample packing

* check monkeypatch for flash attn multipack

* also run the e2e tests in docker

* use all gpus to run tests in docker ci

* use privileged mode too for docker w gpus

* rename the docker e2e actions for gh ci

* set privileged mode for docker and update mixtral model self attn check

* use fp16/bf16 for mixtral w fa2

* skip e2e tests on docker w gpus for now

* tests to validate mistral and mixtral patches

* fix rel import

.github/workflows/tests-docker.yml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: e2e-docker-tests
2
+
3
+ on:
4
+ pull_request:
5
+ paths:
6
+ - '**.py'
7
+ - 'requirements.txt'
8
+ workflow_dispatch:
9
+
10
+ jobs:
11
+ build-axolotl:
12
+ if: github.repository_owner == 'OpenAccess-AI-Collective'
13
+ # this job needs to be run on self-hosted GPU runners...
14
+ strategy:
15
+ fail-fast: false
16
+ matrix:
17
+ include:
18
+ - cuda: 118
19
+ cuda_version: 11.8.0
20
+ python_version: "3.10"
21
+ pytorch: 2.0.1
22
+ axolotl_extras:
23
+ is_latest: true
24
+ - cuda: 121
25
+ cuda_version: 12.1.0
26
+ python_version: "3.10"
27
+ pytorch: 2.1.1
28
+ axolotl_extras:
29
+ runs-on: [self-hosted, gpu, docker]
30
+ steps:
31
+ - name: Checkout
32
+ uses: actions/checkout@v4
33
+ - name: Docker metadata
34
+ id: metadata
35
+ uses: docker/metadata-action@v5
36
+ with:
37
+ images: winglian/axolotl
38
+ - name: Set up Docker Buildx
39
+ uses: docker/setup-buildx-action@v3
40
+ - name: Login to Docker Hub
41
+ uses: docker/login-action@v3
42
+ with:
43
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
44
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
45
+ # guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
46
+ - name: Build and export to Docker
47
+ uses: docker/build-push-action@v5
48
+ with:
49
+ context: .
50
+ load: true
51
+ build-args: |
52
+ BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
53
+ CUDA=${{ matrix.cuda }}
54
+ PYTORCH_VERSION=${{ matrix.pytorch }}
55
+ file: ./docker/Dockerfile
56
+ tags: |
57
+ ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
58
+ ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
59
+ labels: ${{ steps.metadata.outputs.labels }}
60
+ - name: Unit Tests
61
+ run: |
62
+ docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
requirements.txt CHANGED
@@ -2,7 +2,7 @@
2
  auto-gptq==0.5.1
3
  packaging
4
  peft==0.6.0
5
- transformers==4.36.2
6
  tokenizers==0.15.0
7
  bitsandbytes>=0.41.1
8
  accelerate==0.24.1
 
2
  auto-gptq==0.5.1
3
  packaging
4
  peft==0.6.0
5
+ transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
6
  tokenizers==0.15.0
7
  bitsandbytes>=0.41.1
8
  accelerate==0.24.1
src/axolotl/monkeypatch/mixtral/__init__.py CHANGED
@@ -17,6 +17,6 @@ def replace_mixtral_attn_with_multipack_flash_attn():
17
  transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
18
  mixtral_model_forward
19
  )
20
- transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
21
  "flash_attention_2"
22
  ] = MixtralMultipackFlashAttention2
 
17
  transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
18
  mixtral_model_forward
19
  )
20
+ transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[
21
  "flash_attention_2"
22
  ] = MixtralMultipackFlashAttention2
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py CHANGED
@@ -261,7 +261,11 @@ def mixtral_model_forward(
261
  if inputs_embeds is None:
262
  inputs_embeds = self.embed_tokens(input_ids)
263
 
264
- if attention_mask is not None and self._use_flash_attention_2 and use_cache:
 
 
 
 
265
  is_padding_right = attention_mask[:, -1].sum().item() != batch_size
266
  if is_padding_right:
267
  raise ValueError(
@@ -270,7 +274,7 @@ def mixtral_model_forward(
270
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
271
  )
272
 
273
- if self._use_flash_attention_2:
274
  # 2d mask is passed through the layers
275
  attention_mask = (
276
  attention_mask
 
261
  if inputs_embeds is None:
262
  inputs_embeds = self.embed_tokens(input_ids)
263
 
264
+ if (
265
+ attention_mask is not None
266
+ and self._attn_implementation == "flash_attention_2"
267
+ and use_cache
268
+ ):
269
  is_padding_right = attention_mask[:, -1].sum().item() != batch_size
270
  if is_padding_right:
271
  raise ValueError(
 
274
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
275
  )
276
 
277
+ if self._attn_implementation == "flash_attention_2":
278
  # 2d mask is passed through the layers
279
  attention_mask = (
280
  attention_mask
src/axolotl/utils/models.py CHANGED
@@ -332,15 +332,18 @@ def load_model(
332
  or cfg.is_mistral_derived_model
333
  or model_config.model_type == "mixtral"
334
  ):
 
335
  model_config._attn_implementation = ( # pylint: disable=protected-access
336
  "flash_attention_2"
337
  )
338
  else:
339
  if model_config.model_type == "mixtral":
 
340
  model_config._attn_implementation = ( # pylint: disable=protected-access
341
  "flash_attention_2"
342
  )
343
  else:
 
344
  model_config._attn_implementation = ( # pylint: disable=protected-access
345
  "eager"
346
  )
 
332
  or cfg.is_mistral_derived_model
333
  or model_config.model_type == "mixtral"
334
  ):
335
+ model_kwargs["attn_implementation"] = "flash_attention_2"
336
  model_config._attn_implementation = ( # pylint: disable=protected-access
337
  "flash_attention_2"
338
  )
339
  else:
340
  if model_config.model_type == "mixtral":
341
+ model_kwargs["attn_implementation"] = "flash_attention_2"
342
  model_config._attn_implementation = ( # pylint: disable=protected-access
343
  "flash_attention_2"
344
  )
345
  else:
346
+ model_kwargs["attn_implementation"] = "eager"
347
  model_config._attn_implementation = ( # pylint: disable=protected-access
348
  "eager"
349
  )
tests/e2e/test_mixtral.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for mixtral
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import unittest
8
+ from pathlib import Path
9
+
10
+ from transformers.utils import is_torch_bf16_gpu_available
11
+
12
+ from axolotl.cli import load_datasets
13
+ from axolotl.common.cli import TrainerCliArgs
14
+ from axolotl.train import train
15
+ from axolotl.utils.config import normalize_config
16
+ from axolotl.utils.dict import DictDefault
17
+
18
+ from .utils import with_temp_dir
19
+
20
+ LOG = logging.getLogger("axolotl.tests.e2e")
21
+ os.environ["WANDB_DISABLED"] = "true"
22
+
23
+
24
+ class TestMixtral(unittest.TestCase):
25
+ """
26
+ Test case for Llama models using LoRA
27
+ """
28
+
29
+ @with_temp_dir
30
+ def test_qlora(self, temp_dir):
31
+ # pylint: disable=duplicate-code
32
+ cfg = DictDefault(
33
+ {
34
+ "base_model": "hf-internal-testing/Mixtral-tiny",
35
+ "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
36
+ "flash_attention": True,
37
+ "sequence_len": 1024,
38
+ "load_in_4bit": True,
39
+ "adapter": "qlora",
40
+ "lora_r": 16,
41
+ "lora_alpha": 32,
42
+ "lora_dropout": 0.1,
43
+ "lora_target_linear": True,
44
+ "val_set_size": 0.1,
45
+ "special_tokens": {},
46
+ "datasets": [
47
+ {
48
+ "path": "mhenrichsen/alpaca_2k_test",
49
+ "type": "alpaca",
50
+ },
51
+ ],
52
+ "num_epochs": 2,
53
+ "micro_batch_size": 2,
54
+ "gradient_accumulation_steps": 1,
55
+ "output_dir": temp_dir,
56
+ "learning_rate": 0.00001,
57
+ "optimizer": "adamw_bnb_8bit",
58
+ "lr_scheduler": "cosine",
59
+ "max_steps": 20,
60
+ "save_steps": 10,
61
+ "eval_steps": 10,
62
+ }
63
+ )
64
+ normalize_config(cfg)
65
+ cli_args = TrainerCliArgs()
66
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
67
+
68
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
69
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
70
+
71
+ @with_temp_dir
72
+ def test_ft(self, temp_dir):
73
+ # pylint: disable=duplicate-code
74
+ cfg = DictDefault(
75
+ {
76
+ "base_model": "hf-internal-testing/Mixtral-tiny",
77
+ "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
78
+ "flash_attention": True,
79
+ "sequence_len": 1024,
80
+ "val_set_size": 0.1,
81
+ "special_tokens": {},
82
+ "datasets": [
83
+ {
84
+ "path": "mhenrichsen/alpaca_2k_test",
85
+ "type": "alpaca",
86
+ },
87
+ ],
88
+ "num_epochs": 2,
89
+ "micro_batch_size": 2,
90
+ "gradient_accumulation_steps": 1,
91
+ "output_dir": temp_dir,
92
+ "learning_rate": 0.00001,
93
+ "optimizer": "adamw_bnb_8bit",
94
+ "lr_scheduler": "cosine",
95
+ "max_steps": 20,
96
+ "save_steps": 10,
97
+ "eval_steps": 10,
98
+ }
99
+ )
100
+ if is_torch_bf16_gpu_available():
101
+ cfg.bf16 = True
102
+ else:
103
+ cfg.fp16 = True
104
+ normalize_config(cfg)
105
+ cli_args = TrainerCliArgs()
106
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
107
+
108
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
109
+ assert (Path(temp_dir) / "pytorch_model.bin").exists()
tests/e2e/test_mixtral_samplepack.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for mixtral
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import unittest
8
+ from pathlib import Path
9
+
10
+ from transformers.utils import is_torch_bf16_gpu_available
11
+
12
+ from axolotl.cli import load_datasets
13
+ from axolotl.common.cli import TrainerCliArgs
14
+ from axolotl.train import train
15
+ from axolotl.utils.config import normalize_config
16
+ from axolotl.utils.dict import DictDefault
17
+
18
+ from .utils import with_temp_dir
19
+
20
+ LOG = logging.getLogger("axolotl.tests.e2e")
21
+ os.environ["WANDB_DISABLED"] = "true"
22
+
23
+
24
+ class TestMixtral(unittest.TestCase):
25
+ """
26
+ Test case for Llama models using LoRA
27
+ """
28
+
29
+ @with_temp_dir
30
+ def test_qlora(self, temp_dir):
31
+ # pylint: disable=duplicate-code
32
+ cfg = DictDefault(
33
+ {
34
+ "base_model": "hf-internal-testing/Mixtral-tiny",
35
+ "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
36
+ "flash_attention": True,
37
+ "sequence_len": 2048,
38
+ "load_in_4bit": True,
39
+ "adapter": "qlora",
40
+ "lora_r": 16,
41
+ "lora_alpha": 32,
42
+ "lora_dropout": 0.1,
43
+ "lora_target_linear": True,
44
+ "val_set_size": 0.1,
45
+ "special_tokens": {},
46
+ "datasets": [
47
+ {
48
+ "path": "mhenrichsen/alpaca_2k_test",
49
+ "type": "alpaca",
50
+ },
51
+ ],
52
+ "num_epochs": 2,
53
+ "micro_batch_size": 2,
54
+ "gradient_accumulation_steps": 1,
55
+ "output_dir": temp_dir,
56
+ "learning_rate": 0.00001,
57
+ "optimizer": "adamw_bnb_8bit",
58
+ "lr_scheduler": "cosine",
59
+ "max_steps": 20,
60
+ "save_steps": 10,
61
+ "eval_steps": 10,
62
+ "sample_packing": True,
63
+ }
64
+ )
65
+ if is_torch_bf16_gpu_available():
66
+ cfg.bf16 = True
67
+ else:
68
+ cfg.fp16 = True
69
+ normalize_config(cfg)
70
+ cli_args = TrainerCliArgs()
71
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
72
+
73
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
74
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
75
+
76
+ @with_temp_dir
77
+ def test_ft(self, temp_dir):
78
+ # pylint: disable=duplicate-code
79
+ cfg = DictDefault(
80
+ {
81
+ "base_model": "hf-internal-testing/Mixtral-tiny",
82
+ "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
83
+ "flash_attention": True,
84
+ "sequence_len": 2048,
85
+ "val_set_size": 0.1,
86
+ "special_tokens": {},
87
+ "datasets": [
88
+ {
89
+ "path": "mhenrichsen/alpaca_2k_test",
90
+ "type": "alpaca",
91
+ },
92
+ ],
93
+ "num_epochs": 2,
94
+ "micro_batch_size": 2,
95
+ "gradient_accumulation_steps": 1,
96
+ "output_dir": temp_dir,
97
+ "learning_rate": 0.00001,
98
+ "optimizer": "adamw_bnb_8bit",
99
+ "lr_scheduler": "cosine",
100
+ "max_steps": 20,
101
+ "save_steps": 10,
102
+ "eval_steps": 10,
103
+ "sample_packing": True,
104
+ }
105
+ )
106
+ if is_torch_bf16_gpu_available():
107
+ cfg.bf16 = True
108
+ else:
109
+ cfg.fp16 = True
110
+ normalize_config(cfg)
111
+ cli_args = TrainerCliArgs()
112
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
113
+
114
+ model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
115
+ assert (
116
+ "axolotl.monkeypatch.mixtral.modeling_mixtral"
117
+ in model.model.layers[0].self_attn.__class__.__module__
118
+ )
119
+ assert (
120
+ "MixtralMultipackFlashAttention2"
121
+ in model.model.layers[0].self_attn.__class__.__name__
122
+ )
123
+ assert (Path(temp_dir) / "pytorch_model.bin").exists()
tests/e2e/test_model_patches.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E smoke tests to check that the monkeypatches are in place for certain configurations
3
+ """
4
+
5
+ import unittest
6
+
7
+ from axolotl.common.cli import TrainerCliArgs
8
+ from axolotl.utils.config import normalize_config
9
+ from axolotl.utils.dict import DictDefault
10
+ from axolotl.utils.models import load_model, load_tokenizer
11
+
12
+ from .utils import with_temp_dir
13
+
14
+
15
+ class TestModelPatches(unittest.TestCase):
16
+ """
17
+ TestCases for the multipack monkey patches
18
+ """
19
+
20
+ @with_temp_dir
21
+ def test_mixtral_multipack(self, temp_dir):
22
+ cfg = DictDefault(
23
+ {
24
+ "base_model": "hf-internal-testing/Mixtral-tiny",
25
+ "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
26
+ "flash_attention": True,
27
+ "sample_packing": True,
28
+ "sequence_len": 2048,
29
+ "val_set_size": 0.1,
30
+ "special_tokens": {},
31
+ "datasets": [
32
+ {
33
+ "path": "mhenrichsen/alpaca_2k_test",
34
+ "type": "alpaca",
35
+ },
36
+ ],
37
+ "num_epochs": 2,
38
+ "micro_batch_size": 2,
39
+ "gradient_accumulation_steps": 1,
40
+ "output_dir": temp_dir,
41
+ "learning_rate": 0.00001,
42
+ "optimizer": "adamw_bnb_8bit",
43
+ "lr_scheduler": "cosine",
44
+ "max_steps": 20,
45
+ "save_steps": 10,
46
+ "eval_steps": 10,
47
+ }
48
+ )
49
+ normalize_config(cfg)
50
+ cli_args = TrainerCliArgs()
51
+ tokenizer = load_tokenizer(cfg)
52
+ model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
53
+
54
+ assert (
55
+ "axolotl.monkeypatch.mixtral.modeling_mixtral"
56
+ in model.model.layers[0].self_attn.__class__.__module__
57
+ )
58
+ assert (
59
+ "MixtralMultipackFlashAttention2"
60
+ in model.model.layers[0].self_attn.__class__.__name__
61
+ )
62
+
63
+ @with_temp_dir
64
+ def test_mistral_multipack(self, temp_dir):
65
+ cfg = DictDefault(
66
+ {
67
+ "base_model": "openaccess-ai-collective/tiny-mistral",
68
+ "flash_attention": True,
69
+ "sample_packing": True,
70
+ "sequence_len": 2048,
71
+ "val_set_size": 0.1,
72
+ "special_tokens": {},
73
+ "datasets": [
74
+ {
75
+ "path": "mhenrichsen/alpaca_2k_test",
76
+ "type": "alpaca",
77
+ },
78
+ ],
79
+ "num_epochs": 2,
80
+ "micro_batch_size": 2,
81
+ "gradient_accumulation_steps": 1,
82
+ "output_dir": temp_dir,
83
+ "learning_rate": 0.00001,
84
+ "optimizer": "adamw_bnb_8bit",
85
+ "lr_scheduler": "cosine",
86
+ "max_steps": 20,
87
+ "save_steps": 10,
88
+ "eval_steps": 10,
89
+ }
90
+ )
91
+ normalize_config(cfg)
92
+ cli_args = TrainerCliArgs()
93
+ tokenizer = load_tokenizer(cfg)
94
+ model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
95
+
96
+ assert (
97
+ "axolotl.monkeypatch.mistral_attn_hijack_flash"
98
+ in model.model.layers[0].self_attn.forward.__module__
99
+ )