Bryan Thornbury winglian commited on
Commit
992e742
1 Parent(s): a1da39c

Support device_map=sequential & max_memory config parameters (#903)

Browse files

* Support device_map sequential (and others). Support max_memory in cfg.

* Update documentation in README accordingly.

* Update README.md

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

README.md CHANGED
@@ -612,6 +612,12 @@ eval_sample_packing:
612
  sample_packing_eff_est:
613
  total_num_tokens:
614
 
 
 
 
 
 
 
615
  # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
616
  adapter: lora
617
  # If you already have a lora model trained that you want to load, put that here.
 
612
  sample_packing_eff_est:
613
  total_num_tokens:
614
 
615
+ # Passed through to transformers when loading the model when launched without accelerate
616
+ # Use `sequential` when training w/ model parallelism to limit memory
617
+ device_map:
618
+ # Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
619
+ max_memory:
620
+
621
  # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
622
  adapter: lora
623
  # If you already have a lora model trained that you want to load, put that here.
src/axolotl/utils/config.py CHANGED
@@ -27,7 +27,7 @@ def choose_device(cfg):
27
 
28
  cfg.device = get_device()
29
  if cfg.world_size == 1:
30
- cfg.device_map = "auto"
31
  else:
32
  if cfg.device.startswith("cuda"):
33
  cfg.device_map = {"": torch.cuda.current_device()}
 
27
 
28
  cfg.device = get_device()
29
  if cfg.world_size == 1:
30
+ cfg.device_map = cfg.device_map or "auto"
31
  else:
32
  if cfg.device.startswith("cuda"):
33
  cfg.device_map = {"": torch.cuda.current_device()}
src/axolotl/utils/models.py CHANGED
@@ -216,6 +216,7 @@ def load_model(
216
  model_kwargs = {}
217
 
218
  model_kwargs["device_map"] = cfg.device_map
 
219
  model_kwargs["torch_dtype"] = cfg.torch_dtype
220
 
221
  if cfg.model_revision:
 
216
  model_kwargs = {}
217
 
218
  model_kwargs["device_map"] = cfg.device_map
219
+ model_kwargs["max_memory"] = cfg.max_memory
220
  model_kwargs["torch_dtype"] = cfg.torch_dtype
221
 
222
  if cfg.model_revision: