winglian commited on
Commit
fd2c981
β€’
2 Parent(s): c9a149f 93dacba

Merge branch 'main' into flash-optimum

Browse files
FAQS.md CHANGED
@@ -2,3 +2,6 @@
2
 
3
  - Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)
4
  - Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases
 
 
 
 
2
 
3
  - Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)
4
  - Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases
5
+ - `Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c`
6
+ `/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598: arrow::fs::FinalizeS3 was not called even though S3 was initialized.`
7
+ This could lead to a segmentation fault at exit. Try reinstalling bitsandbytes and transformers from source.
README.md CHANGED
@@ -16,13 +16,14 @@
16
 
17
  ## Axolotl supports
18
 
19
- | | fp16/fp32 | fp16/fp32 w/ lora | qlora | 4bit-quant | 4bit-quant w/flash attention | flash attention | xformers attention |
20
- |---------|:----------|:------------------|------|------------|------------------------------|-----------------|--------------------|
21
- | llama | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… |
22
- | Pythia | βœ… | βœ… | ❓ | ❌ | ❌ | ❌ | ❓ |
23
- | cerebras | βœ… | βœ… | ❓ | ❌ | ❌ | ❌ | ❓ |
24
- | mpt | βœ… | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
25
- | falcon | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❓ |
 
26
 
27
 
28
  ## Quickstart ⚑
@@ -38,10 +39,10 @@ pip3 install -U git+https://github.com/huggingface/peft.git
38
  accelerate config
39
 
40
  # finetune lora
41
- accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml
42
 
43
  # inference
44
- accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
45
  --inference --lora_model_dir="./lora-out"
46
  ```
47
 
@@ -218,6 +219,14 @@ Have dataset(s) in one of the following format (JSONL recommended):
218
  ```json
219
  {"conversations": [{"role": "...", "value": "..."}]}
220
  ```
 
 
 
 
 
 
 
 
221
 
222
  </details>
223
 
@@ -381,6 +390,8 @@ num_epochs: 3
381
  warmup_steps: 100
382
  learning_rate: 0.00003
383
  logging_steps:
 
 
384
 
385
  # whether to mask out or include the human's prompt from the training labels
386
  train_on_inputs: false
@@ -497,6 +508,11 @@ Pass the appropriate flag to the train command:
497
  ```bash
498
  --inference --base_model ./completed-model
499
  ```
 
 
 
 
 
500
 
501
  ### Merge LORA to base
502
 
@@ -524,7 +540,7 @@ Try set `fp16: true`
524
 
525
  Try to turn off xformers.
526
 
527
- ## Need help? πŸ™‹β€β™‚οΈ
528
 
529
  Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
530
 
 
16
 
17
  ## Axolotl supports
18
 
19
+ | | fp16/fp32 | lora | qlora | gptq | gptq w/ lora | gptq w/flash attn | flash attn | xformers attn |
20
+ |----------|:----------|:-----|-------|------|:-------------|-------------------|------------|---------------|
21
+ | llama | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… |
22
+ | Pythia | βœ… | βœ… | βœ… | ❌ | ❓ | ❌ | ❌ | ❓ |
23
+ | cerebras | βœ… | βœ… | βœ… | ❌ | ❓ | ❌ | ❌ | βœ… |
24
+ | mpt | βœ… | ❌ | ❓ | ❌ | ❓ | ❌ | ❌ | ❓ |
25
+ | falcon | βœ… | βœ… | βœ… | ❌ | ❓ | ❌ | ❌ | βœ… |
26
+ | gpt-j | βœ… | βœ… | βœ… | ❌ | ❓ | ❌ | ❓ | βœ… |
27
 
28
 
29
  ## Quickstart ⚑
 
39
  accelerate config
40
 
41
  # finetune lora
42
+ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
43
 
44
  # inference
45
+ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
46
  --inference --lora_model_dir="./lora-out"
47
  ```
48
 
 
219
  ```json
220
  {"conversations": [{"role": "...", "value": "..."}]}
221
  ```
222
+ - `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
223
+ ```json
224
+ {"conversations": [{"role": "...", "value": "..."}]}
225
+ ```
226
+ - `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
227
+ ```json
228
+ {"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
229
+ ```
230
 
231
  </details>
232
 
 
390
  warmup_steps: 100
391
  learning_rate: 0.00003
392
  logging_steps:
393
+ save_steps:
394
+ eval_steps:
395
 
396
  # whether to mask out or include the human's prompt from the training labels
397
  train_on_inputs: false
 
508
  ```bash
509
  --inference --base_model ./completed-model
510
  ```
511
+ - Full weights finetune w/ a prompt from a text file:
512
+ ```bash
513
+ cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
514
+ --base_model ./completed-model --inference --prompter=None --load_in_8bit=True
515
+ ```
516
 
517
  ### Merge LORA to base
518
 
 
540
 
541
  Try to turn off xformers.
542
 
543
+ ## Need help? πŸ™‹β™‚οΈ
544
 
545
  Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
546
 
configs/accelerate/default_config.yaml DELETED
@@ -1,15 +0,0 @@
1
- compute_environment: LOCAL_MACHINE
2
- distributed_type: 'NO'
3
- downcast_bf16: 'no'
4
- gpu_ids: all
5
- machine_rank: 0
6
- main_training_function: main
7
- mixed_precision: bf16
8
- num_machines: 1
9
- num_processes: 1
10
- rdzv_backend: static
11
- same_network: true
12
- tpu_env: []
13
- tpu_use_cluster: false
14
- tpu_use_sudo: false
15
- use_cpu: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/cerebras_1_3B_alpaca.yml DELETED
@@ -1,40 +0,0 @@
1
- base_model: cerebras/Cerebras-GPT-1.3B
2
- model_type: AutoModelForCausalLM
3
- tokenizer_type: AutoTokenizer
4
- load_in_8bit: true
5
- datasets:
6
- - path: data/alpaca_data_gpt4.jsonl
7
- type: alpaca
8
- - path: data/vicuna_cleaned.jsonl
9
- type: sharegpt
10
- - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
11
- type: gpteacher
12
- - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
13
- type: gpteacher
14
- dataset_prepared_path: last_run_prepared
15
- val_set_size: 0.05
16
- adapter: lora
17
- sequence_len: 2048
18
- lora_r: 8
19
- lora_alpha: 16
20
- lora_dropout: 0.05
21
- lora_target_modules:
22
- - c_attn
23
- lora_fan_in_fan_out: false
24
- wandb_project: pythia-1.4b-lora
25
- wandb_watch:
26
- wandb_run_id:
27
- wandb_log_model:
28
- output_dir: ./lora-alpaca
29
- gradient_accumulation_steps: 1
30
- micro_batch_size: 4
31
- num_epochs: 5
32
- learning_rate: 0.0003
33
- train_on_inputs: false
34
- group_by_length: false
35
- bf16: True
36
- tf32: True
37
- gradient_checkpointing:
38
- early_stopping_patience:
39
- resume_from_checkpoint:
40
- local_rank:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/galactica_1_3B.yml DELETED
@@ -1,41 +0,0 @@
1
- base_model: facebook/galactica-1.3b
2
- model_type: AutoModelForCausalLM
3
- tokenizer_type: AutoTokenizer
4
- load_in_8bit: false
5
- datasets:
6
- - path: tatsu-lab/alpaca
7
- type: alpaca
8
- dataset_prepared_path: last_run_prepared
9
- val_set_size: 0.1
10
- adapter:
11
- lora_model_dir:
12
- sequence_len: 1024
13
- max_packed_sequence_len: 1024
14
- lora_r: 8
15
- lora_alpha: 16
16
- lora_dropout: 0.05
17
- lora_target_modules:
18
- - q_proj
19
- - v_proj
20
- lora_fan_in_fan_out: false
21
- wandb_project:
22
- wandb_watch:
23
- wandb_run_id:
24
- wandb_log_model:
25
- output_dir: ./lora-llama-alpaca
26
- gradient_accumulation_steps: 1
27
- micro_batch_size: 16
28
- num_epochs: 3
29
- learning_rate: 0.00003
30
- train_on_inputs: false
31
- group_by_length: false
32
- bf16: false
33
- tf32: false
34
- early_stopping_patience:
35
- resume_from_checkpoint:
36
- local_rank:
37
- tokens:
38
- pad_token: "[PAD]"
39
- bos_token: "<s>"
40
- eos_token: "</s>"
41
- unk_token: "<unk>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/llama_13B_alpaca.yml DELETED
@@ -1,39 +0,0 @@
1
- base_model: huggyllama/llama-13b
2
- model_type: LlamaForCausalLM
3
- tokenizer_type: LlamaTokenizer
4
- load_in_8bit: true
5
- datasets:
6
- - path: anon8231489123/ShareGPT_Vicuna_unfiltered
7
- data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
8
- type: sharegpt
9
- dataset_prepared_path: last_run_prepared
10
- val_set_size: 0.002
11
- adapter:
12
- lora_model_dir:
13
- sequence_len: 2048
14
- lora_r: 8
15
- lora_alpha: 16
16
- lora_dropout: 0.05
17
- lora_target_modules:
18
- - q_proj
19
- - v_proj
20
- lora_fan_in_fan_out: false
21
- wandb_project:
22
- wandb_watch:
23
- wandb_run_id:
24
- wandb_log_model:
25
- output_dir: ./llama-13b-sharegpt
26
- gradient_accumulation_steps: 1
27
- micro_batch_size: 2
28
- warmup_steps: 1000
29
- save_steps:
30
- eval_steps:
31
- num_epochs: 5
32
- learning_rate: 0.00003
33
- train_on_inputs: false
34
- group_by_length: false
35
- bf16: true
36
- tf32: true
37
- early_stopping_patience: 5
38
- resume_from_checkpoint:
39
- local_rank:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/llama_65B_alpaca.yml DELETED
@@ -1,44 +0,0 @@
1
- base_model: huggyllama/llama-65b
2
- model_type: LlamaForCausalLM
3
- tokenizer_type: LlamaTokenizer
4
- load_in_8bit: true
5
- datasets:
6
- - path: data/alpaca_data_gpt4.jsonl
7
- type: alpaca
8
- - path: anon8231489123/ShareGPT_Vicuna_unfiltered
9
- data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
10
- type: sharegpt
11
- - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
12
- type: gpteacher
13
- - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
14
- type: gpteacher
15
- dataset_prepared_path: last_run_prepared
16
- val_set_size: 0.04
17
- adapter: lora
18
- lora_model_dir:
19
- sequence_len: 2048
20
- lora_r: 8
21
- lora_alpha: 16
22
- lora_dropout: 0.05
23
- lora_target_modules:
24
- - q_proj
25
- - v_proj
26
- lora_fan_in_fan_out: false
27
- wandb_project: llama-65b-lora
28
- wandb_watch:
29
- wandb_run_id:
30
- wandb_log_model:
31
- output_dir: ./lora-llama-alpaca
32
- gradient_accumulation_steps: 1
33
- micro_batch_size: 16
34
- warmup_steps: 1000
35
- save_steps:
36
- num_epochs: 5
37
- learning_rate: 0.00003
38
- train_on_inputs: false
39
- group_by_length: false
40
- bf16: true
41
- tf32: true
42
- early_stopping_patience:
43
- resume_from_checkpoint:
44
- local_rank:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/llama_7B_4bit.yml DELETED
@@ -1,45 +0,0 @@
1
- base_model: decapoda-research/llama-7b-hf-int4
2
- base_model_config: decapoda-research/llama-7b-hf
3
- model_type: LlamaForCausalLM
4
- tokenizer_type: LlamaTokenizer
5
- load_in_8bit: true
6
- datasets:
7
- - path: tatsu-lab/alpaca # original alpaca dataset
8
- type: alpaca
9
- dataset_prepared_path: data/last_run_prepared
10
- val_set_size: 0.04
11
- adapter: lora
12
- lora_model_dir:
13
- sequence_len: 2048
14
- max_packed_sequence_len: 1024
15
- lora_r: 8
16
- lora_alpha: 16
17
- lora_dropout: 0.05
18
- lora_target_modules:
19
- - q_proj
20
- - v_proj
21
- # - k_proj
22
- # - o_proj
23
- lora_fan_in_fan_out: false
24
- wandb_project:
25
- wandb_watch:
26
- wandb_run_id:
27
- wandb_log_model:
28
- output_dir: ./lora-test
29
- gradient_accumulation_steps: 1
30
- micro_batch_size: 2
31
- num_epochs: 3
32
- warmup_steps: 100
33
- learning_rate: 0.00003
34
- train_on_inputs: false
35
- group_by_length: false
36
- bf16: true
37
- tf32: true
38
- gradient_checkpointing: false
39
- early_stopping_patience: 3
40
- resume_from_checkpoint:
41
- auto_resume_from_checkpoints: true
42
- local_rank:
43
- load_4bit: true
44
- xformers_attention: true
45
- flash_attention:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/llama_7B_alpaca.yml DELETED
@@ -1,41 +0,0 @@
1
- base_model: huggyllama/llama-7b
2
- model_type: LlamaForCausalLM
3
- tokenizer_type: LlamaTokenizer
4
- load_in_8bit: true
5
- datasets:
6
- - path: data/alpaca_data_gpt4.jsonl
7
- type: alpaca
8
- - path: data/vicuna_cleaned.jsonl
9
- type: sharegpt
10
- - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
11
- type: gpteacher
12
- - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
13
- type: gpteacher
14
- dataset_prepared_path: last_run_prepared
15
- val_set_size: 0.04
16
- adapter: lora
17
- lora_model_dir:
18
- sequence_len: 2048
19
- lora_r: 8
20
- lora_alpha: 16
21
- lora_dropout: 0.05
22
- lora_target_modules:
23
- - q_proj
24
- - v_proj
25
- lora_fan_in_fan_out: false
26
- wandb_project: llama-7b-lora
27
- wandb_watch:
28
- wandb_run_id:
29
- wandb_log_model:
30
- output_dir: ./lora-llama-alpaca
31
- gradient_accumulation_steps: 1
32
- micro_batch_size: 16
33
- num_epochs: 5
34
- learning_rate: 0.00003
35
- train_on_inputs: false
36
- group_by_length: false
37
- bf16: true
38
- tf32: true
39
- early_stopping_patience:
40
- resume_from_checkpoint:
41
- local_rank:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/quickstart.yml DELETED
@@ -1,45 +0,0 @@
1
- base_model: decapoda-research/llama-7b-hf-int4
2
- base_model_config: decapoda-research/llama-7b-hf
3
- model_type: LlamaForCausalLM
4
- tokenizer_type: LlamaTokenizer
5
- load_in_8bit: true
6
- datasets:
7
- - path: tatsu-lab/alpaca # original alpaca dataset
8
- type: alpaca
9
- dataset_prepared_path: data/last_run_prepared
10
- val_set_size: 0.04
11
- adapter: lora
12
- lora_model_dir:
13
- sequence_len: 1024
14
- max_packed_sequence_len: 1024
15
- lora_r: 8
16
- lora_alpha: 16
17
- lora_dropout: 0.05
18
- lora_target_modules:
19
- - q_proj
20
- - v_proj
21
- # - k_proj
22
- # - o_proj
23
- lora_fan_in_fan_out: false
24
- wandb_project:
25
- wandb_watch:
26
- wandb_run_id:
27
- wandb_log_model:
28
- output_dir: ./lora-test
29
- gradient_accumulation_steps: 1
30
- micro_batch_size: 1
31
- num_epochs: 3
32
- warmup_steps: 100
33
- learning_rate: 0.00003
34
- train_on_inputs: false
35
- group_by_length: false
36
- bf16: true
37
- tf32: true
38
- gradient_checkpointing: false
39
- early_stopping_patience: 3
40
- resume_from_checkpoint:
41
- auto_resume_from_checkpoints: true
42
- local_rank:
43
- gptq: true
44
- xformers_attention: true
45
- flash_attention:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/sample.yml DELETED
@@ -1,87 +0,0 @@
1
- # this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
2
- # this can also be a relative path to a model on disk
3
- base_model: decapoda-research/llama-7b-hf-int4
4
- # you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
5
- base_model_ignore_patterns:
6
- # if the base_model repo on hf hub doesn't include configuration .json files,
7
- # you can set that here, or leave this empty to default to base_model
8
- base_model_config: decapoda-research/llama-7b-hf
9
- # If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
10
- model_type: AutoModelForCausalLM
11
- # Corresponding tokenizer for the model AutoTokenizer is a good choice
12
- tokenizer_type: AutoTokenizer
13
- # whether you are training a 4-bit quantized model
14
- load_4bit: true
15
- # this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
16
- load_in_8bit: true
17
- # a list of one or more datasets to finetune the model with
18
- datasets:
19
- # this can be either a hf dataset, or relative path
20
- - path: vicgalle/alpaca-gpt4
21
- # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
22
- type: alpaca
23
- # axolotl attempts to save the dataset as an arrow after packing the data together so
24
- # subsequent training attempts load faster, relative path
25
- dataset_prepared_path: data/last_run_prepared
26
- # How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc
27
- val_set_size: 0.04
28
- # if you want to use lora, leave blank to train all parameters in original model
29
- adapter: lora
30
- # if you already have a lora model trained that you want to load, put that here
31
- lora_model_dir:
32
- # the maximum length of an input to train with, this should typically be less than 2048
33
- # as most models have a token/context limit of 2048
34
- sequence_len: 2048
35
- # max sequence length to concatenate training samples together up to
36
- # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
37
- max_packed_sequence_len: 1024
38
- # lora hyperparameters
39
- lora_r: 8
40
- lora_alpha: 16
41
- lora_dropout: 0.05
42
- lora_target_modules:
43
- - q_proj
44
- - v_proj
45
- # - k_proj
46
- # - o_proj
47
- lora_fan_in_fan_out: false
48
- # wandb configuration if your're using it
49
- wandb_project:
50
- wandb_watch:
51
- wandb_run_id:
52
- wandb_log_model:
53
- # where to save the finsihed model to
54
- output_dir: ./completed-model
55
- # training hyperparameters
56
- gradient_accumulation_steps: 1
57
- batch_size:
58
- micro_batch_size: 2
59
- num_epochs: 3
60
- warmup_steps: 100
61
- learning_rate: 0.00003
62
- # whether to mask out or include the human's prompt from the training labels
63
- train_on_inputs: false
64
- # don't use this, leads to wonky training (according to someone on the internet)
65
- group_by_length: false
66
- # Use CUDA bf16
67
- bf16: true
68
- # Use CUDA tf32
69
- tf32: true
70
- # does not work with current implementation of 4-bit LoRA
71
- gradient_checkpointing: false
72
- # stop training after this many evaluation losses have increased in a row
73
- # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
74
- early_stopping_patience: 3
75
- # specify a scheduler to use with the optimizer. only one_cycle is supported currently
76
- lr_scheduler:
77
- # whether to use xformers attention patch https://github.com/facebookresearch/xformers:
78
- xformers_attention:
79
- # whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
80
- flash_attention:
81
- # resume from a specific checkpoint dir
82
- resume_from_checkpoint:
83
- # if resume_from_checkpoint isn't set and you simply want it to start where it left off
84
- # be careful with this being turned on between different models
85
- auto_resume_from_checkpoints: false
86
- # don't mess with this, it's here for accelerate and torchrun
87
- local_rank:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/vicuna_13B_4bit_reflect.yml DELETED
@@ -1,45 +0,0 @@
1
- base_model: anon8231489123/vicuna-13b-GPTQ-4bit-128g
2
- base_model_config: anon8231489123/vicuna-13b-GPTQ-4bit-128g
3
- model_type: LlamaForCausalLM
4
- tokenizer_type: LlamaTokenizer
5
- load_in_8bit: false
6
- load_4bit: true
7
- gptq_groupsize: 128
8
- gptq_model_v1: false
9
- datasets:
10
- # https://github.com/vaguenebula/AlpacaDataReflect/blob/main/alpaca_reflect_pruned.json
11
- - path: data/alpaca_reflect_pruned.jsonl
12
- type: reflection
13
- dataset_prepared_path: data/last_run_prepared
14
- val_set_size: 0.04
15
- adapter: lora
16
- lora_model_dir:
17
- sequence_len: 2048
18
- max_packed_sequence_len: 2048
19
- lora_r: 8
20
- lora_alpha: 16
21
- lora_dropout: 0.05
22
- lora_target_modules:
23
- - q_proj
24
- - v_proj
25
- # - k_proj
26
- # - o_proj
27
- lora_fan_in_fan_out: false
28
- wandb_project:
29
- wandb_watch:
30
- wandb_run_id:
31
- wandb_log_model:
32
- output_dir: ./lora-reflect
33
- gradient_accumulation_steps: 1
34
- micro_batch_size: 2
35
- num_epochs: 3
36
- learning_rate: 0.00003
37
- train_on_inputs: false
38
- group_by_length: false
39
- bf16: true
40
- tf32: true
41
- gradient_checkpointing: false
42
- early_stopping_patience: 3
43
- resume_from_checkpoint:
44
- local_rank:
45
- flash_attention: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/cerebras/qlora.yml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: cerebras/Cerebras-GPT-1.3B
2
+ base_model_config: cerebras/Cerebras-GPT-1.3B
3
+ load_in_8bit: false
4
+ load_in_4bit: true
5
+ strict: false
6
+ push_dataset_to_hub:
7
+ datasets:
8
+ - path: teknium/GPT4-LLM-Cleaned
9
+ type: alpaca
10
+ dataset_prepared_path: last_run_prepared
11
+ val_set_size: 0.01
12
+ adapter: qlora
13
+ lora_model_dir:
14
+ sequence_len: 2048
15
+ max_packed_sequence_len: 2048
16
+ lora_r: 16
17
+ lora_alpha: 32
18
+ lora_dropout: 0.05
19
+ lora_target_modules:
20
+ - c_fc
21
+ - c_attn
22
+ - c_proj
23
+ lora_target_linear:
24
+ lora_fan_in_fan_out:
25
+ wandb_project:
26
+ wandb_watch:
27
+ wandb_run_id:
28
+ wandb_log_model:
29
+ output_dir: ./qlora-out
30
+ batch_size: 4
31
+ micro_batch_size: 4
32
+ num_epochs: 2
33
+ optimizer: paged_adamw_8bit
34
+ torchdistx_path:
35
+ lr_scheduler: cosine
36
+ learning_rate: 0.0002
37
+ train_on_inputs: false
38
+ group_by_length: true
39
+ bf16: true
40
+ fp16: false
41
+ tf32: true
42
+ gradient_checkpointing: true
43
+ early_stopping_patience:
44
+ resume_from_checkpoint:
45
+ local_rank:
46
+ logging_steps: 1
47
+ xformers_attention: true
48
+ flash_attention:
49
+ gptq_groupsize:
50
+ gptq_model_v1:
51
+ warmup_steps: 10
52
+ eval_steps: 20
53
+ save_steps:
54
+ debug:
55
+ deepspeed:
56
+ weight_decay: 0.1
57
+ fsdp:
58
+ fsdp_config:
59
+ special_tokens:
60
+ pad_token: "<|endoftext|>"
examples/falcon/config-7b-lora.yml CHANGED
@@ -23,7 +23,7 @@ lora_dropout: 0.0
23
  lora_target_modules:
24
  lora_target_linear: true
25
  lora_fan_in_fan_out:
26
- wandb_project: falcon-7b
27
  wandb_watch:
28
  wandb_run_id:
29
  wandb_log_model:
 
23
  lora_target_modules:
24
  lora_target_linear: true
25
  lora_fan_in_fan_out:
26
+ wandb_project:
27
  wandb_watch:
28
  wandb_run_id:
29
  wandb_log_model:
examples/falcon/config-7b.yml CHANGED
@@ -23,7 +23,7 @@ lora_dropout: 0.0
23
  lora_target_modules:
24
  lora_target_linear: true
25
  lora_fan_in_fan_out:
26
- wandb_project: falcon-7b
27
  wandb_watch:
28
  wandb_run_id:
29
  wandb_log_model:
 
23
  lora_target_modules:
24
  lora_target_linear: true
25
  lora_fan_in_fan_out:
26
+ wandb_project:
27
  wandb_watch:
28
  wandb_run_id:
29
  wandb_log_model:
configs/stability_3b.yml β†’ examples/gptj/qlora.yml RENAMED
@@ -1,38 +1,42 @@
1
- base_model: stabilityai/stablelm-base-alpha-3b
2
- base_model_config: stabilityai/stablelm-base-alpha-3b
3
  load_in_8bit: false
 
 
 
4
  datasets:
5
- - path: vicgalle/alpaca-gpt4
6
  type: alpaca
7
  dataset_prepared_path: last_run_prepared
8
- val_set_size: 0.04
9
- adapter:
10
  lora_model_dir:
11
- sequence_len: 4096
12
- max_packed_sequence_len: 4096
13
  lora_r: 8
14
- lora_alpha: 16
15
  lora_dropout: 0.05
16
  lora_target_modules:
17
- - q_proj
18
- - v_proj
19
- lora_fan_in_fan_out: false
20
- wandb_project: stable-alpaca-3b
21
  wandb_watch:
22
  wandb_run_id:
23
  wandb_log_model:
24
- output_dir: ./stable-alpaca-3b
25
- gradient_accumulation_steps: 1
26
- micro_batch_size: 1
27
- num_epochs: 1
28
- optimizer: adamw_bnb_8bit
29
  torchdistx_path:
30
  lr_scheduler: cosine
31
- learning_rate: 0.0000002
32
  train_on_inputs: false
33
- group_by_length: false
34
  bf16: true
 
35
  tf32: true
 
36
  early_stopping_patience:
37
  resume_from_checkpoint:
38
  local_rank:
@@ -41,16 +45,13 @@ xformers_attention: true
41
  flash_attention:
42
  gptq_groupsize:
43
  gptq_model_v1:
44
- warmup_steps: 100
45
- eval_steps: 50
46
- save_steps: 200
47
  debug:
48
  deepspeed:
49
- weight_decay: 0.01
50
  fsdp:
51
  fsdp_config:
52
- #tokens:
53
- # pad_token: "[PAD]"
54
- # bos_token: "<s>"
55
- # eos_token: "</s>"
56
- # unk_token: "<unk>"
 
1
+ base_model: EleutherAI/gpt-j-6b
2
+ base_model_config: EleutherAI/gpt-j-6b
3
  load_in_8bit: false
4
+ load_in_4bit: true
5
+ strict: false
6
+ push_dataset_to_hub:
7
  datasets:
8
+ - path: teknium/GPT4-LLM-Cleaned
9
  type: alpaca
10
  dataset_prepared_path: last_run_prepared
11
+ val_set_size: 0.01
12
+ adapter: qlora
13
  lora_model_dir:
14
+ sequence_len: 2048
15
+ max_packed_sequence_len:
16
  lora_r: 8
17
+ lora_alpha: 32
18
  lora_dropout: 0.05
19
  lora_target_modules:
20
+ lora_target_linear: true
21
+ lora_fan_in_fan_out:
22
+ wandb_project:
 
23
  wandb_watch:
24
  wandb_run_id:
25
  wandb_log_model:
26
+ output_dir: ./qlora-out
27
+ gradient_accumulation_steps: 2
28
+ micro_batch_size: 2
29
+ num_epochs: 2
30
+ optimizer: paged_adamw_8bit
31
  torchdistx_path:
32
  lr_scheduler: cosine
33
+ learning_rate: 0.0001
34
  train_on_inputs: false
35
+ group_by_length: true
36
  bf16: true
37
+ fp16: false
38
  tf32: true
39
+ gradient_checkpointing: true
40
  early_stopping_patience:
41
  resume_from_checkpoint:
42
  local_rank:
 
45
  flash_attention:
46
  gptq_groupsize:
47
  gptq_model_v1:
48
+ warmup_steps: 10
49
+ eval_steps: 20
50
+ save_steps:
51
  debug:
52
  deepspeed:
53
+ weight_decay: 0.1
54
  fsdp:
55
  fsdp_config:
56
+ special_tokens:
57
+ pad_token: "<|endoftext|>"
 
 
 
examples/gptq-lora-7b/README.md CHANGED
@@ -3,6 +3,6 @@
3
  This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
4
 
5
  ```shell
6
- accelerate launch scripts/finetune.py examples/4bit-lora-7b/config.yml
7
 
8
  ```
 
3
  This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
4
 
5
  ```shell
6
+ accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
7
 
8
  ```
configs/llama_7B_jeopardy.yml β†’ examples/jeopardy-bot/config.yml RENAMED
@@ -7,30 +7,28 @@ datasets:
7
  - path: openaccess-ai-collective/jeopardy
8
  type: jeopardy
9
  dataset_prepared_path: last_run_prepared
10
- val_set_size: 0.01
11
  adapter:
12
  lora_model_dir:
13
- sequence_len: 2048
14
- max_packed_sequence_len: 2048
15
- lora_r: 8
16
- lora_alpha: 16
17
- lora_dropout: 0.05
18
  lora_target_modules:
19
- - q_proj
20
- - v_proj
21
  lora_fan_in_fan_out: false
22
- wandb_project: jeopardy-bot-7b
23
  wandb_watch:
24
  wandb_run_id:
25
  wandb_log_model:
26
  output_dir: ./jeopardy-bot-7b
27
- gradient_accumulation_steps: 2
28
  micro_batch_size: 1
29
- num_epochs: 2
30
  optimizer: adamw_bnb_8bit
31
  torchdistx_path:
32
  lr_scheduler: cosine
33
- learning_rate: 0.0000002
34
  train_on_inputs: false
35
  group_by_length: false
36
  bf16: true
@@ -48,11 +46,10 @@ eval_steps: 110
48
  save_steps: 660
49
  debug:
50
  deepspeed:
51
- weight_decay: 0.0001
52
  fsdp:
53
  fsdp_config:
54
  tokens:
55
- pad_token: "[PAD]"
56
  bos_token: "<s>"
57
  eos_token: "</s>"
58
  unk_token: "<unk>"
 
7
  - path: openaccess-ai-collective/jeopardy
8
  type: jeopardy
9
  dataset_prepared_path: last_run_prepared
10
+ val_set_size: 0.02
11
  adapter:
12
  lora_model_dir:
13
+ sequence_len: 512
14
+ max_packed_sequence_len:
15
+ lora_r:
16
+ lora_alpha:
17
+ lora_dropout:
18
  lora_target_modules:
 
 
19
  lora_fan_in_fan_out: false
20
+ wandb_project:
21
  wandb_watch:
22
  wandb_run_id:
23
  wandb_log_model:
24
  output_dir: ./jeopardy-bot-7b
25
+ gradient_accumulation_steps: 1
26
  micro_batch_size: 1
27
+ num_epochs: 3
28
  optimizer: adamw_bnb_8bit
29
  torchdistx_path:
30
  lr_scheduler: cosine
31
+ learning_rate: 0.00003
32
  train_on_inputs: false
33
  group_by_length: false
34
  bf16: true
 
46
  save_steps: 660
47
  debug:
48
  deepspeed:
49
+ weight_decay: 0.1
50
  fsdp:
51
  fsdp_config:
52
  tokens:
 
53
  bos_token: "<s>"
54
  eos_token: "</s>"
55
  unk_token: "<unk>"
examples/openllama-3b/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # openllama-3b
2
+
3
+ Basic full tune
4
+ ```shell
5
+ accelerate launch scripts/finetune.py examples/openllama-3b/config.yml
6
+ ```
7
+
8
+ LoRA
9
+ ```shell
10
+ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
11
+ ```
12
+
13
+ QLoRA
14
+ ```shell
15
+ accelerate launch scripts/finetune.py examples/openllama-3b/qlora.yml
16
+ ```
examples/openllama-3b/config.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: openlm-research/open_llama_3b
2
+ base_model_config: openlm-research/open_llama_3b
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: LlamaTokenizer
5
+ load_in_8bit: false
6
+ load_in_4bit: false
7
+ strict: false
8
+ push_dataset_to_hub:
9
+ datasets:
10
+ - path: teknium/GPT4-LLM-Cleaned
11
+ type: alpaca
12
+ dataset_prepared_path: last_run_prepared
13
+ val_set_size: 0.02
14
+ adapter:
15
+ lora_model_dir:
16
+ sequence_len: 256
17
+ max_packed_sequence_len:
18
+ lora_r:
19
+ lora_alpha:
20
+ lora_dropout:
21
+ lora_target_modules:
22
+ lora_target_linear:
23
+ lora_fan_in_fan_out:
24
+ wandb_project:
25
+ wandb_watch:
26
+ wandb_run_id:
27
+ wandb_log_model:
28
+ output_dir: ./openllama-out
29
+ batch_size: 16
30
+ micro_batch_size: 4
31
+ num_epochs: 3
32
+ optimizer: adamw_bnb_8bit
33
+ torchdistx_path:
34
+ lr_scheduler: cosine
35
+ learning_rate: 0.0002
36
+ train_on_inputs: false
37
+ group_by_length: false
38
+ bf16: false
39
+ fp16: true
40
+ tf32: false
41
+ gradient_checkpointing: true
42
+ early_stopping_patience:
43
+ resume_from_checkpoint:
44
+ local_rank:
45
+ logging_steps: 1
46
+ xformers_attention: true
47
+ flash_attention:
48
+ gptq_groupsize:
49
+ gptq_model_v1:
50
+ warmup_steps: 10
51
+ eval_steps: 50
52
+ save_steps:
53
+ debug:
54
+ deepspeed:
55
+ weight_decay: 0.0
56
+ fsdp:
57
+ fsdp_config:
58
+ special_tokens:
59
+ bos_token: "<s>"
60
+ eos_token: "</s>"
61
+ unk_token: "<unk>"
examples/{lora-openllama-3b/config.yml β†’ openllama-3b/lora.yml} RENAMED
@@ -1,5 +1,5 @@
1
- base_model: openlm-research/open_llama_3b_600bt_preview
2
- base_model_config: openlm-research/open_llama_3b_600bt_preview
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
5
  load_in_8bit: true
@@ -49,7 +49,7 @@ early_stopping_patience:
49
  resume_from_checkpoint:
50
  local_rank:
51
  logging_steps: 1
52
- xformers_attention:
53
  flash_attention:
54
  gptq_groupsize:
55
  gptq_model_v1:
 
1
+ base_model: openlm-research/open_llama_3b
2
+ base_model_config: openlm-research/open_llama_3b
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
5
  load_in_8bit: true
 
49
  resume_from_checkpoint:
50
  local_rank:
51
  logging_steps: 1
52
+ xformers_attention: true
53
  flash_attention:
54
  gptq_groupsize:
55
  gptq_model_v1:
examples/{qlora-openllama-3b/config.yml β†’ openllama-3b/qlora.yml} RENAMED
@@ -1,5 +1,5 @@
1
- base_model: openlm-research/open_llama_3b_600bt_preview
2
- base_model_config: openlm-research/open_llama_3b_600bt_preview
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
5
  load_in_8bit: false
 
1
+ base_model: openlm-research/open_llama_3b
2
+ base_model_config: openlm-research/open_llama_3b
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
5
  load_in_8bit: false
configs/pythia_1_2B_alpaca.yml β†’ examples/pythia/lora.yml RENAMED
@@ -1,36 +1,29 @@
1
  base_model: EleutherAI/pythia-1.4b-deduped
2
- model_type: GPTNeoXForCausalLM
3
- tokenizer_type: AutoTokenizer
4
  load_in_8bit: true
5
  datasets:
6
- - path: data/alpaca_data_gpt4.jsonl
7
  type: alpaca
8
- - path: data/vicuna_cleaned.jsonl
9
- type: sharegpt
10
- - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
11
- type: gpteacher
12
- - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
13
- type: gpteacher
14
  dataset_prepared_path: last_run_prepared
15
  val_set_size: 0.05
16
  adapter: lora
17
  lora_model_dir:
18
- sequence_len: 2048
19
- lora_r: 8
20
  lora_alpha: 32
21
  lora_dropout: 0.05
22
  lora_target_modules:
23
  - query_key_value
24
- # - xxx
25
  lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
26
- wandb_project: pythia-1.4b-lora
27
  wandb_watch:
28
  wandb_run_id:
29
  wandb_log_model:
30
- output_dir: ./lora-alpaca
31
  gradient_accumulation_steps: 1
32
  micro_batch_size: 4
33
- num_epochs: 5
34
  learning_rate: 0.00001
35
  train_on_inputs: false
36
  group_by_length: false
@@ -39,3 +32,6 @@ tf32: True
39
  early_stopping_patience:
40
  resume_from_checkpoint:
41
  local_rank:
 
 
 
 
1
  base_model: EleutherAI/pythia-1.4b-deduped
2
+ base_model_config: EleutherAI/pythia-1.4b-deduped
 
3
  load_in_8bit: true
4
  datasets:
5
+ - path: teknium/GPT4-LLM-Cleaned
6
  type: alpaca
 
 
 
 
 
 
7
  dataset_prepared_path: last_run_prepared
8
  val_set_size: 0.05
9
  adapter: lora
10
  lora_model_dir:
11
+ sequence_len: 512
12
+ lora_r: 16
13
  lora_alpha: 32
14
  lora_dropout: 0.05
15
  lora_target_modules:
16
  - query_key_value
17
+ lora_target_linear:
18
  lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
19
+ wandb_project:
20
  wandb_watch:
21
  wandb_run_id:
22
  wandb_log_model:
23
+ output_dir: ./lora-alpaca-pythia
24
  gradient_accumulation_steps: 1
25
  micro_batch_size: 4
26
+ num_epochs: 3
27
  learning_rate: 0.00001
28
  train_on_inputs: false
29
  group_by_length: false
 
32
  early_stopping_patience:
33
  resume_from_checkpoint:
34
  local_rank:
35
+ weight_decay: 0.1
36
+ eval_steps: 20
37
+ logging_steps: 1
examples/qlora-openllama-3b/README.md DELETED
@@ -1,6 +0,0 @@
1
- # qlora-openllama-3b
2
-
3
- ```shell
4
- accelerate launch scripts/finetune.py examples/qlora-openllama-3b/config.yml
5
-
6
- ```
 
 
 
 
 
 
 
scripts/finetune.py CHANGED
@@ -72,7 +72,19 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
72
  if not (cfg.special_tokens and token in cfg.special_tokens):
73
  tokenizer.add_special_tokens({token: symbol})
74
 
75
- prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  while True:
78
  print("=" * 80)
@@ -80,10 +92,14 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
80
  instruction = get_multi_line_input()
81
  if not instruction:
82
  return
83
- prompt: str = next(
84
- prompter_module().build_prompt(instruction=instruction.strip("\n"))
85
- )
 
 
 
86
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
 
87
  print("=" * 40)
88
  model.eval()
89
  with torch.no_grad():
@@ -159,7 +175,7 @@ def train(
159
  cfg_keys = cfg.keys()
160
  for k, _ in kwargs.items():
161
  # if not strict, allow writing to cfg even if it's not in the yml already
162
- if k in cfg_keys or cfg.strict is False:
163
  # handle booleans
164
  if isinstance(cfg[k], bool):
165
  cfg[k] = bool(kwargs[k])
@@ -199,8 +215,8 @@ def train(
199
  logging.info(f"loading tokenizer... {tokenizer_config}")
200
  tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
201
 
202
- if check_not_in(
203
- ["inference", "shard", "merge_lora"], kwargs
204
  ): # don't need to load dataset for these
205
  if not cfg.pretraining_dataset:
206
  train_dataset, eval_dataset = load_prepare_datasets(
@@ -239,7 +255,6 @@ def train(
239
  tokenizer,
240
  cfg,
241
  adapter=cfg.adapter,
242
- inference=("inference" in kwargs),
243
  )
244
 
245
  if "merge_lora" in kwargs and cfg.adapter is not None:
@@ -252,9 +267,15 @@ def train(
252
  model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
253
  return
254
 
255
- if "inference" in kwargs:
256
  logging.info("calling do_inference function")
257
- do_inference(cfg, model, tokenizer)
 
 
 
 
 
 
258
  return
259
 
260
  if "shard" in kwargs:
 
72
  if not (cfg.special_tokens and token in cfg.special_tokens):
73
  tokenizer.add_special_tokens({token: symbol})
74
 
75
+ prompter_module = None
76
+ if prompter:
77
+ prompter_module = getattr(
78
+ importlib.import_module("axolotl.prompters"), prompter
79
+ )
80
+
81
+ if cfg.landmark_attention:
82
+ from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
83
+
84
+ set_model_mem_id(model, tokenizer)
85
+ model.set_mem_cache_args(
86
+ max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
87
+ )
88
 
89
  while True:
90
  print("=" * 80)
 
92
  instruction = get_multi_line_input()
93
  if not instruction:
94
  return
95
+ if prompter_module:
96
+ prompt: str = next(
97
+ prompter_module().build_prompt(instruction=instruction.strip("\n"))
98
+ )
99
+ else:
100
+ prompt = instruction.strip()
101
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
102
+
103
  print("=" * 40)
104
  model.eval()
105
  with torch.no_grad():
 
175
  cfg_keys = cfg.keys()
176
  for k, _ in kwargs.items():
177
  # if not strict, allow writing to cfg even if it's not in the yml already
178
+ if k in cfg_keys or not cfg.strict:
179
  # handle booleans
180
  if isinstance(cfg[k], bool):
181
  cfg[k] = bool(kwargs[k])
 
215
  logging.info(f"loading tokenizer... {tokenizer_config}")
216
  tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
217
 
218
+ if (
219
+ check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
220
  ): # don't need to load dataset for these
221
  if not cfg.pretraining_dataset:
222
  train_dataset, eval_dataset = load_prepare_datasets(
 
255
  tokenizer,
256
  cfg,
257
  adapter=cfg.adapter,
 
258
  )
259
 
260
  if "merge_lora" in kwargs and cfg.adapter is not None:
 
267
  model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
268
  return
269
 
270
+ if cfg.inference:
271
  logging.info("calling do_inference function")
272
+ inf_kwargs: Dict[str, Any] = {}
273
+ if "prompter" in kwargs:
274
+ if kwargs["prompter"] == "None":
275
+ inf_kwargs["prompter"] = None
276
+ else:
277
+ inf_kwargs["prompter"] = kwargs["prompter"]
278
+ do_inference(cfg, model, tokenizer, **inf_kwargs)
279
  return
280
 
281
  if "shard" in kwargs:
src/axolotl/datasets.py CHANGED
@@ -33,12 +33,16 @@ class TokenizedPromptDataset(IterableDataset):
33
 
34
  def __iter__(self):
35
  iterator = iter(self.dataset)
 
36
  # Loop through the entire dataset
37
  for example in iterator:
38
  try:
39
  yield self.prompt_tokenizer.tokenize_prompt(example)
 
40
  except InvalidDataException:
41
  pass
 
 
42
 
43
 
44
  # TODO this isn't the best since it can't interleave datasets
 
33
 
34
  def __iter__(self):
35
  iterator = iter(self.dataset)
36
+ count = 0
37
  # Loop through the entire dataset
38
  for example in iterator:
39
  try:
40
  yield self.prompt_tokenizer.tokenize_prompt(example)
41
+ count += 1
42
  except InvalidDataException:
43
  pass
44
+ if count == 0:
45
+ raise RuntimeError("Expected at least one datapoint in dataset.")
46
 
47
 
48
  # TODO this isn't the best since it can't interleave datasets
src/axolotl/monkeypatch/llama_landmark_attn.py CHANGED
@@ -28,15 +28,24 @@ from typing import List, Optional, Tuple, Union
28
  import torch
29
  import torch.utils.checkpoint
30
  from torch import nn
31
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
32
- from transformers.activations import ACT2FN
33
  from transformers.modeling_outputs import (
34
  BaseModelOutputWithPast,
35
  CausalLMOutputWithPast,
36
- SequenceClassifierOutputWithPast,
37
  )
38
- from transformers.modeling_utils import PreTrainedModel
39
  from transformers.models.llama.configuration_llama import LlamaConfig
 
 
 
 
 
 
 
 
 
 
 
40
  from transformers.utils import (
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
@@ -51,131 +60,6 @@ _CONFIG_FOR_DOC = "LlamaConfig"
51
  MEM_TOKEN = "<landmark>" # nosec
52
 
53
 
54
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
55
- def _make_causal_mask(
56
- input_ids_shape: torch.Size,
57
- dtype: torch.dtype,
58
- device: torch.device,
59
- past_key_values_length: int = 0,
60
- ):
61
- """
62
- Make causal mask used for bi-directional self-attention.
63
- """
64
- bsz, tgt_len = input_ids_shape
65
- mask = torch.full(
66
- (tgt_len, tgt_len),
67
- torch.tensor(torch.finfo(dtype).min, device=device),
68
- device=device,
69
- )
70
- mask_cond = torch.arange(mask.size(-1), device=device)
71
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
72
- mask = mask.to(dtype)
73
-
74
- if past_key_values_length > 0:
75
- mask = torch.cat(
76
- [
77
- torch.zeros(
78
- tgt_len, past_key_values_length, dtype=dtype, device=device
79
- ),
80
- mask,
81
- ],
82
- dim=-1,
83
- )
84
- return mask[None, None, :, :].expand(
85
- bsz, 1, tgt_len, tgt_len + past_key_values_length
86
- )
87
-
88
-
89
- # Copied from transformers.models.bart.modeling_bart._expand_mask
90
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
91
- """
92
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
93
- """
94
- bsz, src_len = mask.size()
95
- tgt_len = tgt_len if tgt_len is not None else src_len
96
-
97
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
98
-
99
- inverted_mask = 1.0 - expanded_mask
100
-
101
- return inverted_mask.masked_fill(
102
- inverted_mask.to(torch.bool), torch.finfo(dtype).min
103
- )
104
-
105
-
106
- class LlamaRMSNorm(nn.Module):
107
- def __init__(self, hidden_size, eps=1e-6):
108
- """
109
- LlamaRMSNorm is equivalent to T5LayerNorm
110
- """
111
- super().__init__()
112
- self.weight = nn.Parameter(torch.ones(hidden_size))
113
- self.variance_epsilon = eps
114
-
115
- def forward(self, hidden_states):
116
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
117
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
118
-
119
- # convert into half-precision if necessary
120
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
121
- hidden_states = hidden_states.to(self.weight.dtype)
122
-
123
- return self.weight * hidden_states
124
-
125
-
126
- class LlamaRotaryEmbedding(torch.nn.Module):
127
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
128
- super().__init__()
129
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
130
- self.register_buffer("inv_freq", inv_freq)
131
-
132
- # Build here to make `torch.jit.trace` work.
133
- self.max_seq_len_cached = max_position_embeddings
134
- t = torch.arange(
135
- self.max_seq_len_cached,
136
- device=self.inv_freq.device,
137
- dtype=self.inv_freq.dtype,
138
- )
139
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
140
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
141
- emb = torch.cat((freqs, freqs), dim=-1)
142
- self.register_buffer(
143
- "cos_cached", emb.cos()[None, None, :, :], persistent=False
144
- )
145
- self.register_buffer(
146
- "sin_cached", emb.sin()[None, None, :, :], persistent=False
147
- )
148
-
149
- def forward(self, x, seq_len=None):
150
- # x: [bs, num_attention_heads, seq_len, head_size]
151
- # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
152
- if seq_len > self.max_seq_len_cached:
153
- self.max_seq_len_cached = seq_len
154
- t = torch.arange(
155
- self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
156
- )
157
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
158
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
159
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
160
- self.register_buffer(
161
- "cos_cached", emb.cos()[None, None, :, :], persistent=False
162
- )
163
- self.register_buffer(
164
- "sin_cached", emb.sin()[None, None, :, :], persistent=False
165
- )
166
- return (
167
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
168
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
169
- )
170
-
171
-
172
- def rotate_half(x):
173
- """Rotates half the hidden dims of the input."""
174
- x1 = x[..., : x.shape[-1] // 2]
175
- x2 = x[..., x.shape[-1] // 2 :]
176
- return torch.cat((-x2, x1), dim=-1)
177
-
178
-
179
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
180
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
181
  cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
@@ -190,24 +74,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
190
  return q_embed, k_embed
191
 
192
 
193
- class LlamaMLP(nn.Module):
194
- def __init__(
195
- self,
196
- hidden_size: int,
197
- intermediate_size: int,
198
- hidden_act: str,
199
- ):
200
- super().__init__()
201
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
202
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
203
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
204
- self.act_fn = ACT2FN[hidden_act]
205
-
206
- def forward(self, x):
207
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
208
-
209
-
210
  class LandmarkGroupedSoftmaxFunction(torch.autograd.Function):
 
 
 
 
211
  # Note that forward, setup_context, and backward are @staticmethods
212
  @staticmethod
213
  def forward(ctx, x, dim, mem_cnt, resp_mem_idx):
@@ -682,16 +553,14 @@ class LlamaAttention(nn.Module):
682
  # upcast attention to fp32
683
  if is_mem is None:
684
  raise ValueError("Don't use this without landmarks")
685
- # attn_weights = nn.functional.softmax(
686
- # attn_weights, dim=-1, dtype=torch.float32
687
- # ).to(query_states.dtype)
688
- else:
689
- attn_weights = landmark_grouped_softmax(
690
- attn_weights,
691
- dim=-1,
692
- is_mem=is_mem.expand(-1, self.num_heads, -1, -1),
693
- last_section_mask=last_section_mask,
694
- ).to(query_states.dtype)
695
  if attn_prefix is not None:
696
  attn_prefix, attn_weights = torch.split(
697
  attn_weights,
@@ -722,6 +591,10 @@ class LlamaAttention(nn.Module):
722
 
723
 
724
  class LlamaDecoderLayer(nn.Module):
 
 
 
 
725
  def __init__(self, config: LlamaConfig):
726
  super().__init__()
727
  self.hidden_size = config.hidden_size
@@ -802,114 +675,6 @@ class LlamaDecoderLayer(nn.Module):
802
  return outputs
803
 
804
 
805
- LLAMA_START_DOCSTRING = r"""
806
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
807
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
808
- etc.)
809
-
810
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
811
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
812
- and behavior.
813
-
814
- Parameters:
815
- config ([`LlamaConfig`]):
816
- Model configuration class with all the parameters of the model. Initializing with a config file does not
817
- load the weights associated with the model, only the configuration. Check out the
818
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
819
- """
820
-
821
-
822
- @add_start_docstrings(
823
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
824
- LLAMA_START_DOCSTRING,
825
- )
826
- class LlamaPreTrainedModel(PreTrainedModel):
827
- config_class = LlamaConfig
828
- base_model_prefix = "model"
829
- supports_gradient_checkpointing = True
830
- _no_split_modules = ["LlamaDecoderLayer"]
831
- _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
832
-
833
- def _init_weights(self, module):
834
- std = self.config.initializer_range
835
- if isinstance(module, nn.Linear):
836
- module.weight.data.normal_(mean=0.0, std=std)
837
- if module.bias is not None:
838
- module.bias.data.zero_()
839
- elif isinstance(module, nn.Embedding):
840
- module.weight.data.normal_(mean=0.0, std=std)
841
- if module.padding_idx is not None:
842
- module.weight.data[module.padding_idx].zero_()
843
-
844
- def _set_gradient_checkpointing(self, module, value=False):
845
- if isinstance(module, LlamaModel):
846
- module.gradient_checkpointing = value
847
-
848
-
849
- LLAMA_INPUTS_DOCSTRING = r"""
850
- Args:
851
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
852
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
853
- it.
854
-
855
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
856
- [`PreTrainedTokenizer.__call__`] for details.
857
-
858
- [What are input IDs?](../glossary#input-ids)
859
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
860
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
861
-
862
- - 1 for tokens that are **not masked**,
863
- - 0 for tokens that are **masked**.
864
-
865
- [What are attention masks?](../glossary#attention-mask)
866
-
867
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
868
- [`PreTrainedTokenizer.__call__`] for details.
869
-
870
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
871
- `past_key_values`).
872
-
873
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
874
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
875
- information on the default strategy.
876
-
877
- - 1 indicates the head is **not masked**,
878
- - 0 indicates the head is **masked**.
879
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
880
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
881
- config.n_positions - 1]`.
882
-
883
- [What are position IDs?](../glossary#position-ids)
884
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
885
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
886
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
887
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
888
-
889
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
890
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
891
-
892
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
893
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
894
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
895
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
896
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
897
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
898
- model's internal embedding lookup matrix.
899
- use_cache (`bool`, *optional*):
900
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
901
- `past_key_values`).
902
- output_attentions (`bool`, *optional*):
903
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
904
- tensors for more detail.
905
- output_hidden_states (`bool`, *optional*):
906
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
907
- more detail.
908
- return_dict (`bool`, *optional*):
909
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
910
- """
911
-
912
-
913
  @add_start_docstrings(
914
  "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
915
  LLAMA_START_DOCSTRING,
@@ -1178,6 +943,10 @@ class LlamaModel(LlamaPreTrainedModel):
1178
 
1179
 
1180
  class LlamaForCausalLM(LlamaPreTrainedModel):
 
 
 
 
1181
  def __init__(self, config):
1182
  super().__init__(config)
1183
  self.model = LlamaModel(config)
@@ -1448,148 +1217,33 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1448
  return reordered_past
1449
 
1450
 
1451
- @add_start_docstrings(
1452
- """
1453
- The LLaMa Model transformer with a sequence classification head on top (linear layer).
1454
-
1455
- [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1456
- (e.g. GPT-2) do.
1457
-
1458
- Since it does classification on the last token, it requires to know the position of the last token. If a
1459
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1460
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1461
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1462
- each row of the batch).
1463
- """,
1464
- LLAMA_START_DOCSTRING,
1465
- )
1466
- class LlamaForSequenceClassification(LlamaPreTrainedModel):
1467
- _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
1468
-
1469
- def __init__(self, config):
1470
- super().__init__(config)
1471
- self.num_labels = config.num_labels
1472
- self.model = LlamaModel(config)
1473
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1474
-
1475
- # Initialize weights and apply final processing
1476
- self.post_init()
1477
-
1478
- def get_input_embeddings(self):
1479
- return self.model.embed_tokens
1480
-
1481
- def set_input_embeddings(self, value):
1482
- self.model.embed_tokens = value
1483
-
1484
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1485
- def forward(
1486
- self,
1487
- input_ids: torch.LongTensor = None,
1488
- attention_mask: Optional[torch.Tensor] = None,
1489
- position_ids: Optional[torch.LongTensor] = None,
1490
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1491
- inputs_embeds: Optional[torch.FloatTensor] = None,
1492
- labels: Optional[torch.LongTensor] = None,
1493
- use_cache: Optional[bool] = None,
1494
- output_attentions: Optional[bool] = None,
1495
- output_hidden_states: Optional[bool] = None,
1496
- return_dict: Optional[bool] = None,
1497
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1498
- r"""
1499
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1500
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1501
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1502
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1503
- """
1504
- return_dict = (
1505
- return_dict if return_dict is not None else self.config.use_return_dict
1506
- )
1507
-
1508
- transformer_outputs = self.model(
1509
- input_ids,
1510
- attention_mask=attention_mask,
1511
- position_ids=position_ids,
1512
- past_key_values=past_key_values,
1513
- inputs_embeds=inputs_embeds,
1514
- use_cache=use_cache,
1515
- output_attentions=output_attentions,
1516
- output_hidden_states=output_hidden_states,
1517
- return_dict=return_dict,
1518
- )
1519
- hidden_states = transformer_outputs[0]
1520
- logits = self.score(hidden_states)
1521
-
1522
- if input_ids is not None:
1523
- batch_size = input_ids.shape[0]
1524
- else:
1525
- batch_size = inputs_embeds.shape[0]
1526
-
1527
- if self.config.pad_token_id is None and batch_size != 1:
1528
- raise ValueError(
1529
- "Cannot handle batch sizes > 1 if no padding token is defined."
1530
- )
1531
- if self.config.pad_token_id is None:
1532
- sequence_lengths = -1
1533
- else:
1534
- if input_ids is not None:
1535
- sequence_lengths = (
1536
- torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1537
- ).to(logits.device)
1538
- else:
1539
- sequence_lengths = -1
1540
-
1541
- pooled_logits = logits[
1542
- torch.arange(batch_size, device=logits.device), sequence_lengths
1543
- ]
1544
-
1545
- loss = None
1546
- if labels is not None:
1547
- labels = labels.to(logits.device)
1548
- if self.config.problem_type is None:
1549
- if self.num_labels == 1:
1550
- self.config.problem_type = "regression"
1551
- elif self.num_labels > 1 and (
1552
- labels.dtype == torch.long or labels.dtype == torch.int
1553
- ):
1554
- self.config.problem_type = "single_label_classification"
1555
- else:
1556
- self.config.problem_type = "multi_label_classification"
1557
-
1558
- if self.config.problem_type == "regression":
1559
- loss_fct = MSELoss()
1560
- if self.num_labels == 1:
1561
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1562
- else:
1563
- loss = loss_fct(pooled_logits, labels)
1564
- elif self.config.problem_type == "single_label_classification":
1565
- loss_fct = CrossEntropyLoss()
1566
- loss = loss_fct(
1567
- pooled_logits.view(-1, self.num_labels), labels.view(-1)
1568
- )
1569
- elif self.config.problem_type == "multi_label_classification":
1570
- loss_fct = BCEWithLogitsLoss()
1571
- loss = loss_fct(pooled_logits, labels)
1572
- if not return_dict:
1573
- output = (pooled_logits,) + transformer_outputs[1:]
1574
- return ((loss,) + output) if loss is not None else output
1575
-
1576
- return SequenceClassifierOutputWithPast(
1577
- loss=loss,
1578
- logits=pooled_logits,
1579
- past_key_values=transformer_outputs.past_key_values,
1580
- hidden_states=transformer_outputs.hidden_states,
1581
- attentions=transformer_outputs.attentions,
1582
- )
1583
-
1584
-
1585
  def add_mem_tokens(example, mem_freq, mem_id):
1586
- x = example["input_ids"]
1587
  ret = []
1588
  prev_idx = 0
1589
- for t_idx in range(mem_freq, len(x), mem_freq):
1590
- ret.extend(x[prev_idx:t_idx])
1591
  ret.append(mem_id)
1592
  prev_idx = t_idx
1593
- ret.extend(x[prev_idx:])
1594
  # drop attention_mask
1595
  return {"input_ids": ret}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  import torch
29
  import torch.utils.checkpoint
30
  from torch import nn
31
+ from torch.nn import CrossEntropyLoss
32
+ from transformers import LlamaTokenizer
33
  from transformers.modeling_outputs import (
34
  BaseModelOutputWithPast,
35
  CausalLMOutputWithPast,
 
36
  )
 
37
  from transformers.models.llama.configuration_llama import LlamaConfig
38
+ from transformers.models.llama.modeling_llama import (
39
+ LLAMA_INPUTS_DOCSTRING,
40
+ LLAMA_START_DOCSTRING,
41
+ LlamaMLP,
42
+ LlamaPreTrainedModel,
43
+ LlamaRMSNorm,
44
+ LlamaRotaryEmbedding,
45
+ _expand_mask,
46
+ _make_causal_mask,
47
+ rotate_half,
48
+ )
49
  from transformers.utils import (
50
  add_start_docstrings,
51
  add_start_docstrings_to_model_forward,
 
60
  MEM_TOKEN = "<landmark>" # nosec
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
64
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
65
  cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
 
74
  return q_embed, k_embed
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  class LandmarkGroupedSoftmaxFunction(torch.autograd.Function):
78
+ """
79
+ Landmark grouped softmax function.
80
+ """
81
+
82
  # Note that forward, setup_context, and backward are @staticmethods
83
  @staticmethod
84
  def forward(ctx, x, dim, mem_cnt, resp_mem_idx):
 
553
  # upcast attention to fp32
554
  if is_mem is None:
555
  raise ValueError("Don't use this without landmarks")
556
+
557
+ attn_weights = landmark_grouped_softmax(
558
+ attn_weights,
559
+ dim=-1,
560
+ is_mem=is_mem.expand(-1, self.num_heads, -1, -1),
561
+ last_section_mask=last_section_mask,
562
+ ).to(query_states.dtype)
563
+
 
 
564
  if attn_prefix is not None:
565
  attn_prefix, attn_weights = torch.split(
566
  attn_weights,
 
591
 
592
 
593
  class LlamaDecoderLayer(nn.Module):
594
+ """
595
+ Llama Decoder layer
596
+ """
597
+
598
  def __init__(self, config: LlamaConfig):
599
  super().__init__()
600
  self.hidden_size = config.hidden_size
 
675
  return outputs
676
 
677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
  @add_start_docstrings(
679
  "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
680
  LLAMA_START_DOCSTRING,
 
943
 
944
 
945
  class LlamaForCausalLM(LlamaPreTrainedModel):
946
+ """
947
+ Llama model with a causal language modeling head.
948
+ """
949
+
950
  def __init__(self, config):
951
  super().__init__(config)
952
  self.model = LlamaModel(config)
 
1217
  return reordered_past
1218
 
1219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1220
  def add_mem_tokens(example, mem_freq, mem_id):
1221
+ ids = example["input_ids"]
1222
  ret = []
1223
  prev_idx = 0
1224
+ for t_idx in range(mem_freq, len(ids), mem_freq):
1225
+ ret.extend(ids[prev_idx:t_idx])
1226
  ret.append(mem_id)
1227
  prev_idx = t_idx
1228
+ ret.extend(ids[prev_idx:])
1229
  # drop attention_mask
1230
  return {"input_ids": ret}
1231
+
1232
+
1233
+ def patch_llama_with_landmark_attn():
1234
+ import transformers
1235
+
1236
+ transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
1237
+ transformers.models.llama.modeling_llama.LlamaModel = LlamaModel
1238
+ transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
1239
+ transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
1240
+ transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
1241
+
1242
+
1243
+ def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer):
1244
+ mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
1245
+ model.set_mem_id(mem_id)
1246
+
1247
+
1248
+ def get_mem_id(tokenizer: LlamaTokenizer):
1249
+ return tokenizer.convert_tokens_to_ids(MEM_TOKEN)
src/axolotl/prompt_strategies/sharegpt_jokes.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for Jokes prompts using sharegpt style """
2
+ from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
3
+ from axolotl.prompters import PromptStyle, ShareGPTPrompter
4
+
5
+
6
+ def load(tokenizer, cfg):
7
+ return SimpleJokesShareGPTPromptTokenizingStrategy(
8
+ ShareGPTPrompter(PromptStyle.CHAT.value),
9
+ tokenizer,
10
+ cfg.train_on_inputs,
11
+ cfg.sequence_len,
12
+ )
13
+
14
+
15
+ class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
16
+ """
17
+ Tokenization strategy for asking bot to tell a joke and then explain why its funny
18
+ """
19
+
20
+ # title, text, explanation
21
+ def get_conversation_thread(self, prompt):
22
+ title = "" if not prompt["title"] else prompt["title"] + " "
23
+ return [
24
+ {"from": "human", "value": "Tell me a joke."},
25
+ {"from": "gpt", "value": title + prompt["text"]},
26
+ {"from": "human", "value": "Why is that joke funny?"},
27
+ {"from": "gpt", "value": prompt["explanation"]},
28
+ ]
src/axolotl/prompt_strategies/sharegpt_simple.py CHANGED
@@ -13,6 +13,15 @@ def load(tokenizer, cfg):
13
  )
14
 
15
 
 
 
 
 
 
 
 
 
 
16
  def load_guanaco(tokenizer, cfg):
17
  return GuanacoShareGPTPromptTokenizingStrategy(
18
  ShareGPTPrompter(PromptStyle.CHAT.value),
@@ -31,6 +40,18 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
31
  return prompt["conversations"]
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
35
  """
36
  sharegpt strategy that remaps oasst data to sharegpt format
 
13
  )
14
 
15
 
16
+ def load_role(tokenizer, cfg):
17
+ return SimpleRoleShareGPTPromptTokenizingStrategy(
18
+ ShareGPTPrompter(PromptStyle.CHAT.value),
19
+ tokenizer,
20
+ cfg.train_on_inputs,
21
+ cfg.sequence_len,
22
+ )
23
+
24
+
25
  def load_guanaco(tokenizer, cfg):
26
  return GuanacoShareGPTPromptTokenizingStrategy(
27
  ShareGPTPrompter(PromptStyle.CHAT.value),
 
40
  return prompt["conversations"]
41
 
42
 
43
+ class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
44
+ """
45
+ basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
46
+ """
47
+
48
+ def get_conversation_thread(self, prompt):
49
+ conversations = prompt["conversations"]
50
+ # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
51
+ turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
52
+ return turns
53
+
54
+
55
  class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
56
  """
57
  sharegpt strategy that remaps oasst data to sharegpt format
src/axolotl/prompters.py CHANGED
@@ -261,28 +261,33 @@ class Conversation:
261
  self.messages.append([role, message])
262
 
263
 
264
- conv_vicuna_v1_1 = Conversation(
265
- system="A chat between a curious user and an artificial intelligence assistant. "
266
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
267
- roles=["USER", "ASSISTANT"],
268
- messages=[],
269
- offset=0,
270
- sep_style=SeparatorStyle.TWO,
271
- sep=" ",
272
- sep2=" ",
273
- )
274
-
275
-
276
  class ShareGPTPrompter: # pylint: disable=too-few-public-methods
277
  """
278
  A prompter that generates prompts for the ShareGPT
279
  """
280
 
281
- def __init__(self, prompt_style=None):
282
  if prompt_style != PromptStyle.CHAT.value:
283
  raise ValueError(
284
  f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
285
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  # def match_prompt_style(self):
288
  # if self.prompt_style == PromptStyle.chat.value:
@@ -300,7 +305,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
300
  # also happens on the data splitting leaving empty conversations
301
  raise IndexError
302
 
303
- conv = conv_vicuna_v1_1.copy()
304
  roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
305
 
306
  try:
 
261
  self.messages.append([role, message])
262
 
263
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  class ShareGPTPrompter: # pylint: disable=too-few-public-methods
265
  """
266
  A prompter that generates prompts for the ShareGPT
267
  """
268
 
269
+ def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
270
  if prompt_style != PromptStyle.CHAT.value:
271
  raise ValueError(
272
  f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
273
  )
274
+ system: str = (
275
+ system_prompt
276
+ if system_prompt
277
+ else (
278
+ "A chat between a curious user and an artificial intelligence assistant. "
279
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
280
+ )
281
+ )
282
+ self._conversation = Conversation(
283
+ system=system,
284
+ roles=["USER", "ASSISTANT"],
285
+ messages=[],
286
+ offset=0,
287
+ sep_style=SeparatorStyle.TWO,
288
+ sep=" ",
289
+ sep2=" ",
290
+ )
291
 
292
  # def match_prompt_style(self):
293
  # if self.prompt_style == PromptStyle.chat.value:
 
305
  # also happens on the data splitting leaving empty conversations
306
  raise IndexError
307
 
308
+ conv = self._conversation.copy()
309
  roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
310
 
311
  try:
src/axolotl/utils/data.py CHANGED
@@ -240,8 +240,15 @@ def load_tokenized_prepared_datasets(
240
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
241
  datasets.append(ds_wrapper)
242
  else:
243
- logging.error(f"unhandled prompt tokenization strategy: {d.type}")
244
- raise ValueError(f"unhandled prompt tokenization strategy: {d.type}")
 
 
 
 
 
 
 
245
  logging.info("tokenizing, merging, and shuffling master dataset")
246
 
247
  samples: List[int] = []
 
240
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
241
  datasets.append(ds_wrapper)
242
  else:
243
+ suffix = ""
244
+ if ":load_" in d.type:
245
+ suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
246
+ logging.error(
247
+ f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
248
+ )
249
+ raise ValueError(
250
+ f"unhandled prompt tokenization strategy: {d.type} {suffix}"
251
+ )
252
  logging.info("tokenizing, merging, and shuffling master dataset")
253
 
254
  samples: List[int] = []
src/axolotl/utils/models.py CHANGED
@@ -20,15 +20,6 @@ from transformers import (
20
  LlamaConfig,
21
  )
22
 
23
- try:
24
- from transformers import ( # pylint: disable=unused-import # noqa: F401
25
- LlamaForCausalLM,
26
- )
27
- except ImportError:
28
- logging.warning(
29
- "This version of transformers does not support Llama. Consider upgrading."
30
- )
31
-
32
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
33
 
34
  if TYPE_CHECKING:
@@ -78,15 +69,9 @@ def load_tokenizer(
78
 
79
 
80
  def load_model(
81
- base_model,
82
- base_model_config,
83
- model_type,
84
- tokenizer,
85
- cfg,
86
- adapter="lora",
87
- inference=False,
88
  ):
89
- # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
90
  """
91
  Load a model from a base model and a model type.
92
  """
@@ -98,7 +83,7 @@ def load_model(
98
  )
99
 
100
  if cfg.is_llama_derived_model and cfg.flash_attention:
101
- if cfg.device not in ["mps", "cpu"] and inference is False:
102
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
103
 
104
  logging.info("patching with flash attention")
@@ -118,14 +103,15 @@ def load_model(
118
  logging.info("patching with sdp attention")
119
  hijack_llama_sdp_attention()
120
  elif cfg.is_llama_derived_model and cfg.landmark_attention:
121
- from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
122
  MEM_TOKEN,
123
- LlamaForCausalLM,
124
  )
125
 
126
  logging.info("patching with landmark attention")
 
127
 
128
- # TODO: Check if this would overwrite previous additional_special_tokens
129
  tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
130
 
131
  if cfg.is_llama_derived_model and cfg.xpos_rope:
@@ -210,7 +196,9 @@ def load_model(
210
  else True,
211
  )
212
  load_in_8bit = False
213
- elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
 
 
214
  config = LlamaConfig.from_pretrained(base_model_config)
215
  model = LlamaForCausalLM.from_pretrained(
216
  base_model,
@@ -314,7 +302,9 @@ def load_model(
314
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
315
  ):
316
  logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
317
- model = prepare_model_for_kbit_training(model)
 
 
318
 
319
  model, lora_config = load_adapter(model, cfg, adapter)
320
 
@@ -387,7 +377,6 @@ def load_llama_adapter(model, cfg):
387
  model = PeftModel.from_pretrained(
388
  model,
389
  cfg.lora_model_dir,
390
- device_map=cfg.device_map,
391
  torch_dtype=torch.float16,
392
  )
393
  else:
@@ -449,8 +438,7 @@ def load_lora(model, cfg):
449
  model = PeftModel.from_pretrained(
450
  model,
451
  cfg.lora_model_dir,
452
- device_map=cfg.device_map,
453
- # torch_dtype=torch.float16,
454
  )
455
  else:
456
  model = get_peft_model(model, lora_config)
 
20
  LlamaConfig,
21
  )
22
 
 
 
 
 
 
 
 
 
 
23
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
24
 
25
  if TYPE_CHECKING:
 
69
 
70
 
71
  def load_model(
72
+ base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
 
 
 
 
 
 
73
  ):
74
+ # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
75
  """
76
  Load a model from a base model and a model type.
77
  """
 
83
  )
84
 
85
  if cfg.is_llama_derived_model and cfg.flash_attention:
86
+ if cfg.device not in ["mps", "cpu"] and not cfg.inference:
87
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
88
 
89
  logging.info("patching with flash attention")
 
103
  logging.info("patching with sdp attention")
104
  hijack_llama_sdp_attention()
105
  elif cfg.is_llama_derived_model and cfg.landmark_attention:
106
+ from axolotl.monkeypatch.llama_landmark_attn import (
107
  MEM_TOKEN,
108
+ patch_llama_with_landmark_attn,
109
  )
110
 
111
  logging.info("patching with landmark attention")
112
+ patch_llama_with_landmark_attn()
113
 
114
+ # Note: This might overwrite previous additional_special_tokens
115
  tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
116
 
117
  if cfg.is_llama_derived_model and cfg.xpos_rope:
 
196
  else True,
197
  )
198
  load_in_8bit = False
199
+ elif cfg.is_llama_derived_model:
200
+ from transformers import LlamaForCausalLM
201
+
202
  config = LlamaConfig.from_pretrained(base_model_config)
203
  model = LlamaForCausalLM.from_pretrained(
204
  base_model,
 
302
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
303
  ):
304
  logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
305
+ model = prepare_model_for_kbit_training(
306
+ model, use_gradient_checkpointing=cfg.gradient_checkpointing
307
+ )
308
 
309
  model, lora_config = load_adapter(model, cfg, adapter)
310
 
 
377
  model = PeftModel.from_pretrained(
378
  model,
379
  cfg.lora_model_dir,
 
380
  torch_dtype=torch.float16,
381
  )
382
  else:
 
438
  model = PeftModel.from_pretrained(
439
  model,
440
  cfg.lora_model_dir,
441
+ is_trainable=not cfg.inference,
 
442
  )
443
  else:
444
  model = get_peft_model(model, lora_config)
src/axolotl/utils/trainer.py CHANGED
@@ -245,16 +245,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
245
  if cfg.is_llama_derived_model and cfg.landmark_attention:
246
  from functools import partial
247
 
248
- from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
 
 
 
 
249
 
250
- mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
251
- model.set_mem_id(mem_id)
252
 
253
  logging.info("Adding landmark attention tokens to dataset")
254
 
255
  for dataset in [train_dataset, eval_dataset]:
256
  dataset = dataset.map(
257
- partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
258
  batched=False,
259
  num_proc=32,
260
  )
 
245
  if cfg.is_llama_derived_model and cfg.landmark_attention:
246
  from functools import partial
247
 
248
+ from axolotl.monkeypatch.llama_landmark_attn import (
249
+ add_mem_tokens,
250
+ get_mem_id,
251
+ set_model_mem_id,
252
+ )
253
 
254
+ set_model_mem_id(model, tokenizer)
 
255
 
256
  logging.info("Adding landmark attention tokens to dataset")
257
 
258
  for dataset in [train_dataset, eval_dataset]:
259
  dataset = dataset.map(
260
+ partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
261
  batched=False,
262
  num_proc=32,
263
  )
src/axolotl/utils/validation.py CHANGED
@@ -59,6 +59,11 @@ def validate_config(cfg):
59
  if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
60
  raise ValueError("FSDP is not supported for falcon models")
61
 
 
 
 
 
 
62
  if cfg.flash_optimum is True:
63
  if cfg.adapter:
64
  logging.warning(
 
59
  if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
60
  raise ValueError("FSDP is not supported for falcon models")
61
 
62
+ if (
63
+ cfg.base_model and "mpt" in cfg.base_model.lower()
64
+ ) and cfg.gradient_checkpointing:
65
+ raise ValueError("gradient_checkpointing is not supported for MPT models")
66
+
67
  if cfg.flash_optimum is True:
68
  if cfg.adapter:
69
  logging.warning(
tests/test_validation.py CHANGED
@@ -199,6 +199,20 @@ class ValidationTest(unittest.TestCase):
199
 
200
  validate_config(cfg)
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  def test_flash_optimum(self):
203
  cfg = DictDefault(
204
  {
 
199
 
200
  validate_config(cfg)
201
 
202
+ def test_mpt_gradient_checkpointing(self):
203
+ regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
204
+
205
+ # Check for lower-case
206
+ cfg = DictDefault(
207
+ {
208
+ "base_model": "mosaicml/mpt-7b",
209
+ "gradient_checkpointing": True,
210
+ }
211
+ )
212
+
213
+ with pytest.raises(ValueError, match=regex_exp):
214
+ validate_config(cfg)
215
+
216
  def test_flash_optimum(self):
217
  cfg = DictDefault(
218
  {