import os from transformers import ( Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl, ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR class SavePeftModelCallback(TrainerCallback): def on_save( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): checkpoint_folder = os.path.join( args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" ) peft_model_path = os.path.join(checkpoint_folder, "adapter_model") kwargs["model"].save_pretrained(peft_model_path) return control