tmm1 commited on
Commit
7181022
1 Parent(s): 176b888

simplify load_model signature

Browse files
scripts/finetune.py CHANGED
@@ -252,14 +252,7 @@ def train(
252
 
253
  # Load the model and tokenizer
254
  LOG.info("loading model and peft_config...")
255
- model, peft_config = load_model(
256
- cfg.base_model,
257
- cfg.base_model_config,
258
- cfg.model_type,
259
- tokenizer,
260
- cfg,
261
- adapter=cfg.adapter,
262
- )
263
 
264
  if "merge_lora" in kwargs and cfg.adapter is not None:
265
  LOG.info("running merge of LoRA with base model")
 
252
 
253
  # Load the model and tokenizer
254
  LOG.info("loading model and peft_config...")
255
+ model, peft_config = load_model(cfg, tokenizer)
 
 
 
 
 
 
 
256
 
257
  if "merge_lora" in kwargs and cfg.adapter is not None:
258
  LOG.info("running merge of LoRA with base model")
src/axolotl/utils/models.py CHANGED
@@ -77,12 +77,15 @@ def load_tokenizer(
77
 
78
 
79
  def load_model(
80
- base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
81
- ):
82
- # type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
83
  """
84
- Load a model from a base model and a model type.
85
  """
 
 
 
 
86
 
87
  # TODO refactor as a kwarg
88
  load_in_8bit = cfg.load_in_8bit
 
77
 
78
 
79
  def load_model(
80
+ cfg, tokenizer
81
+ ): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
 
82
  """
83
+ Load a model for a given configuration and tokenizer.
84
  """
85
+ base_model = cfg.base_model
86
+ base_model_config = cfg.base_model_config
87
+ model_type = cfg.model_type
88
+ adapter = cfg.adapter
89
 
90
  # TODO refactor as a kwarg
91
  load_in_8bit = cfg.load_in_8bit