winglian commited on
Commit
69164da
1 Parent(s): e107643

imrpove llama check and fix safetensors file check

Browse files
Files changed (1) hide show
  1. scripts/finetune.py +2 -4
scripts/finetune.py CHANGED
@@ -85,14 +85,12 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
85
  raise e
86
 
87
  try:
88
- if cfg.load_4bit and "llama" in base_model:
89
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
90
  from huggingface_hub import snapshot_download
91
 
92
  cache_model_path = Path(snapshot_download(base_model))
93
- # TODO search .glob for a .pt, .safetensor, or .bin
94
- cache_model_path.glob("*.pt")
95
- files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensor')) + list(cache_model_path.glob('*.bin'))
96
  if len(files) > 0:
97
  model_path = str(files[0])
98
  else:
 
85
  raise e
86
 
87
  try:
88
+ if cfg.load_4bit and "llama" in base_model or "llama" in cfg.model_type.lower():
89
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
90
  from huggingface_hub import snapshot_download
91
 
92
  cache_model_path = Path(snapshot_download(base_model))
93
+ files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensors')) + list(cache_model_path.glob('*.bin'))
 
 
94
  if len(files) > 0:
95
  model_path = str(files[0])
96
  else: