winglian commited on
Commit
42410c7
1 Parent(s): aef00b6

more fixes

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +6 -2
src/axolotl/utils/models.py CHANGED
@@ -184,7 +184,8 @@ def load_model(
184
  for k, v in cfg.tokens.items():
185
  tokenizer.add_special_tokens({k: v})
186
 
187
- model.resize_token_embeddings(len(tokenizer))
 
188
 
189
  if cfg.adapter and load_in_8bit and not cfg.load_4bit:
190
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
@@ -207,7 +208,10 @@ def load_model(
207
  m.scales = m.scales.half()
208
  m.bias = m.bias.half()
209
 
210
- if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1:
 
 
 
211
  model.is_parallelizable = True
212
  model.model_parallel = True
213
 
 
184
  for k, v in cfg.tokens.items():
185
  tokenizer.add_special_tokens({k: v})
186
 
187
+ # this should only be needed if you are messing with new tokens in the vocab
188
+ # model.resize_token_embeddings(len(tokenizer))
189
 
190
  if cfg.adapter and load_in_8bit and not cfg.load_4bit:
191
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
 
208
  m.scales = m.scales.half()
209
  m.bias = m.bias.half()
210
 
211
+ if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1 and cfg.load_4bit:
212
+ # llama is PROBABLY model parallelizable, but the default isn't that it is
213
+ # so let's only set it for the 4bit, see
214
+ # https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
215
  model.is_parallelizable = True
216
  model.model_parallel = True
217