tmm1 commited on
Commit
11ddccb
2 Parent(s): 9643121 7181022

Merge pull request #356 from tmm1/load_model-args

Browse files
scripts/finetune.py CHANGED
@@ -255,14 +255,7 @@ def train(
255
 
256
  # Load the model and tokenizer
257
  LOG.info("loading model and peft_config...")
258
- model, peft_config = load_model(
259
- cfg.base_model,
260
- cfg.base_model_config,
261
- cfg.model_type,
262
- tokenizer,
263
- cfg,
264
- adapter=cfg.adapter,
265
- )
266
 
267
  if "merge_lora" in kwargs and cfg.adapter is not None:
268
  LOG.info("running merge of LoRA with base model")
 
255
 
256
  # Load the model and tokenizer
257
  LOG.info("loading model and peft_config...")
258
+ model, peft_config = load_model(cfg, tokenizer)
 
 
 
 
 
 
 
259
 
260
  if "merge_lora" in kwargs and cfg.adapter is not None:
261
  LOG.info("running merge of LoRA with base model")
src/axolotl/utils/models.py CHANGED
@@ -78,12 +78,15 @@ def load_tokenizer(
78
 
79
 
80
  def load_model(
81
- base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
82
- ):
83
- # type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
84
  """
85
- Load a model from a base model and a model type.
86
  """
 
 
 
 
87
 
88
  # TODO refactor as a kwarg
89
  load_in_8bit = cfg.load_in_8bit
 
78
 
79
 
80
  def load_model(
81
+ cfg, tokenizer
82
+ ): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
 
83
  """
84
+ Load a model for a given configuration and tokenizer.
85
  """
86
+ base_model = cfg.base_model
87
+ base_model_config = cfg.base_model_config
88
+ model_type = cfg.model_type
89
+ adapter = cfg.adapter
90
 
91
  # TODO refactor as a kwarg
92
  load_in_8bit = cfg.load_in_8bit