winglian commited on
Commit
9190ada
1 Parent(s): 4dbef09

8bit and deepspeed changes

Browse files
Files changed (2) hide show
  1. ds_config.json +5 -3
  2. src/axolotl/utils/models.py +6 -13
ds_config.json CHANGED
@@ -20,10 +20,12 @@
20
  }
21
  },
22
  "scheduler": {
23
- "type": "OneCycle",
24
  "params": {
25
- "cycle_min_lr": 1e-7,
26
- "cycle_max_lr": 1e-4
 
 
27
  }
28
  },
29
  "zero_optimization": {
 
20
  }
21
  },
22
  "scheduler": {
23
+ "type": "WarmupDecayLR",
24
  "params": {
25
+ "warmup_min_lr": "auto",
26
+ "warmup_max_lr": "auto",
27
+ "warmup_num_steps": "auto",
28
+ "total_num_steps": "auto"
29
  }
30
  },
31
  "zero_optimization": {
src/axolotl/utils/models.py CHANGED
@@ -101,19 +101,12 @@ def load_model(
101
  )
102
  load_in_8bit = False
103
  elif is_llama_derived_model and "LlamaForCausalLM" in globals():
104
- if not cfg.load_in_8bit:
105
- model = LlamaForCausalLM.from_pretrained(
106
- base_model,
107
- device_map=cfg.device_map,
108
- )
109
- else:
110
- model = LlamaForCausalLM.from_pretrained(
111
- base_model,
112
- load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
113
- torch_dtype=torch_dtype,
114
- device_map=cfg.device_map,
115
- )
116
-
117
  elif model_type:
118
  model = getattr(transformers, model_type).from_pretrained(
119
  base_model,
 
101
  )
102
  load_in_8bit = False
103
  elif is_llama_derived_model and "LlamaForCausalLM" in globals():
104
+ model = LlamaForCausalLM.from_pretrained(
105
+ base_model,
106
+ load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
107
+ torch_dtype=torch_dtype,
108
+ device_map=cfg.device_map,
109
+ )
 
 
 
 
 
 
 
110
  elif model_type:
111
  model = getattr(transformers, model_type).from_pretrained(
112
  base_model,