Nanobit commited on
Commit
eb41f76
1 Parent(s): 383f88d

Feat: Add example for Mistral (#644)

Browse files

* Feat: Add example for Mistral

* chore: turn off flash

* chore: add is_mistral_derived_model

* chore: update following PR

README.md CHANGED
@@ -413,9 +413,10 @@ tokenizer_legacy:
413
  # this is reported to improve training speed on some models
414
  resize_token_embeddings_to_32x:
415
 
416
- # used to identify if the model is falcon/llama based
417
  is_falcon_derived_model:
418
  is_llama_derived_model:
 
419
 
420
  # whether you are training a 4-bit GPTQ quantized model
421
  gptq: true
 
413
  # this is reported to improve training speed on some models
414
  resize_token_embeddings_to_32x:
415
 
416
+ # used to identify which the model is based on
417
  is_falcon_derived_model:
418
  is_llama_derived_model:
419
+ is_mistral_derived_model:
420
 
421
  # whether you are training a 4-bit GPTQ quantized model
422
  gptq: true
examples/mistral/config.yml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: mistralai/Mistral-7B-v0.1
2
+ base_model_config: mistralai/Mistral-7B-v0.1
3
+ model_type: MistralForCausalLM
4
+ tokenizer_type: LlamaTokenizer
5
+ is_mistral_derived_model: true
6
+
7
+ load_in_8bit: false
8
+ load_in_4bit: false
9
+ strict: false
10
+
11
+ datasets:
12
+ - path: mhenrichsen/alpaca_2k_test
13
+ type: alpaca
14
+ dataset_prepared_path: last_run_prepared
15
+ val_set_size: 0.01
16
+ output_dir: ./out
17
+
18
+ sequence_len: 8192
19
+ sample_packing:
20
+ pad_to_sequence_len:
21
+
22
+ wandb_project:
23
+ wandb_entity:
24
+ wandb_watch:
25
+ wandb_run_id:
26
+ wandb_log_model:
27
+
28
+ gradient_accumulation_steps: 4
29
+ micro_batch_size: 2
30
+ num_epochs: 3
31
+ optimizer: adamw_bnb_8bit
32
+ lr_scheduler: cosine
33
+ learning_rate: 0.0002
34
+
35
+ train_on_inputs: false
36
+ group_by_length: false
37
+ bf16: true
38
+ fp16: false
39
+ tf32: false
40
+
41
+ gradient_checkpointing: true
42
+ early_stopping_patience:
43
+ resume_from_checkpoint:
44
+ local_rank:
45
+ logging_steps: 1
46
+ xformers_attention:
47
+ flash_attention: true
48
+
49
+ warmup_steps: 10
50
+ eval_steps: 20
51
+ eval_table_size: 5
52
+ eval_table_max_new_tokens: 128
53
+ save_steps:
54
+ debug:
55
+ deepspeed:
56
+ weight_decay: 0.0
57
+ fsdp:
58
+ fsdp_config:
59
+ special_tokens:
60
+ bos_token: "<s>"
61
+ eos_token: "</s>"
62
+ unk_token: "<unk>"
src/axolotl/utils/config.py CHANGED
@@ -82,7 +82,7 @@ def normalize_config(cfg):
82
  cfg.is_llama_derived_model = (
83
  (hasattr(model_config, "model_type") and model_config.model_type == "llama")
84
  or cfg.is_llama_derived_model
85
- or "llama" in cfg.base_model
86
  or (cfg.model_type and "llama" in cfg.model_type.lower())
87
  )
88
 
@@ -98,10 +98,23 @@ def normalize_config(cfg):
98
  ]
99
  )
100
  or cfg.is_falcon_derived_model
101
- or "falcon" in cfg.base_model
102
  or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
103
  )
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
106
 
107
 
 
82
  cfg.is_llama_derived_model = (
83
  (hasattr(model_config, "model_type") and model_config.model_type == "llama")
84
  or cfg.is_llama_derived_model
85
+ or "llama" in cfg.base_model.lower()
86
  or (cfg.model_type and "llama" in cfg.model_type.lower())
87
  )
88
 
 
98
  ]
99
  )
100
  or cfg.is_falcon_derived_model
101
+ or "falcon" in cfg.base_model.lower()
102
  or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
103
  )
104
 
105
+ cfg.is_mistral_derived_model = (
106
+ (
107
+ hasattr(model_config, "model_type")
108
+ and model_config.model_type
109
+ in [
110
+ "mistral",
111
+ ]
112
+ )
113
+ or cfg.is_mistral_derived_model
114
+ or "mistral" in cfg.base_model.lower()
115
+ or (cfg.model_type and "mistral" in cfg.model_type.lower())
116
+ )
117
+
118
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
119
 
120