winglian commited on
Commit
e799e08
1 Parent(s): 0f77b8d

Falcon embeddings (#1149) [skip docker]

Browse files

* also fix multipack for falcon and add smoke tests

* make sure to handle special tokens and added tokens for lora

* fix reference to model_type

* fix tests for falcon

* fix stray typo

* fixes for smoke tests

examples/falcon/config-7b-lora.yml CHANGED
@@ -60,5 +60,5 @@ fsdp:
60
  fsdp_config:
61
  special_tokens:
62
  pad_token: "<|endoftext|>"
63
- bos_token: ">>ABSTRACT<<"
64
  eos_token: "<|endoftext|>"
 
60
  fsdp_config:
61
  special_tokens:
62
  pad_token: "<|endoftext|>"
63
+ bos_token: "<|endoftext|>"
64
  eos_token: "<|endoftext|>"
examples/falcon/config-7b-qlora.yml CHANGED
@@ -89,5 +89,5 @@ fsdp:
89
  fsdp_config:
90
  special_tokens:
91
  pad_token: "<|endoftext|>"
92
- bos_token: ">>ABSTRACT<<"
93
  eos_token: "<|endoftext|>"
 
89
  fsdp_config:
90
  special_tokens:
91
  pad_token: "<|endoftext|>"
92
+ bos_token: "<|endoftext|>"
93
  eos_token: "<|endoftext|>"
examples/falcon/config-7b.yml CHANGED
@@ -60,5 +60,5 @@ fsdp:
60
  fsdp_config:
61
  special_tokens:
62
  pad_token: "<|endoftext|>"
63
- bos_token: ">>ABSTRACT<<"
64
  eos_token: "<|endoftext|>"
 
60
  fsdp_config:
61
  special_tokens:
62
  pad_token: "<|endoftext|>"
63
+ bos_token: "<|endoftext|>"
64
  eos_token: "<|endoftext|>"
src/axolotl/monkeypatch/falcon/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Patches to support multipack for falcon
3
+ """
4
+ import transformers
5
+
6
+ from axolotl.monkeypatch.utils import get_unpad_data
7
+
8
+
9
+ def replace_falcon_attn_with_multipack_flash_attn():
10
+ transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
11
+ get_unpad_data
12
+ )
src/axolotl/utils/lora_embeddings.py CHANGED
@@ -11,4 +11,6 @@ def get_linear_embedding_layers(model_type):
11
  return ["embd.wte", "lm_head.linear"]
12
  if model_type == "gpt_neox":
13
  return ["embed_in", "embed_out"]
 
 
14
  return ["embed_tokens", "lm_head"]
 
11
  return ["embd.wte", "lm_head.linear"]
12
  if model_type == "gpt_neox":
13
  return ["embed_in", "embed_out"]
14
+ if model_type == "falcon":
15
+ return ["word_embeddings", "lm_head"]
16
  return ["embed_tokens", "lm_head"]
src/axolotl/utils/models.py CHANGED
@@ -334,6 +334,14 @@ def load_model(
334
  LOG.info("patching mixtral with flash attention")
335
  replace_mixtral_attn_with_multipack_flash_attn()
336
 
 
 
 
 
 
 
 
 
337
  if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
338
  from axolotl.monkeypatch.qwen2 import (
339
  replace_qwen2_attn_with_multipack_flash_attn,
@@ -434,18 +442,13 @@ def load_model(
434
  if not cfg.sample_packing:
435
  if cfg.s2_attention:
436
  pass
437
- if (
438
- cfg.is_llama_derived_model
439
- or cfg.is_falcon_derived_model
440
- or cfg.is_mistral_derived_model
441
- or model_config.model_type in ["mixtral", "qwen2"]
442
- ):
443
- model_kwargs["attn_implementation"] = "flash_attention_2"
444
- model_config._attn_implementation = ( # pylint: disable=protected-access
445
- "flash_attention_2"
446
- )
447
  else:
448
- if model_config.model_type in ["mixtral", "qwen2"]:
449
  model_kwargs["attn_implementation"] = "flash_attention_2"
450
  model_config._attn_implementation = ( # pylint: disable=protected-access
451
  "flash_attention_2"
@@ -461,7 +464,11 @@ def load_model(
461
  model_config.fused_dense = True
462
 
463
  try:
464
- if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
 
 
 
 
465
  from transformers import LlamaForCausalLM
466
 
467
  model = LlamaForCausalLM.from_pretrained(
@@ -755,8 +762,10 @@ def find_all_linear_names(model):
755
  names = name.split(".")
756
  lora_module_names.add(names[0] if len(names) == 1 else names[-1])
757
 
758
- if "lm_head" in lora_module_names: # needed for 16-bit
759
- lora_module_names.remove("lm_head")
 
 
760
 
761
  return list(lora_module_names)
762
 
 
334
  LOG.info("patching mixtral with flash attention")
335
  replace_mixtral_attn_with_multipack_flash_attn()
336
 
337
+ if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
338
+ from axolotl.monkeypatch.falcon import (
339
+ replace_falcon_attn_with_multipack_flash_attn,
340
+ )
341
+
342
+ LOG.info("patching falcon with flash attention")
343
+ replace_falcon_attn_with_multipack_flash_attn()
344
+
345
  if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
346
  from axolotl.monkeypatch.qwen2 import (
347
  replace_qwen2_attn_with_multipack_flash_attn,
 
442
  if not cfg.sample_packing:
443
  if cfg.s2_attention:
444
  pass
445
+ # most other models support flash attention, we can define exceptions as they come up
446
+ model_kwargs["attn_implementation"] = "flash_attention_2"
447
+ model_config._attn_implementation = ( # pylint: disable=protected-access
448
+ "flash_attention_2"
449
+ )
 
 
 
 
 
450
  else:
451
+ if model_config.model_type in ["mixtral", "qwen2", "falcon"]:
452
  model_kwargs["attn_implementation"] = "flash_attention_2"
453
  model_config._attn_implementation = ( # pylint: disable=protected-access
454
  "flash_attention_2"
 
464
  model_config.fused_dense = True
465
 
466
  try:
467
+ if (
468
+ model_config.model_type == "llama"
469
+ and not cfg.trust_remote_code
470
+ and not cfg.gptq
471
+ ):
472
  from transformers import LlamaForCausalLM
473
 
474
  model = LlamaForCausalLM.from_pretrained(
 
762
  names = name.split(".")
763
  lora_module_names.add(names[0] if len(names) == 1 else names[-1])
764
 
765
+ embedding_modules = get_linear_embedding_layers(model.config.model_type)
766
+ output_embedding = embedding_modules[1]
767
+ if output_embedding in lora_module_names: # needed for 16-bit
768
+ lora_module_names.remove(output_embedding)
769
 
770
  return list(lora_module_names)
771
 
src/axolotl/utils/trainer.py CHANGED
@@ -124,6 +124,12 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
124
  if eval_dataset:
125
  eval_dataset = eval_dataset.remove_columns("attention_mask")
126
 
 
 
 
 
 
 
127
  train_dataset = train_dataset.filter(
128
  drop_long,
129
  num_proc=cfg.dataset_processes,
 
124
  if eval_dataset:
125
  eval_dataset = eval_dataset.remove_columns("attention_mask")
126
 
127
+ if cfg.model_config_type == "falcon":
128
+ LOG.info("dropping token_type_ids column")
129
+ train_dataset = train_dataset.remove_columns("token_type_ids")
130
+ if eval_dataset:
131
+ eval_dataset = eval_dataset.remove_columns("token_type_ids")
132
+
133
  train_dataset = train_dataset.filter(
134
  drop_long,
135
  num_proc=cfg.dataset_processes,
tests/e2e/patched/test_falcon_samplepack.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for falcon
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import unittest
8
+ from pathlib import Path
9
+
10
+ from axolotl.cli import load_datasets
11
+ from axolotl.common.cli import TrainerCliArgs
12
+ from axolotl.train import train
13
+ from axolotl.utils.config import normalize_config
14
+ from axolotl.utils.dict import DictDefault
15
+
16
+ from ..utils import with_temp_dir
17
+
18
+ LOG = logging.getLogger("axolotl.tests.e2e")
19
+ os.environ["WANDB_DISABLED"] = "true"
20
+
21
+
22
+ class TestFalconPatched(unittest.TestCase):
23
+ """
24
+ Test case for Falcon models
25
+ """
26
+
27
+ @with_temp_dir
28
+ def test_qlora(self, temp_dir):
29
+ # pylint: disable=duplicate-code
30
+ cfg = DictDefault(
31
+ {
32
+ "base_model": "illuin/tiny-random-FalconForCausalLM",
33
+ "flash_attention": True,
34
+ "sample_packing": True,
35
+ "sequence_len": 2048,
36
+ "load_in_4bit": True,
37
+ "adapter": "qlora",
38
+ "lora_r": 16,
39
+ "lora_alpha": 32,
40
+ "lora_dropout": 0.1,
41
+ "lora_target_linear": True,
42
+ "lora_modules_to_save": ["word_embeddings", "lm_head"],
43
+ "val_set_size": 0.1,
44
+ "special_tokens": {
45
+ "bos_token": "<|endoftext|>",
46
+ "pad_token": "<|endoftext|>",
47
+ },
48
+ "datasets": [
49
+ {
50
+ "path": "mhenrichsen/alpaca_2k_test",
51
+ "type": "alpaca",
52
+ },
53
+ ],
54
+ "num_epochs": 2,
55
+ "micro_batch_size": 2,
56
+ "gradient_accumulation_steps": 1,
57
+ "output_dir": temp_dir,
58
+ "learning_rate": 0.00001,
59
+ "optimizer": "adamw_bnb_8bit",
60
+ "lr_scheduler": "cosine",
61
+ "max_steps": 20,
62
+ "save_steps": 10,
63
+ "eval_steps": 10,
64
+ "bf16": "auto",
65
+ }
66
+ )
67
+ normalize_config(cfg)
68
+ cli_args = TrainerCliArgs()
69
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
70
+
71
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
72
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
73
+
74
+ @with_temp_dir
75
+ def test_ft(self, temp_dir):
76
+ # pylint: disable=duplicate-code
77
+ cfg = DictDefault(
78
+ {
79
+ "base_model": "illuin/tiny-random-FalconForCausalLM",
80
+ "flash_attention": True,
81
+ "sample_packing": True,
82
+ "sequence_len": 2048,
83
+ "val_set_size": 0.1,
84
+ "special_tokens": {
85
+ "bos_token": "<|endoftext|>",
86
+ "pad_token": "<|endoftext|>",
87
+ },
88
+ "datasets": [
89
+ {
90
+ "path": "mhenrichsen/alpaca_2k_test",
91
+ "type": "alpaca",
92
+ },
93
+ ],
94
+ "num_epochs": 2,
95
+ "micro_batch_size": 2,
96
+ "gradient_accumulation_steps": 1,
97
+ "output_dir": temp_dir,
98
+ "learning_rate": 0.00001,
99
+ "optimizer": "adamw_bnb_8bit",
100
+ "lr_scheduler": "cosine",
101
+ "max_steps": 20,
102
+ "save_steps": 10,
103
+ "eval_steps": 10,
104
+ "bf16": "auto",
105
+ }
106
+ )
107
+ normalize_config(cfg)
108
+ cli_args = TrainerCliArgs()
109
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
110
+
111
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
112
+ assert (Path(temp_dir) / "pytorch_model.bin").exists()
tests/e2e/patched/test_mixtral_samplepack.py CHANGED
@@ -32,6 +32,7 @@ class TestMixtral(unittest.TestCase):
32
  "base_model": "hf-internal-testing/Mixtral-tiny",
33
  "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
34
  "flash_attention": True,
 
35
  "sequence_len": 2048,
36
  "load_in_4bit": True,
37
  "adapter": "qlora",
@@ -57,7 +58,6 @@ class TestMixtral(unittest.TestCase):
57
  "max_steps": 20,
58
  "save_steps": 10,
59
  "eval_steps": 10,
60
- "sample_packing": True,
61
  "bf16": "auto",
62
  }
63
  )
@@ -76,6 +76,7 @@ class TestMixtral(unittest.TestCase):
76
  "base_model": "hf-internal-testing/Mixtral-tiny",
77
  "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
78
  "flash_attention": True,
 
79
  "sequence_len": 2048,
80
  "val_set_size": 0.1,
81
  "special_tokens": {},
@@ -95,7 +96,6 @@ class TestMixtral(unittest.TestCase):
95
  "max_steps": 20,
96
  "save_steps": 10,
97
  "eval_steps": 10,
98
- "sample_packing": True,
99
  "bf16": "auto",
100
  }
101
  )
 
32
  "base_model": "hf-internal-testing/Mixtral-tiny",
33
  "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
34
  "flash_attention": True,
35
+ "sample_packing": True,
36
  "sequence_len": 2048,
37
  "load_in_4bit": True,
38
  "adapter": "qlora",
 
58
  "max_steps": 20,
59
  "save_steps": 10,
60
  "eval_steps": 10,
 
61
  "bf16": "auto",
62
  }
63
  )
 
76
  "base_model": "hf-internal-testing/Mixtral-tiny",
77
  "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
78
  "flash_attention": True,
79
+ "sample_packing": True,
80
  "sequence_len": 2048,
81
  "val_set_size": 0.1,
82
  "special_tokens": {},
 
96
  "max_steps": 20,
97
  "save_steps": 10,
98
  "eval_steps": 10,
 
99
  "bf16": "auto",
100
  }
101
  )
tests/e2e/test_falcon.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for falcon
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import unittest
8
+ from pathlib import Path
9
+
10
+ from axolotl.cli import load_datasets
11
+ from axolotl.common.cli import TrainerCliArgs
12
+ from axolotl.train import train
13
+ from axolotl.utils.config import normalize_config
14
+ from axolotl.utils.dict import DictDefault
15
+
16
+ from .utils import with_temp_dir
17
+
18
+ LOG = logging.getLogger("axolotl.tests.e2e")
19
+ os.environ["WANDB_DISABLED"] = "true"
20
+
21
+
22
+ class TestFalcon(unittest.TestCase):
23
+ """
24
+ Test case for falcon
25
+ """
26
+
27
+ @with_temp_dir
28
+ def test_lora(self, temp_dir):
29
+ # pylint: disable=duplicate-code
30
+ cfg = DictDefault(
31
+ {
32
+ "base_model": "illuin/tiny-random-FalconForCausalLM",
33
+ "flash_attention": True,
34
+ "sequence_len": 1024,
35
+ "load_in_8bit": True,
36
+ "adapter": "lora",
37
+ "lora_r": 32,
38
+ "lora_alpha": 64,
39
+ "lora_dropout": 0.05,
40
+ "lora_target_linear": True,
41
+ "lora_modules_to_save": [
42
+ "word_embeddings",
43
+ "lm_head",
44
+ ],
45
+ "val_set_size": 0.1,
46
+ "special_tokens": {
47
+ "bos_token": "<|endoftext|>",
48
+ "pad_token": "<|endoftext|>",
49
+ },
50
+ "datasets": [
51
+ {
52
+ "path": "mhenrichsen/alpaca_2k_test",
53
+ "type": "alpaca",
54
+ },
55
+ ],
56
+ "num_epochs": 2,
57
+ "micro_batch_size": 2,
58
+ "gradient_accumulation_steps": 1,
59
+ "output_dir": temp_dir,
60
+ "learning_rate": 0.00001,
61
+ "optimizer": "adamw_torch",
62
+ "lr_scheduler": "cosine",
63
+ "max_steps": 20,
64
+ "save_steps": 10,
65
+ "eval_steps": 10,
66
+ "bf16": "auto",
67
+ }
68
+ )
69
+ normalize_config(cfg)
70
+ cli_args = TrainerCliArgs()
71
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
72
+
73
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
74
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
75
+
76
+ @with_temp_dir
77
+ def test_lora_added_vocab(self, temp_dir):
78
+ # pylint: disable=duplicate-code
79
+ cfg = DictDefault(
80
+ {
81
+ "base_model": "illuin/tiny-random-FalconForCausalLM",
82
+ "flash_attention": True,
83
+ "sequence_len": 1024,
84
+ "load_in_8bit": True,
85
+ "adapter": "lora",
86
+ "lora_r": 32,
87
+ "lora_alpha": 64,
88
+ "lora_dropout": 0.05,
89
+ "lora_target_linear": True,
90
+ "lora_modules_to_save": [
91
+ "word_embeddings",
92
+ "lm_head",
93
+ ],
94
+ "val_set_size": 0.1,
95
+ "special_tokens": {
96
+ "bos_token": "<|endoftext|>",
97
+ "pad_token": "<|endoftext|>",
98
+ },
99
+ "tokens": [
100
+ "<|im_start|>",
101
+ "<|im_end|>",
102
+ ],
103
+ "datasets": [
104
+ {
105
+ "path": "mhenrichsen/alpaca_2k_test",
106
+ "type": "alpaca",
107
+ },
108
+ ],
109
+ "num_epochs": 2,
110
+ "micro_batch_size": 2,
111
+ "gradient_accumulation_steps": 1,
112
+ "output_dir": temp_dir,
113
+ "learning_rate": 0.00001,
114
+ "optimizer": "adamw_torch",
115
+ "lr_scheduler": "cosine",
116
+ "max_steps": 20,
117
+ "save_steps": 10,
118
+ "eval_steps": 10,
119
+ "bf16": "auto",
120
+ }
121
+ )
122
+ normalize_config(cfg)
123
+ cli_args = TrainerCliArgs()
124
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
125
+
126
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
127
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
128
+
129
+ @with_temp_dir
130
+ def test_ft(self, temp_dir):
131
+ # pylint: disable=duplicate-code
132
+ cfg = DictDefault(
133
+ {
134
+ "base_model": "illuin/tiny-random-FalconForCausalLM",
135
+ "flash_attention": True,
136
+ "sequence_len": 1024,
137
+ "val_set_size": 0.1,
138
+ "special_tokens": {
139
+ "bos_token": "<|endoftext|>",
140
+ "pad_token": "<|endoftext|>",
141
+ },
142
+ "datasets": [
143
+ {
144
+ "path": "mhenrichsen/alpaca_2k_test",
145
+ "type": "alpaca",
146
+ },
147
+ ],
148
+ "num_epochs": 2,
149
+ "micro_batch_size": 2,
150
+ "gradient_accumulation_steps": 1,
151
+ "output_dir": temp_dir,
152
+ "learning_rate": 0.00001,
153
+ "optimizer": "adamw_torch",
154
+ "lr_scheduler": "cosine",
155
+ "max_steps": 20,
156
+ "save_steps": 10,
157
+ "eval_steps": 10,
158
+ "bf16": "auto",
159
+ }
160
+ )
161
+ normalize_config(cfg)
162
+ cli_args = TrainerCliArgs()
163
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
164
+
165
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
166
+ assert (Path(temp_dir) / "pytorch_model.bin").exists()