winglian commited on
Commit
dd449c5
1 Parent(s): 40a88e8

support galore once upstreamed into transformers (#1409)

Browse files

* support galore once upstreamed into transformers

* update module name for llama in readme and fix typing for all linear

* bump trl for deprecation fixes from newer transformers

* include galore as an extra and install in docker image

* fix optim_args type

* fix optim_args

* update dependencies for galore

* add galore to cicd dockerfile

README.md CHANGED
@@ -907,7 +907,26 @@ lr_div_factor: # Learning rate div factor
907
  # - paged_adamw_8bit
908
  # - paged_lion_32bit
909
  # - paged_lion_8bit
 
 
 
 
 
 
910
  optimizer:
 
 
 
 
 
 
 
 
 
 
 
 
 
911
  # Specify weight decay
912
  weight_decay:
913
  # adamw hyperparams
 
907
  # - paged_adamw_8bit
908
  # - paged_lion_32bit
909
  # - paged_lion_8bit
910
+ # - galore_adamw
911
+ # - galore_adamw_8bit
912
+ # - galore_adafactor
913
+ # - galore_adamw_layerwise
914
+ # - galore_adamw_8bit_layerwise
915
+ # - galore_adafactor_layerwise
916
  optimizer:
917
+ # Dictionary of arguments to pass to the optimizer
918
+ optim_args:
919
+ # For Galore Optimizers the following optim_args are available
920
+ # rank: # type: int
921
+ # update_proj_gap # type: int
922
+ # scale # type: float
923
+ # proj_type: # type: str, default = std
924
+
925
+ # The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
926
+ optim_target_modules:
927
+ # - self_attn # for llama
928
+ # - mlp
929
+
930
  # Specify weight decay
931
  weight_decay:
932
  # adamw hyperparams
cicd/Dockerfile.jinja CHANGED
@@ -23,9 +23,9 @@ RUN git fetch origin +$GITHUB_REF && \
23
 
24
  # If AXOLOTL_EXTRAS is set, append it in brackets
25
  RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
26
- pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
27
  else \
28
- pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
29
  fi
30
 
31
  # So we can test the Docker image
 
23
 
24
  # If AXOLOTL_EXTRAS is set, append it in brackets
25
  RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
26
+ pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
27
  else \
28
+ pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
29
  fi
30
 
31
  # So we can test the Docker image
docker/Dockerfile CHANGED
@@ -21,9 +21,9 @@ WORKDIR /workspace/axolotl
21
 
22
  # If AXOLOTL_EXTRAS is set, append it in brackets
23
  RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
24
- pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
25
  else \
26
- pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
27
  fi
28
 
29
  # So we can test the Docker image
 
21
 
22
  # If AXOLOTL_EXTRAS is set, append it in brackets
23
  RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
24
+ pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
25
  else \
26
+ pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
27
  fi
28
 
29
  # So we can test the Docker image
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
  peft==0.9.0
4
- transformers==4.38.2
5
  tokenizers==0.15.0
6
  bitsandbytes>=0.43.0
7
  accelerate==0.26.1
@@ -39,5 +39,5 @@ s3fs
39
  gcsfs
40
  # adlfs
41
 
42
- trl>=0.7.9
43
  fastcore>=1.5.29
 
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
  peft==0.9.0
4
+ transformers @ git+https://github.com/huggingface/transformers.git@f6261d7d81edd036fc53bfede65fe91f01a661aa
5
  tokenizers==0.15.0
6
  bitsandbytes>=0.43.0
7
  accelerate==0.26.1
 
39
  gcsfs
40
  # adlfs
41
 
42
+ trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
43
  fastcore>=1.5.29
setup.py CHANGED
@@ -89,5 +89,8 @@ setup(
89
  "lion-pytorch": [
90
  "lion-pytorch==0.1.2",
91
  ],
 
 
 
92
  },
93
  )
 
89
  "lion-pytorch": [
90
  "lion-pytorch==0.1.2",
91
  ],
92
+ "galore": [
93
+ "galore_torch",
94
+ ],
95
  },
96
  )
src/axolotl/core/trainer_builder.py CHANGED
@@ -220,7 +220,7 @@ class AxolotlTrainer(Trainer):
220
  num_epochs=1,
221
  bench_data_collator=None,
222
  eval_data_collator=None,
223
- **kwargs
224
  ):
225
  self.num_epochs = num_epochs
226
  self.bench_data_collator = bench_data_collator
@@ -239,6 +239,7 @@ class AxolotlTrainer(Trainer):
239
  if self.optimizer is None: # pylint: disable=access-member-before-definition
240
  optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
241
  self.args,
 
242
  )
243
 
244
  loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
@@ -1150,6 +1151,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
1150
  training_arguments_kwargs["optim"] = (
1151
  self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
1152
  )
 
 
 
 
 
 
 
 
 
 
 
 
1153
  training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
1154
  training_arguments_kwargs[
1155
  "loraplus_lr_embedding"
 
220
  num_epochs=1,
221
  bench_data_collator=None,
222
  eval_data_collator=None,
223
+ **kwargs,
224
  ):
225
  self.num_epochs = num_epochs
226
  self.bench_data_collator = bench_data_collator
 
239
  if self.optimizer is None: # pylint: disable=access-member-before-definition
240
  optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
241
  self.args,
242
+ opt_model,
243
  )
244
 
245
  loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
 
1151
  training_arguments_kwargs["optim"] = (
1152
  self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
1153
  )
1154
+ if self.cfg.optim_args:
1155
+ if isinstance(self.cfg.optim_args, dict):
1156
+ optim_args = ",".join(
1157
+ [f"{key}={value}" for key, value in self.cfg.optim_args.items()]
1158
+ )
1159
+ else:
1160
+ optim_args = self.cfg.optim_args
1161
+ training_arguments_kwargs["optim_args"] = optim_args
1162
+ if self.cfg.optim_target_modules:
1163
+ training_arguments_kwargs[
1164
+ "optim_target_modules"
1165
+ ] = self.cfg.optim_target_modules
1166
  training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
1167
  training_arguments_kwargs[
1168
  "loraplus_lr_embedding"
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -313,6 +313,15 @@ class HyperparametersConfig(BaseModel):
313
  learning_rate: Union[str, float]
314
  weight_decay: Optional[float] = None
315
  optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
 
 
 
 
 
 
 
 
 
316
  torchdistx_path: Optional[str] = None
317
  lr_scheduler: Optional[SchedulerType] = None
318
  lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
 
313
  learning_rate: Union[str, float]
314
  weight_decay: Optional[float] = None
315
  optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
316
+ optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
317
+ default=None, metadata={"help": "Optional arguments to supply to optimizer."}
318
+ )
319
+ optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field(
320
+ default=None,
321
+ metadata={
322
+ "help": "The target modules to optimize, i.e. the module names that you would like to train."
323
+ },
324
+ )
325
  torchdistx_path: Optional[str] = None
326
  lr_scheduler: Optional[SchedulerType] = None
327
  lr_scheduler_kwargs: Optional[Dict[str, Any]] = None