winglian commited on
Commit
c0f50d9
1 Parent(s): 4e705ed

wire up gradient checkpointing for 4bit

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +7 -1
src/axolotl/utils/trainer.py CHANGED
@@ -28,7 +28,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
28
  training_arguments_kwargs["warmup_steps"] = warmup_steps
29
  training_arguments_kwargs["logging_steps"] = logging_steps
30
  if cfg.gradient_checkpointing is not None:
31
- training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
 
 
 
 
 
 
32
 
33
  # deepspeed
34
  if os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" and torch.cuda.device_count() > 1:
 
28
  training_arguments_kwargs["warmup_steps"] = warmup_steps
29
  training_arguments_kwargs["logging_steps"] = logging_steps
30
  if cfg.gradient_checkpointing is not None:
31
+ if cfg.load_4bit:
32
+ from alpaca_lora_4bit.gradient_checkpointing import apply_gradient_checkpointing
33
+ gradient_checkpointing_ratio = cfg.gradient_checkpointing_ratio if cfg.gradient_checkpointing_ratio else 1.0
34
+ apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
35
+ else:
36
+ training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
37
+
38
 
39
  # deepspeed
40
  if os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" and torch.cuda.device_count() > 1: