winglian commited on
Commit
d1aed4c
1 Parent(s): a459383

deepspeed doesn't work with flash-attn, and the gpu savings w flash attn are better than the deepspeed headaches

Browse files
configs/cerebras_1_3B_alpaca.yml CHANGED
@@ -34,6 +34,6 @@ train_on_inputs: false
34
  group_by_length: false
35
  bf16: True
36
  tf32: True
 
37
  resume_from_checkpoint:
38
  local_rank:
39
- deepspeed:
 
34
  group_by_length: false
35
  bf16: True
36
  tf32: True
37
+ early_stopping_patience:
38
  resume_from_checkpoint:
39
  local_rank:
 
configs/llama_65B_alpaca.yml CHANGED
@@ -36,6 +36,6 @@ train_on_inputs: false
36
  group_by_length: false
37
  bf16: true
38
  tf32: true
 
39
  resume_from_checkpoint:
40
  local_rank:
41
- deepspeed:
 
36
  group_by_length: false
37
  bf16: true
38
  tf32: true
39
+ early_stopping_patience:
40
  resume_from_checkpoint:
41
  local_rank:
 
configs/llama_7B_alpaca.yml CHANGED
@@ -36,6 +36,6 @@ train_on_inputs: false
36
  group_by_length: false
37
  bf16: true
38
  tf32: true
 
39
  resume_from_checkpoint:
40
  local_rank:
41
- deepspeed:
 
36
  group_by_length: false
37
  bf16: true
38
  tf32: true
39
+ early_stopping_patience:
40
  resume_from_checkpoint:
41
  local_rank:
 
configs/pythia_1_2B_alpaca.yml CHANGED
@@ -36,6 +36,6 @@ train_on_inputs: false
36
  group_by_length: false
37
  bf16: True
38
  tf32: True
 
39
  resume_from_checkpoint:
40
  local_rank:
41
- deepspeed:
 
36
  group_by_length: false
37
  bf16: True
38
  tf32: True
39
+ early_stopping_patience:
40
  resume_from_checkpoint:
41
  local_rank:
 
ds_config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "bf16": {
3
- "enabled": "auto",
4
  },
5
  "fp16": {
6
  "enabled": "auto",
@@ -10,15 +10,6 @@
10
  "hysteresis": 2,
11
  "min_loss_scale": 1
12
  },
13
- "optimizer": {
14
- "type": "AdamW",
15
- "params": {
16
- "lr": "auto",
17
- "betas": "auto",
18
- "eps": "auto",
19
- "weight_decay": "auto"
20
- }
21
- },
22
  "scheduler": {
23
  "type": "WarmupLR",
24
  "params": {
@@ -28,29 +19,19 @@
28
  }
29
  },
30
  "zero_optimization": {
31
- "stage": 3,
32
- "offload_optimizer": {
33
- "device": "cpu",
34
- "pin_memory": true
35
- },
36
- "offload_param": {
37
- "device": "cpu",
38
- "pin_memory": true
39
- },
40
  "overlap_comm": true,
 
 
41
  "contiguous_gradients": true,
42
- "sub_group_size": 1e9,
43
  "reduce_bucket_size": "auto",
44
- "stage3_prefetch_bucket_size": "auto",
45
- "stage3_param_persistence_threshold": "auto",
46
- "stage3_max_live_parameters": 1e9,
47
- "stage3_max_reuse_distance": 1e9,
48
- "stage3_gather_16bit_weights_on_model_save": true
49
  },
50
  "gradient_accumulation_steps": "auto",
51
  "gradient_clipping": "auto",
52
  "steps_per_print": 5,
53
  "train_batch_size": "auto",
54
  "train_micro_batch_size_per_gpu": "auto",
55
- "wall_clock_breakdown": false
 
56
  }
 
1
  {
2
  "bf16": {
3
+ "enabled": "auto"
4
  },
5
  "fp16": {
6
  "enabled": "auto",
 
10
  "hysteresis": 2,
11
  "min_loss_scale": 1
12
  },
 
 
 
 
 
 
 
 
 
13
  "scheduler": {
14
  "type": "WarmupLR",
15
  "params": {
 
19
  }
20
  },
21
  "zero_optimization": {
22
+ "stage": 2,
 
 
 
 
 
 
 
 
23
  "overlap_comm": true,
24
+ "allgather_partitions": true,
25
+ "allgather_bucket_size": 5e8,
26
  "contiguous_gradients": true,
 
27
  "reduce_bucket_size": "auto",
28
+ "reduce_scatter": true
 
 
 
 
29
  },
30
  "gradient_accumulation_steps": "auto",
31
  "gradient_clipping": "auto",
32
  "steps_per_print": 5,
33
  "train_batch_size": "auto",
34
  "train_micro_batch_size_per_gpu": "auto",
35
+ "wall_clock_breakdown": false,
36
+ "round_robin_gradients": true
37
  }
scripts/finetune.py CHANGED
@@ -20,7 +20,13 @@ from peft import (
20
  PeftModel,
21
  )
22
  from torch import nn
23
- from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
 
 
 
 
 
 
24
 
25
  # add src to the pythonpath so we don't need to pip install this
26
  from transformers.trainer_pt_utils import get_parameter_names
@@ -54,11 +60,11 @@ def setup_wandb_env_vars(cfg):
54
  os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
55
 
56
 
57
- def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
58
  if adapter != "lora":
59
  raise NotImplementedError(f"{adapter} peft adapter not available")
60
  if "llama" in base_model:
61
- if cfg.device not in ["mps", "cpu"]:
62
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
63
  replace_llama_attn_with_flash_attn()
64
 
@@ -185,7 +191,7 @@ def do_inference(cfg, model, tokenizer):
185
  generated = model.generate(inputs=batch["input_ids"],
186
  do_sample=True, use_cache=True,
187
  repetition_penalty=1.1,
188
- max_new_tokens=50,
189
  temperature=0.9,
190
  top_p=0.95,
191
  top_k=40,
@@ -224,19 +230,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
224
  total_num_steps = int(
225
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
226
  )
 
 
227
  save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
228
 
229
  training_arguments_kwargs = {}
230
-
231
- if not cfg.deepspeed:
232
- warmup_steps = min(int(0.03 * total_num_steps), 100)
233
- logging_steps = min(int(0.005 * total_num_steps), 10)
234
-
235
- training_arguments_kwargs["warmup_steps"] = warmup_steps
236
- training_arguments_kwargs["logging_steps"] = logging_steps
237
- training_arguments_kwargs["logging_steps"] = logging_steps
238
- training_arguments_kwargs["bf16"] = cfg.bf16
239
- training_arguments_kwargs["tf32"] = cfg.tf32
240
 
241
  training_args = transformers.TrainingArguments(
242
  per_device_train_batch_size=cfg.micro_batch_size,
@@ -258,37 +260,40 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
258
  )
259
 
260
  trainer_kwargs = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
- if not cfg.deepspeed:
263
- decay_parameters = get_parameter_names(model, [nn.LayerNorm])
264
- decay_parameters = [name for name in decay_parameters if "bias" not in name]
265
- optimizer_grouped_parameters = [
266
- {
267
- "params": [p for n, p in model.named_parameters() if n in decay_parameters],
268
- "weight_decay": training_args.weight_decay,
269
- },
270
- {
271
- "params": [
272
- p for n, p in model.named_parameters() if n not in decay_parameters
273
- ],
274
- "weight_decay": 0.0,
275
- },
276
- ]
277
-
278
- adam_bnb_optim = bnb.optim.Adam8bit(
279
- optimizer_grouped_parameters,
280
- betas=(training_args.adam_beta1, training_args.adam_beta2),
281
- eps=training_args.adam_epsilon,
282
- lr=training_args.learning_rate,
283
- )
284
 
285
- lr_scheduler = transformers.get_cosine_schedule_with_warmup(
286
- adam_bnb_optim,
287
- training_args.warmup_steps,
288
- total_num_steps,
289
  )
290
- trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
291
-
292
 
293
  trainer = transformers.Trainer(
294
  model=model,
@@ -340,7 +345,7 @@ def train(
340
 
341
  # Load the model and tokenizer
342
  model, tokenizer, lora_config = load_model(
343
- cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter
344
  )
345
 
346
  if "inference" in kwargs:
@@ -422,17 +427,19 @@ def train(
422
  lora_config.save_pretrained(cfg.output_dir)
423
 
424
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
425
- signal.signal(
426
- signal.SIGINT,
427
- lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
428
- )
 
429
 
430
  logging.info("Starting trainer...")
431
  trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
432
 
433
- # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
434
- logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
435
- model.save_pretrained(cfg.output_dir)
 
436
 
437
 
438
  if __name__ == "__main__":
 
20
  PeftModel,
21
  )
22
  from torch import nn
23
+ from transformers import (
24
+ AutoModelForCausalLM,
25
+ AutoTokenizer,
26
+ LlamaForCausalLM,
27
+ LlamaTokenizer,
28
+ EarlyStoppingCallback,
29
+ )
30
 
31
  # add src to the pythonpath so we don't need to pip install this
32
  from transformers.trainer_pt_utils import get_parameter_names
 
60
  os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
61
 
62
 
63
+ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
64
  if adapter != "lora":
65
  raise NotImplementedError(f"{adapter} peft adapter not available")
66
  if "llama" in base_model:
67
+ if cfg.device not in ["mps", "cpu"] and inference is False:
68
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
69
  replace_llama_attn_with_flash_attn()
70
 
 
191
  generated = model.generate(inputs=batch["input_ids"],
192
  do_sample=True, use_cache=True,
193
  repetition_penalty=1.1,
194
+ max_new_tokens=100,
195
  temperature=0.9,
196
  top_p=0.95,
197
  top_k=40,
 
230
  total_num_steps = int(
231
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
232
  )
233
+ warmup_steps = min(int(0.03 * total_num_steps), 100)
234
+ logging_steps = min(int(0.005 * total_num_steps), 10)
235
  save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
236
 
237
  training_arguments_kwargs = {}
238
+ training_arguments_kwargs["bf16"] = cfg.bf16
239
+ training_arguments_kwargs["tf32"] = cfg.tf32
240
+ training_arguments_kwargs["warmup_steps"] = warmup_steps
241
+ training_arguments_kwargs["logging_steps"] = logging_steps
 
 
 
 
 
 
242
 
243
  training_args = transformers.TrainingArguments(
244
  per_device_train_batch_size=cfg.micro_batch_size,
 
260
  )
261
 
262
  trainer_kwargs = {}
263
+ decay_parameters = get_parameter_names(model, [nn.LayerNorm])
264
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
265
+ optimizer_grouped_parameters = [
266
+ {
267
+ "params": [p for n, p in model.named_parameters() if n in decay_parameters],
268
+ "weight_decay": training_args.weight_decay,
269
+ },
270
+ {
271
+ "params": [
272
+ p for n, p in model.named_parameters() if n not in decay_parameters
273
+ ],
274
+ "weight_decay": 0.0,
275
+ },
276
+ ]
277
+
278
+ adam_bnb_optim = bnb.optim.Adam8bit(
279
+ optimizer_grouped_parameters,
280
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
281
+ eps=training_args.adam_epsilon,
282
+ lr=training_args.learning_rate,
283
+ )
284
 
285
+ lr_scheduler = transformers.get_cosine_schedule_with_warmup(
286
+ adam_bnb_optim,
287
+ training_args.warmup_steps,
288
+ total_num_steps,
289
+ )
290
+ trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
+ if cfg.early_stopping_patience:
293
+ early_stop_cb = EarlyStoppingCallback(
294
+ cfg.early_stopping_patience,
 
295
  )
296
+ trainer_kwargs["callbacks"] = [early_stop_cb]
 
297
 
298
  trainer = transformers.Trainer(
299
  model=model,
 
345
 
346
  # Load the model and tokenizer
347
  model, tokenizer, lora_config = load_model(
348
+ cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
349
  )
350
 
351
  if "inference" in kwargs:
 
427
  lora_config.save_pretrained(cfg.output_dir)
428
 
429
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
430
+ if cfg.local_rank == 0:
431
+ signal.signal(
432
+ signal.SIGINT,
433
+ lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
434
+ )
435
 
436
  logging.info("Starting trainer...")
437
  trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
438
 
439
+ if cfg.local_rank == 0:
440
+ # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
441
+ logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
442
+ model.save_pretrained(cfg.output_dir)
443
 
444
 
445
  if __name__ == "__main__":