winglian commited on
Commit
3b4d055
1 Parent(s): 2ae936f

integrate qlora? maybe?

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. src/axolotl/utils/models.py +32 -2
requirements.txt CHANGED
@@ -1,10 +1,10 @@
1
  peft @ git+https://github.com/huggingface/peft.git
2
  transformers @ git+https://github.com/huggingface/transformers.git
 
3
  attrdict
4
  fire
5
  PyYAML==6.0
6
  black
7
- bitsandbytes==0.37.2
8
  datasets
9
  accelerate>=0.19.0
10
  sentencepiece
 
1
  peft @ git+https://github.com/huggingface/peft.git
2
  transformers @ git+https://github.com/huggingface/transformers.git
3
+ bitsandbytes @ git+https://github.com/TimDettmers/bitsandbytes.git
4
  attrdict
5
  fire
6
  PyYAML==6.0
7
  black
 
8
  datasets
9
  accelerate>=0.19.0
10
  sentencepiece
src/axolotl/utils/models.py CHANGED
@@ -6,11 +6,12 @@ from typing import Optional, Tuple, TYPE_CHECKING
6
 
7
  import torch
8
  import transformers
 
9
  from transformers import (
10
  AutoModelForCausalLM,
11
  AutoTokenizer,
12
  PreTrainedModel,
13
- AutoConfig,
14
  )
15
 
16
  try:
@@ -81,6 +82,16 @@ def load_model(
81
  logging.exception(e)
82
  raise e
83
 
 
 
 
 
 
 
 
 
 
 
84
  try:
85
  if cfg.load_4bit and is_llama_derived_model:
86
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
@@ -125,6 +136,7 @@ def load_model(
125
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
126
  torch_dtype=torch_dtype,
127
  device_map=cfg.device_map,
 
128
  )
129
  # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
130
  # This is a WIP, still an issue with the backward pass
@@ -159,6 +171,7 @@ def load_model(
159
  torch_dtype=torch_dtype,
160
  device_map=cfg.device_map,
161
  trust_remote_code=True if cfg.trust_remote_code is True else False,
 
162
  )
163
  else:
164
  config = AutoConfig.from_pretrained(
@@ -172,6 +185,7 @@ def load_model(
172
  torch_dtype=torch_dtype,
173
  device_map=cfg.device_map,
174
  trust_remote_code=True if cfg.trust_remote_code is True else False,
 
175
  )
176
  except Exception as e:
177
  logging.error(
@@ -184,8 +198,24 @@ def load_model(
184
  torch_dtype=torch_dtype,
185
  device_map=cfg.device_map,
186
  trust_remote_code=True if cfg.trust_remote_code is True else False,
 
187
  )
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  if not tokenizer:
190
  try:
191
  if is_llama_derived_model and "LlamaTokenizer" in globals():
@@ -270,7 +300,7 @@ def load_adapter(model, cfg, adapter):
270
 
271
  if adapter is None:
272
  return model, None
273
- if adapter == "lora":
274
  return load_lora(model, cfg)
275
  if adapter == "llama-adapter":
276
  return load_llama_adapter(model, cfg)
 
6
 
7
  import torch
8
  import transformers
9
+ from torch import nn
10
  from transformers import (
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
13
  PreTrainedModel,
14
+ AutoConfig, BitsAndBytesConfig,
15
  )
16
 
17
  try:
 
82
  logging.exception(e)
83
  raise e
84
 
85
+ model_kwargs = {}
86
+ if cfg.adapter == "qlora":
87
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
88
+ load_in_4bit=True,
89
+ llm_int8_threshold=6.0,
90
+ llm_int8_has_fp16_weight=False,
91
+ bnb_4bit_compute_dtype=torch.float16,
92
+ bnb_4bit_use_double_quant=True,
93
+ bnb_4bit_quant_type="nf4",
94
+ )
95
  try:
96
  if cfg.load_4bit and is_llama_derived_model:
97
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
 
136
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
137
  torch_dtype=torch_dtype,
138
  device_map=cfg.device_map,
139
+ **model_kwargs,
140
  )
141
  # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
142
  # This is a WIP, still an issue with the backward pass
 
171
  torch_dtype=torch_dtype,
172
  device_map=cfg.device_map,
173
  trust_remote_code=True if cfg.trust_remote_code is True else False,
174
+ **model_kwargs,
175
  )
176
  else:
177
  config = AutoConfig.from_pretrained(
 
185
  torch_dtype=torch_dtype,
186
  device_map=cfg.device_map,
187
  trust_remote_code=True if cfg.trust_remote_code is True else False,
188
+ **model_kwargs,
189
  )
190
  except Exception as e:
191
  logging.error(
 
198
  torch_dtype=torch_dtype,
199
  device_map=cfg.device_map,
200
  trust_remote_code=True if cfg.trust_remote_code is True else False,
201
+ **model_kwargs,
202
  )
203
 
204
+ """### Post-processing on the model
205
+ Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons.
206
+ """
207
+ if cfg.adapter == "qlora":
208
+ for param in model.parameters():
209
+ param.requires_grad = False # freeze the model - train adapters later
210
+ if param.ndim == 1:
211
+ # cast the small parameters (e.g. layernorm) to fp32 for stability
212
+ param.data = param.data.to(torch.float32)
213
+ class CastOutputToFloat(nn.Sequential):
214
+ def forward(self, x):
215
+ return super().forward(x).to(torch.float32)
216
+
217
+ model.lm_head = CastOutputToFloat(model.lm_head)
218
+
219
  if not tokenizer:
220
  try:
221
  if is_llama_derived_model and "LlamaTokenizer" in globals():
 
300
 
301
  if adapter is None:
302
  return model, None
303
+ if adapter == "lora" or adapter == "qlora":
304
  return load_lora(model, cfg)
305
  if adapter == "llama-adapter":
306
  return load_llama_adapter(model, cfg)