Jan Philipp Harries Jan Philipp Harries commited on
Commit
be75668
1 Parent(s): aeec7c4

set fsdp state dict (#584)

Browse files

Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>

Files changed (1) hide show
  1. src/axolotl/train.py +4 -0
src/axolotl/train.py CHANGED
@@ -117,6 +117,10 @@ def train(
117
 
118
  LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
119
 
 
 
 
 
120
  if cfg.relora_steps:
121
  if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
122
  model = model.merge_and_unload()
 
117
 
118
  LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
119
 
120
+ if trainer.is_fsdp_enabled:
121
+ trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
122
+ LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
123
+
124
  if cfg.relora_steps:
125
  if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
126
  model = model.merge_and_unload()