winglian commited on
Commit
aef00b6
1 Parent(s): 0d28df0

fix torch_dtype for model load

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +6 -3
src/axolotl/utils/models.py CHANGED
@@ -62,9 +62,12 @@ def load_model(
62
  logging.info("patching with xformers attention")
63
  hijack_llama_attention()
64
 
65
- torch_dtype = (
66
- torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
67
- )
 
 
 
68
  try:
69
  if cfg.load_4bit:
70
  from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
 
62
  logging.info("patching with xformers attention")
63
  hijack_llama_attention()
64
 
65
+ if cfg.bf16:
66
+ torch_dtype = torch.bfloat16
67
+ elif cfg.load_in_8bit or cfg.fp16:
68
+ torch_dtype = torch.float16
69
+ else:
70
+ torch_dtype = torch.float32
71
  try:
72
  if cfg.load_4bit:
73
  from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (