"""Callbacks for Trainer class""" import os from optimum.bettertransformer import BetterTransformer from transformers import ( TrainerCallback, TrainerControl, TrainerState, TrainingArguments, ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods """Callback to save the PEFT adapter""" def on_save( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): checkpoint_folder = os.path.join( args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", ) peft_model_path = os.path.join(checkpoint_folder, "adapter_model") kwargs["model"].save_pretrained(peft_model_path) return control class SaveBetterTransformerModelCallback( TrainerCallback ): # pylint: disable=too-few-public-methods """Callback to save the BetterTransformer wrapped model""" def on_step_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): # Save if ( args.save_strategy == IntervalStrategy.STEPS and args.save_steps > 0 and state.global_step % args.save_steps == 0 ): control.should_save = True if control.should_save: checkpoint_folder = os.path.join( args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", ) model = BetterTransformer.reverse(kwargs["model"]) model.save_pretrained(checkpoint_folder) # FIXME - need to cleanup old checkpoints # since we're saving here, we don't need the trainer loop to attempt to save too b/c # the trainer will raise an exception since it can't save a BetterTransformer wrapped model control.should_save = False return control