Maxime Nanobit winglian commited on
Commit
0f6af36
1 Parent(s): 3f69571

Mps mistral lora (#1292) [skip ci]

Browse files

* Lora example for Mistral on MPS backend

* Add some MPS documentation

* Update examples/mistral/lora-mps.yml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/mistral/lora-mps.yml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update README.md

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>

Files changed (4) hide show
  1. .gitignore +5 -0
  2. README.md +17 -1
  3. docs/mac.md +18 -0
  4. examples/mistral/lora-mps.yml +79 -0
.gitignore CHANGED
@@ -167,3 +167,8 @@ cython_debug/
167
  # WandB
168
  # wandb creates a folder to store logs for training runs
169
  wandb
 
 
 
 
 
 
167
  # WandB
168
  # wandb creates a folder to store logs for training runs
169
  wandb
170
+
171
+ # Runs
172
+ lora-out/*
173
+ qlora-out/*
174
+ mlruns/*
README.md CHANGED
@@ -99,7 +99,23 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
99
 
100
  **Requirements**: Python >=3.9 and Pytorch >=2.1.1.
101
 
102
- `pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  ### Usage
105
  ```bash
 
99
 
100
  **Requirements**: Python >=3.9 and Pytorch >=2.1.1.
101
 
102
+ ### For developers
103
+ ```bash
104
+ git clone https://github.com/OpenAccess-AI-Collective/axolotl
105
+ cd axolotl
106
+
107
+ pip3 install packaging
108
+ ```
109
+
110
+ General case:
111
+ ```
112
+ pip3 install -e '.[flash-attn,deepspeed]'
113
+ ```
114
+
115
+ Mac: see https://github.com/OpenAccess-AI-Collective/axolotl/blob/13199f678b9aab39e92961323bdbce3234ee4b2b/docs/mac.md
116
+ ```
117
+ pip3 install -e '.'
118
+ ```
119
 
120
  ### Usage
121
  ```bash
docs/mac.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mac M series support
2
+
3
+ Currently Axolotl on Mac is partially usable, many of the dependencies of Axolotl including Pytorch do not support MPS or have incomplete support.
4
+
5
+ Current support:
6
+ - [x] Support for all models
7
+ - [x] Full training of models
8
+ - [x] LoRA training
9
+ - [x] Sample packing
10
+ - [ ] FP16 and BF16 (awaiting AMP support for MPS in Pytorch)
11
+ - [ ] Tri-dao's flash-attn (until it is supported use spd_attention as an alternative)
12
+ - [ ] xformers
13
+ - [ ] bitsandbytes (meaning no 4/8 bits loading and bnb optimizers)
14
+ - [ ] qlora
15
+ - [ ] DeepSpeed
16
+
17
+ Untested:
18
+ - FSDP
examples/mistral/lora-mps.yml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: mistralai/Mistral-7B-v0.1
2
+ model_type: MistralForCausalLM
3
+ tokenizer_type: LlamaTokenizer
4
+
5
+ load_in_8bit: false
6
+ load_in_4bit: false
7
+ strict: false
8
+
9
+ datasets:
10
+ - path: mhenrichsen/alpaca_2k_test
11
+ type: alpaca
12
+ dataset_prepared_path: last_run_prepared
13
+ val_set_size: 0
14
+ output_dir: ./lora-out
15
+ eval_sample_packing: false
16
+
17
+ adapter: lora
18
+ lora_model_dir:
19
+
20
+ sequence_len: 4096
21
+ sample_packing: true
22
+ pad_to_sequence_len: true
23
+
24
+ lora_r: 32
25
+ lora_alpha: 16
26
+ lora_dropout: 0.05
27
+ lora_target_linear: true
28
+ lora_fan_in_fan_out:
29
+ lora_target_modules:
30
+ - gate_proj
31
+ - down_proj
32
+ - up_proj
33
+ - q_proj
34
+ - v_proj
35
+ - k_proj
36
+ - o_proj
37
+
38
+ wandb_project:
39
+ wandb_entity:
40
+ wandb_watch:
41
+ wandb_name:
42
+ wandb_log_model:
43
+
44
+ gradient_accumulation_steps: 8
45
+ micro_batch_size: 1
46
+ num_epochs: 2
47
+ optimizer: adamw_torch
48
+ lr_scheduler: cosine
49
+ learning_rate: 0.0002
50
+
51
+ train_on_inputs: false
52
+ group_by_length: false
53
+ bf16: auto
54
+ fp16: false
55
+ tf32: true
56
+
57
+ gradient_checkpointing: true
58
+ early_stopping_patience:
59
+ resume_from_checkpoint:
60
+ local_rank:
61
+ logging_steps: 1
62
+ xformers_attention:
63
+ flash_attention: false
64
+ sdp_attention: true
65
+
66
+ loss_watchdog_threshold: 5.0
67
+ loss_watchdog_patience: 3
68
+
69
+ warmup_steps: 10
70
+ evals_per_epoch: 4
71
+ eval_table_size:
72
+ eval_table_max_new_tokens: 128
73
+ saves_per_epoch: 1
74
+ debug:
75
+ deepspeed:
76
+ weight_decay: 0.0
77
+ fsdp:
78
+ fsdp_config:
79
+ special_tokens: