winglian commited on
Commit
eb80890
1 Parent(s): 3f3f561

fix llama check

Browse files
Files changed (1) hide show
  1. scripts/finetune.py +12 -4
scripts/finetune.py CHANGED
@@ -60,12 +60,14 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
60
  # TODO refactor as a kwarg
61
  load_in_8bit = cfg.load_in_8bit
62
  tokenizer = None
 
63
 
64
  if adapter != "lora":
65
  raise NotImplementedError(f"{adapter} peft adapter not available")
66
- if "llama" in base_model and cfg.flash_attention:
67
  if cfg.device not in ["mps", "cpu"] and inference is False:
68
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
 
69
  replace_llama_attn_with_flash_attn()
70
 
71
  torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
@@ -85,7 +87,7 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
85
  raise e
86
 
87
  try:
88
- if cfg.load_4bit and ("llama" in base_model or "llama" in cfg.model_type.lower()):
89
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
90
  from huggingface_hub import snapshot_download
91
 
@@ -104,7 +106,7 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
104
  is_v1_model=cfg.gptq_model_v1 if cfg.gptq_model_v1 is not None else True,
105
  )
106
  load_in_8bit = False
107
- elif "llama" in base_model:
108
  model = LlamaForCausalLM.from_pretrained(
109
  base_model,
110
  load_in_8bit=cfg.load_in_8bit,
@@ -128,13 +130,18 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
128
 
129
  if not tokenizer:
130
  try:
131
- if "llama" in base_model:
132
  tokenizer = LlamaTokenizer.from_pretrained(model)
133
  else:
134
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
135
  except:
136
  tokenizer = AutoTokenizer.from_pretrained(base_model)
137
 
 
 
 
 
 
138
  if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
139
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
140
 
@@ -144,6 +151,7 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
144
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
145
 
146
  if load_in_8bit and not cfg.load_4bit:
 
147
  model = prepare_model_for_int8_training(model)
148
 
149
  lora_config = LoraConfig(
 
60
  # TODO refactor as a kwarg
61
  load_in_8bit = cfg.load_in_8bit
62
  tokenizer = None
63
+ is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower()
64
 
65
  if adapter != "lora":
66
  raise NotImplementedError(f"{adapter} peft adapter not available")
67
+ if is_llama_derived_model and cfg.flash_attention:
68
  if cfg.device not in ["mps", "cpu"] and inference is False:
69
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
70
+ logging.info("patching with flash attention")
71
  replace_llama_attn_with_flash_attn()
72
 
73
  torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
 
87
  raise e
88
 
89
  try:
90
+ if cfg.load_4bit and is_llama_derived_model:
91
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
92
  from huggingface_hub import snapshot_download
93
 
 
106
  is_v1_model=cfg.gptq_model_v1 if cfg.gptq_model_v1 is not None else True,
107
  )
108
  load_in_8bit = False
109
+ elif is_llama_derived_model:
110
  model = LlamaForCausalLM.from_pretrained(
111
  base_model,
112
  load_in_8bit=cfg.load_in_8bit,
 
130
 
131
  if not tokenizer:
132
  try:
133
+ if is_llama_derived_model:
134
  tokenizer = LlamaTokenizer.from_pretrained(model)
135
  else:
136
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
137
  except:
138
  tokenizer = AutoTokenizer.from_pretrained(base_model)
139
 
140
+ logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
141
+ logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
142
+ logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
143
+ logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
144
+
145
  if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
146
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
147
 
 
151
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
152
 
153
  if load_in_8bit and not cfg.load_4bit:
154
+ logging.info("converting model w/ prepare_model_for_int8_training")
155
  model = prepare_model_for_int8_training(model)
156
 
157
  lora_config = LoraConfig(