winglian commited on
Commit
964d858
1 Parent(s): 10388a8

fix model parallel (#816)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +1 -8
src/axolotl/utils/models.py CHANGED
@@ -442,14 +442,7 @@ def load_model(
442
  if cfg.ddp and not load_in_8bit:
443
  model.to(f"cuda:{cfg.local_rank}")
444
 
445
- if (
446
- torch.cuda.device_count() > 1
447
- and int(os.getenv("WORLD_SIZE", "1")) > 1
448
- and (cfg.load_in_4bit)
449
- ):
450
- # llama is PROBABLY model parallelizable, but the default isn't that it is
451
- # so let's only set it for the 4bit, see
452
- # https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
453
  setattr(model, "is_parallelizable", True)
454
  setattr(model, "model_parallel", True)
455
 
 
442
  if cfg.ddp and not load_in_8bit:
443
  model.to(f"cuda:{cfg.local_rank}")
444
 
445
+ if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
 
 
 
 
 
 
 
446
  setattr(model, "is_parallelizable", True)
447
  setattr(model, "model_parallel", True)
448