winglian commited on
Commit
1d21aa6
1 Parent(s): 71b7ea3

ensure merged model matches the training dtype (#902)

Browse files

* ensure merged model matches the training dtype

* Update src/axolotl/cli/__init__.py

* Update src/axolotl/cli/__init__.py

Files changed (1) hide show
  1. src/axolotl/cli/__init__.py +1 -1
src/axolotl/cli/__init__.py CHANGED
@@ -72,7 +72,7 @@ def do_merge_lora(
72
 
73
  LOG.info("running merge of LoRA with base model")
74
  model = model.merge_and_unload()
75
- model.to(dtype=torch.float16)
76
 
77
  if cfg.local_rank == 0:
78
  LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
 
72
 
73
  LOG.info("running merge of LoRA with base model")
74
  model = model.merge_and_unload()
75
+ model.to(dtype=cfg.torch_dtype)
76
 
77
  if cfg.local_rank == 0:
78
  LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")