tmm1 commited on
Commit
868530c
1 Parent(s): 17605b8

let transformers handle adamw_bnb_8bit

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +2 -71
src/axolotl/utils/trainer.py CHANGED
@@ -10,19 +10,13 @@ from functools import partial
10
  from pathlib import Path
11
  from typing import Optional, Union
12
 
13
- import bitsandbytes as bnb
14
  import numpy as np
15
  import torch.cuda
16
- import transformers
17
  from datasets import Dataset, set_caching_enabled
18
- from torch import nn
19
  from torch.optim.lr_scheduler import OneCycleLR
20
  from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
21
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
22
- from transformers.trainer_pt_utils import (
23
- SequentialDistributedSampler,
24
- get_parameter_names,
25
- )
26
 
27
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
28
  from axolotl.utils.callbacks import (
@@ -32,10 +26,7 @@ from axolotl.utils.callbacks import (
32
  )
33
  from axolotl.utils.collators import DataCollatorForSeq2Seq
34
  from axolotl.utils.dataloader import MultipackDistributedDataloader
35
- from axolotl.utils.schedulers import (
36
- InterpolatingLogScheduler,
37
- get_cosine_schedule_with_quadratic_warmup,
38
- )
39
 
40
  LOG = logging.getLogger("axolotl")
41
 
@@ -570,66 +561,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
570
  if Path(cfg.torchdistx_path).exists():
571
  sys.path.append(cfg.torchdistx_path)
572
  importlib.import_module("torchdistx")
573
- if (
574
- cfg.optimizer == "adamw_bnb_8bit"
575
- and not cfg.gptq
576
- and "deepspeed" not in training_arguments_kwargs
577
- and not cfg.fsdp
578
- ):
579
- decay_parameters = get_parameter_names(model, [nn.LayerNorm])
580
- decay_parameters = [name for name in decay_parameters if "bias" not in name]
581
- optimizer_grouped_parameters = [
582
- {
583
- "params": [
584
- p
585
- for n, p in model.named_parameters()
586
- if (n in decay_parameters and p.requires_grad)
587
- ],
588
- "weight_decay": training_args.weight_decay,
589
- },
590
- {
591
- "params": [
592
- p
593
- for n, p in model.named_parameters()
594
- if (n not in decay_parameters and p.requires_grad)
595
- ],
596
- "weight_decay": 0.0,
597
- },
598
- ]
599
-
600
- optimizer = bnb.optim.Adam8bit(
601
- optimizer_grouped_parameters,
602
- betas=(training_args.adam_beta1, training_args.adam_beta2),
603
- eps=training_args.adam_epsilon,
604
- lr=training_args.learning_rate,
605
- )
606
-
607
- if cfg.lr_scheduler == "one_cycle":
608
- lr_scheduler_kwargs = (
609
- cfg.lr_scheduler_kwargs if cfg.lr_scheduler_kwargs else {}
610
- )
611
- lr_scheduler = OneCycleLR(
612
- optimizer,
613
- cfg.learning_rate,
614
- total_steps=total_num_steps,
615
- epochs=cfg.num_epochs,
616
- div_factor=cfg.lr_div_factor if cfg.lr_div_factor else 6,
617
- **lr_scheduler_kwargs,
618
- )
619
- elif cfg.lr_scheduler == "log_sweep":
620
- lr_scheduler = InterpolatingLogScheduler(
621
- optimizer,
622
- cfg.warmup_steps,
623
- cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10,
624
- cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10,
625
- )
626
- else:
627
- lr_scheduler = transformers.get_cosine_schedule_with_warmup(
628
- optimizer,
629
- training_args.warmup_steps,
630
- total_num_steps,
631
- )
632
- trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
633
 
634
  callbacks = []
635
  callbacks.append(GPUStatsCallback(cfg))
 
10
  from pathlib import Path
11
  from typing import Optional, Union
12
 
 
13
  import numpy as np
14
  import torch.cuda
 
15
  from datasets import Dataset, set_caching_enabled
 
16
  from torch.optim.lr_scheduler import OneCycleLR
17
  from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
18
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
19
+ from transformers.trainer_pt_utils import SequentialDistributedSampler
 
 
 
20
 
21
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
22
  from axolotl.utils.callbacks import (
 
26
  )
27
  from axolotl.utils.collators import DataCollatorForSeq2Seq
28
  from axolotl.utils.dataloader import MultipackDistributedDataloader
29
+ from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
 
 
 
30
 
31
  LOG = logging.getLogger("axolotl")
32
 
 
561
  if Path(cfg.torchdistx_path).exists():
562
  sys.path.append(cfg.torchdistx_path)
563
  importlib.import_module("torchdistx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
 
565
  callbacks = []
566
  callbacks.append(GPUStatsCallback(cfg))