winglian commited on
Commit
54d2ac1
β€’
1 Parent(s): af02430

Mixtral fixes 20240124 (#1192) [skip ci]

Browse files

* mixtral nccl fixes

* make sure to patch for z3

README.md CHANGED
@@ -861,7 +861,7 @@ tokens:
861
  fsdp:
862
  fsdp_config:
863
 
864
- # Deepspeed config path. e.g., deepspeed/zero3.json
865
  deepspeed:
866
 
867
  # Advanced DDP Arguments
@@ -982,11 +982,11 @@ for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usa
982
  We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
983
 
984
  ```yaml
985
- deepspeed: deepspeed/zero1.json
986
  ```
987
 
988
  ```shell
989
- accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
990
  ```
991
 
992
  ##### FSDP
 
861
  fsdp:
862
  fsdp_config:
863
 
864
+ # Deepspeed config path. e.g., deepspeed_configs/zero3.json
865
  deepspeed:
866
 
867
  # Advanced DDP Arguments
 
982
  We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
983
 
984
  ```yaml
985
+ deepspeed: deepspeed_configs/zero1.json
986
  ```
987
 
988
  ```shell
989
+ accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed_configs/zero1.json
990
  ```
991
 
992
  ##### FSDP
{deepspeed β†’ deepspeed_configs}/zero1.json RENAMED
File without changes
{deepspeed β†’ deepspeed_configs}/zero2.json RENAMED
File without changes
{deepspeed β†’ deepspeed_configs}/zero3.json RENAMED
File without changes
{deepspeed β†’ deepspeed_configs}/zero3_bf16.json RENAMED
File without changes
examples/llama-2/fft_optimized.yml CHANGED
@@ -62,7 +62,7 @@ evals_per_epoch: 4
62
  eval_table_size:
63
  saves_per_epoch: 1
64
  debug:
65
- deepspeed: #deepspeed/zero2.json # multi-gpu only
66
  weight_decay: 0.1
67
  fsdp:
68
  fsdp_config:
 
62
  eval_table_size:
63
  saves_per_epoch: 1
64
  debug:
65
+ deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
66
  weight_decay: 0.1
67
  fsdp:
68
  fsdp_config:
examples/mistral/Mistral-7b-example/code.ipynb CHANGED
@@ -942,7 +942,7 @@
942
  "not only optimizer states but also gradients and parameters across GPUs. The bf16 indicate mixed precision training using bfloat16.\n",
943
  "For more information read axolotl's readme\n",
944
  "\"\"\"\n",
945
- "!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed/zero3_bf16.json"
946
  ]
947
  }
948
  ],
 
942
  "not only optimizer states but also gradients and parameters across GPUs. The bf16 indicate mixed precision training using bfloat16.\n",
943
  "For more information read axolotl's readme\n",
944
  "\"\"\"\n",
945
+ "!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed_configs/zero3_bf16.json"
946
  ]
947
  }
948
  ],
examples/mistral/Mistral-7b-example/config.yml CHANGED
@@ -65,7 +65,7 @@ eval_table_max_new_tokens: 128
65
  saves_per_epoch: 1
66
  debug:
67
  #default deepspeed, can use more aggresive if needed like zero2, zero3
68
- deepspeed: deepspeed/zero1.json
69
  weight_decay: 0.0
70
  fsdp:
71
  fsdp_config:
 
65
  saves_per_epoch: 1
66
  debug:
67
  #default deepspeed, can use more aggresive if needed like zero2, zero3
68
+ deepspeed: deepspeed_configs/zero1.json
69
  weight_decay: 0.0
70
  fsdp:
71
  fsdp_config:
examples/mistral/README.md CHANGED
@@ -8,5 +8,5 @@ accelerate launch -m axolotl.cli.train examples/mistral/config.yml
8
 
9
  If you run into CUDA OOM, use deepspeed with config zero2.json:
10
  ```shell
11
- accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json
12
  ```
 
8
 
9
  If you run into CUDA OOM, use deepspeed with config zero2.json:
10
  ```shell
11
+ accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed_configs/zero2.json
12
  ```
examples/mistral/mixtral.yml CHANGED
@@ -84,7 +84,7 @@ eval_table_size:
84
  eval_table_max_new_tokens: 128
85
  saves_per_epoch: 1
86
  debug:
87
- deepspeed: deepspeed/zero2.json
88
  weight_decay: 0.0
89
  fsdp:
90
  fsdp_config:
 
84
  eval_table_max_new_tokens: 128
85
  saves_per_epoch: 1
86
  debug:
87
+ deepspeed: deepspeed_configs/zero2.json
88
  weight_decay: 0.0
89
  fsdp:
90
  fsdp_config:
examples/phi/README.md CHANGED
@@ -3,7 +3,7 @@
3
  Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.
4
 
5
  ```shell
6
- accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed/zero1.json
7
 
8
  # OR
9
 
 
3
  Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.
4
 
5
  ```shell
6
+ accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed_configs/zero1.json
7
 
8
  # OR
9
 
src/axolotl/monkeypatch/mixtral/__init__.py CHANGED
@@ -1,12 +1,61 @@
1
  """
2
  Patches to support multipack for mixtral
3
  """
 
4
  import transformers
5
 
6
  from axolotl.monkeypatch.utils import get_unpad_data
7
 
8
 
9
- def replace_mixtral_attn_with_multipack_flash_attn():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
11
  get_unpad_data
12
  )
 
 
 
1
  """
2
  Patches to support multipack for mixtral
3
  """
4
+ import torch
5
  import transformers
6
 
7
  from axolotl.monkeypatch.utils import get_unpad_data
8
 
9
 
10
+ def patch_mixtral_moe_forward_zero3() -> None:
11
+ import torch.nn.functional as F
12
+
13
+ def mlp_forward(self, hidden_states):
14
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
15
+ hidden_states
16
+ )
17
+ current_hidden_states = self.w2(current_hidden_states)
18
+ return current_hidden_states
19
+
20
+ # Ref. https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
21
+ def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
22
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
23
+ hidden_states = hidden_states.view(-1, hidden_dim)
24
+ # router_logits: (batch * sequence_length, n_experts)
25
+ router_logits = self.gate(hidden_states)
26
+
27
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
28
+ topk_weight, topk_idx = torch.topk(
29
+ routing_weights, self.top_k, dim=-1, sorted=False
30
+ )
31
+ topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
32
+ # we cast back to the input dtype
33
+ topk_weight = topk_weight.to(hidden_states.dtype)
34
+
35
+ hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
36
+ y = torch.empty_like(hidden_states) # pylint: disable=invalid-name
37
+ flat_topk_idx = topk_idx.view(-1)
38
+ for i in range(self.num_experts):
39
+ expert = self.experts[i]
40
+ y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
41
+ y = ( # pylint: disable=invalid-name
42
+ y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)
43
+ ).sum(dim=1)
44
+ final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
45
+ return final_hidden_states, router_logits
46
+
47
+ from transformers.models.mixtral.modeling_mixtral import (
48
+ MixtralBLockSparseTop2MLP,
49
+ MixtralSparseMoeBlock,
50
+ )
51
+
52
+ MixtralBLockSparseTop2MLP.forward = mlp_forward
53
+ MixtralSparseMoeBlock.forward = moe_forward
54
+
55
+
56
+ def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
57
  transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
58
  get_unpad_data
59
  )
60
+ if for_zero3:
61
+ patch_mixtral_moe_forward_zero3()
src/axolotl/train.py CHANGED
@@ -15,7 +15,7 @@ from optimum.bettertransformer import BetterTransformer
15
  from peft import PeftModel
16
  from pkg_resources import get_distribution # type: ignore
17
  from transformers import PreTrainedModel, PreTrainedTokenizer
18
- from transformers.deepspeed import is_deepspeed_zero3_enabled
19
 
20
  from axolotl.common.cli import TrainerCliArgs
21
  from axolotl.logging_config import configure_logging
 
15
  from peft import PeftModel
16
  from pkg_resources import get_distribution # type: ignore
17
  from transformers import PreTrainedModel, PreTrainedTokenizer
18
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
19
 
20
  from axolotl.common.cli import TrainerCliArgs
21
  from axolotl.logging_config import configure_logging
src/axolotl/utils/models.py CHANGED
@@ -21,7 +21,7 @@ from transformers import ( # noqa: F401
21
  PreTrainedModel,
22
  PreTrainedTokenizerBase,
23
  )
24
- from transformers.deepspeed import is_deepspeed_zero3_enabled
25
 
26
  from axolotl.models.mamba import fix_mamba_attn_for_loss
27
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
@@ -333,7 +333,10 @@ def load_model(
333
  )
334
 
335
  LOG.info("patching mixtral with flash attention")
336
- replace_mixtral_attn_with_multipack_flash_attn()
 
 
 
337
 
338
  if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
339
  from axolotl.monkeypatch.falcon import (
@@ -646,6 +649,12 @@ def load_model(
646
  needs_fa2_dtype = cfg.adapter or cfg.fsdp
647
  skip_prepare_model_for_kbit_training = False
648
 
 
 
 
 
 
 
649
  if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
650
  # Qwen doesn't play nicely with LoRA if this is enabled
651
  skip_prepare_model_for_kbit_training = True
 
21
  PreTrainedModel,
22
  PreTrainedTokenizerBase,
23
  )
24
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
25
 
26
  from axolotl.models.mamba import fix_mamba_attn_for_loss
27
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
 
333
  )
334
 
335
  LOG.info("patching mixtral with flash attention")
336
+ mixtral_patch_kwargs = {}
337
+ if is_deepspeed_zero3_enabled():
338
+ mixtral_patch_kwargs["for_zero3"] = True
339
+ replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
340
 
341
  if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
342
  from axolotl.monkeypatch.falcon import (
 
649
  needs_fa2_dtype = cfg.adapter or cfg.fsdp
650
  skip_prepare_model_for_kbit_training = False
651
 
652
+ if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
653
+ from deepspeed.utils import set_z3_leaf_modules
654
+ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
655
+
656
+ set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
657
+
658
  if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
659
  # Qwen doesn't play nicely with LoRA if this is enabled
660
  skip_prepare_model_for_kbit_training = True