winglian commited on
Commit
9493b1b
1 Parent(s): 1b3e401

be able to use adam bnb 8bit and one cycle scheduler w fsdp

Browse files
src/axolotl/utils/data.py CHANGED
@@ -7,7 +7,7 @@ from datasets import (
7
  load_dataset,
8
  IterableDataset,
9
  Dataset,
10
- concatenate_datasets,
11
  )
12
  from huggingface_hub import hf_hub_download
13
  from transformers import PreTrainedTokenizerBase
@@ -37,7 +37,7 @@ from axolotl.prompters import (
37
  )
38
 
39
 
40
- def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
41
  tokenizer_name = tokenizer.__class__.__name__
42
  ds_hash = str(
43
  md5(
@@ -196,7 +196,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
196
  return dataset
197
 
198
 
199
- def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path):
200
  max_packed_sequence_len = (
201
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
202
  )
 
7
  load_dataset,
8
  IterableDataset,
9
  Dataset,
10
+ concatenate_datasets, DatasetDict,
11
  )
12
  from huggingface_hub import hf_hub_download
13
  from transformers import PreTrainedTokenizerBase
 
37
  )
38
 
39
 
40
+ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path) -> DatasetDict:
41
  tokenizer_name = tokenizer.__class__.__name__
42
  ds_hash = str(
43
  md5(
 
196
  return dataset
197
 
198
 
199
+ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path) -> (Dataset, Dataset):
200
  max_packed_sequence_len = (
201
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
202
  )
src/axolotl/utils/trainer.py CHANGED
@@ -9,13 +9,31 @@ import torch.cuda
9
  import transformers
10
  from torch import nn
11
  from torch.optim.lr_scheduler import OneCycleLR
12
- from transformers import EarlyStoppingCallback
13
  from transformers.trainer_pt_utils import get_parameter_names
14
 
15
  from axolotl.utils.schedulers import InterpolatingLogScheduler
16
  from axolotl.utils.callbacks import SavePeftModelCallback
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
20
  total_num_steps = int(
21
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
@@ -63,6 +81,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
63
  training_arguments_kwargs["fsdp"] = cfg.fsdp
64
  if cfg.fsdp_config:
65
  training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
 
 
 
66
 
67
  # deepspeed
68
  if (
@@ -119,6 +140,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
119
  cfg.optimizer == "adamw_bnb_8bit"
120
  and not cfg.load_4bit
121
  and not "deepspeed" in training_arguments_kwargs
 
122
  ):
123
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
124
  decay_parameters = [name for name in decay_parameters if "bias" not in name]
@@ -194,7 +216,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
194
  else:
195
  data_collator_kwargs["pad_to_multiple_of"] = 8
196
 
197
- trainer = transformers.Trainer(
 
198
  model=model,
199
  train_dataset=train_dataset,
200
  eval_dataset=eval_dataset,
 
9
  import transformers
10
  from torch import nn
11
  from torch.optim.lr_scheduler import OneCycleLR
12
+ from transformers import EarlyStoppingCallback, Trainer
13
  from transformers.trainer_pt_utils import get_parameter_names
14
 
15
  from axolotl.utils.schedulers import InterpolatingLogScheduler
16
  from axolotl.utils.callbacks import SavePeftModelCallback
17
 
18
 
19
+ class OneCycleLRSchedulerTrainer(Trainer):
20
+ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
21
+ optimizer=self.optimizer if optimizer is None else optimizer
22
+ num_warmup_steps=self.args.get_warmup_steps(num_training_steps)
23
+ num_training_steps=num_training_steps
24
+ pct_start = num_warmup_steps / num_training_steps
25
+
26
+ lr_scheduler = OneCycleLR(
27
+ optimizer,
28
+ max_lr=self.args.learning_rate,
29
+ total_steps=num_training_steps,
30
+ pct_start=pct_start,
31
+ div_factor=6,
32
+ )
33
+
34
+ return lr_scheduler
35
+
36
+
37
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
38
  total_num_steps = int(
39
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
 
81
  training_arguments_kwargs["fsdp"] = cfg.fsdp
82
  if cfg.fsdp_config:
83
  training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
84
+ # can't set optimizers directly on trainer when using fsdp, so set them here
85
+ if cfg.optimizer:
86
+ training_arguments_kwargs["optim"] = cfg.optimizer
87
 
88
  # deepspeed
89
  if (
 
140
  cfg.optimizer == "adamw_bnb_8bit"
141
  and not cfg.load_4bit
142
  and not "deepspeed" in training_arguments_kwargs
143
+ and not cfg.fsdp
144
  ):
145
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
146
  decay_parameters = [name for name in decay_parameters if "bias" not in name]
 
216
  else:
217
  data_collator_kwargs["pad_to_multiple_of"] = 8
218
 
219
+ trainer_cls = OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and cfg.fsdp else transformers.Trainer
220
+ trainer = trainer_cls(
221
  model=model,
222
  train_dataset=train_dataset,
223
  eval_dataset=eval_dataset,