winglian commited on
Commit
f5a828a
1 Parent(s): fccb542

Qwen2 (#1166)

Browse files

* qwen2 multipack support

* fix qwen derived model check so it doesn't break qwen2

* fixes to ensure qwen2 packing works

* bump requirements for qwen2

* requirements typo

requirements.txt CHANGED
@@ -1,10 +1,10 @@
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
  peft==0.7.0
4
- transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
5
  tokenizers==0.15.0
6
  bitsandbytes>=0.41.1
7
- accelerate @ git+https://github.com/huggingface/accelerate.git@0d2280dadc6a93413a5496613b7fdda3a4d2551b
8
  deepspeed
9
  addict
10
  fire
 
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
  peft==0.7.0
4
+ transformers==4.37.0
5
  tokenizers==0.15.0
6
  bitsandbytes>=0.41.1
7
+ accelerate==0.26.1
8
  deepspeed
9
  addict
10
  fire
src/axolotl/core/trainer_builder.py CHANGED
@@ -905,7 +905,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
905
  ]
906
  ]
907
  if use_batch_sampler_collator:
908
- if self.cfg.model_config_type == "mixtral":
909
  collator = V2BatchSamplerDataCollatorForSeq2Seq
910
  else:
911
  collator = BatchSamplerDataCollatorForSeq2Seq
 
905
  ]
906
  ]
907
  if use_batch_sampler_collator:
908
+ if self.cfg.model_config_type in ["mixtral", "qwen2"]:
909
  collator = V2BatchSamplerDataCollatorForSeq2Seq
910
  else:
911
  collator = BatchSamplerDataCollatorForSeq2Seq
src/axolotl/monkeypatch/qwen2/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Patches to support multipack for qwen2
3
+ """
4
+ import transformers
5
+
6
+ from axolotl.monkeypatch.utils import get_unpad_data
7
+
8
+
9
+ def replace_qwen2_attn_with_multipack_flash_attn():
10
+ transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
11
+ get_unpad_data
12
+ )
src/axolotl/utils/config.py CHANGED
@@ -142,17 +142,12 @@ def normalize_config(cfg):
142
  )
143
 
144
  cfg.is_qwen_derived_model = (
145
- (
146
- hasattr(model_config, "model_type")
147
- and model_config.model_type
148
- in [
149
- "qwen",
150
- ]
151
- )
152
- or cfg.is_qwen_derived_model
153
- or "qwen" in cfg.base_model.lower()
154
- or (cfg.model_type and "qwen" in cfg.model_type.lower())
155
- )
156
 
157
  if isinstance(cfg.learning_rate, str):
158
  cfg.learning_rate = float(cfg.learning_rate)
 
142
  )
143
 
144
  cfg.is_qwen_derived_model = (
145
+ hasattr(model_config, "model_type")
146
+ and model_config.model_type
147
+ in [
148
+ "qwen",
149
+ ]
150
+ ) or cfg.is_qwen_derived_model
 
 
 
 
 
151
 
152
  if isinstance(cfg.learning_rate, str):
153
  cfg.learning_rate = float(cfg.learning_rate)
src/axolotl/utils/models.py CHANGED
@@ -334,6 +334,14 @@ def load_model(
334
  LOG.info("patching mixtral with flash attention")
335
  replace_mixtral_attn_with_multipack_flash_attn()
336
 
 
 
 
 
 
 
 
 
337
  if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
338
  from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
339
 
@@ -426,14 +434,14 @@ def load_model(
426
  cfg.is_llama_derived_model
427
  or cfg.is_falcon_derived_model
428
  or cfg.is_mistral_derived_model
429
- or model_config.model_type == "mixtral"
430
  ):
431
  model_kwargs["attn_implementation"] = "flash_attention_2"
432
  model_config._attn_implementation = ( # pylint: disable=protected-access
433
  "flash_attention_2"
434
  )
435
  else:
436
- if model_config.model_type == "mixtral":
437
  model_kwargs["attn_implementation"] = "flash_attention_2"
438
  model_config._attn_implementation = ( # pylint: disable=protected-access
439
  "flash_attention_2"
 
334
  LOG.info("patching mixtral with flash attention")
335
  replace_mixtral_attn_with_multipack_flash_attn()
336
 
337
+ if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
338
+ from axolotl.monkeypatch.qwen2 import (
339
+ replace_qwen2_attn_with_multipack_flash_attn,
340
+ )
341
+
342
+ LOG.info("patching qwen2 with flash attention")
343
+ replace_qwen2_attn_with_multipack_flash_attn()
344
+
345
  if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
346
  from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
347
 
 
434
  cfg.is_llama_derived_model
435
  or cfg.is_falcon_derived_model
436
  or cfg.is_mistral_derived_model
437
+ or model_config.model_type in ["mixtral", "qwen2"]
438
  ):
439
  model_kwargs["attn_implementation"] = "flash_attention_2"
440
  model_config._attn_implementation = ( # pylint: disable=protected-access
441
  "flash_attention_2"
442
  )
443
  else:
444
+ if model_config.model_type in ["mixtral", "qwen2"]:
445
  model_kwargs["attn_implementation"] = "flash_attention_2"
446
  model_config._attn_implementation = ( # pylint: disable=protected-access
447
  "flash_attention_2"