Nanobit hamel commited on
Commit
41353d2
1 Parent(s): f6ecf14

feat: expose bnb kwargs (#1018)

Browse files

* feat: expose bnb kwargs

* chore: added examples and link per suggestion

* Uncomment defaults per suggestion for readability

Co-authored-by: Hamel Husain <hamel.husain@gmail.com>

---------

Co-authored-by: Hamel Husain <hamel.husain@gmail.com>

Files changed (2) hide show
  1. README.md +8 -0
  2. src/axolotl/utils/models.py +13 -6
README.md CHANGED
@@ -520,6 +520,14 @@ model_config:
520
  type: # linear | dynamic
521
  factor: # float
522
 
 
 
 
 
 
 
 
 
523
 
524
  # Whether you are training a 4-bit GPTQ quantized model
525
  gptq: true
 
520
  type: # linear | dynamic
521
  factor: # float
522
 
523
+ # optional overrides to the bnb 4bit quantization configuration
524
+ # https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
525
+ bnb_config_kwargs:
526
+ # These are default values
527
+ llm_int8_has_fp16_weight: false
528
+ bnb_4bit_quant_type: nf4
529
+ bnb_4bit_use_double_quant: true
530
+
531
 
532
  # Whether you are training a 4-bit GPTQ quantized model
533
  gptq: true
src/axolotl/utils/models.py CHANGED
@@ -301,13 +301,20 @@ def load_model(
301
  **model_config.quantization_config
302
  )
303
  if cfg.adapter == "qlora" and cfg.load_in_4bit:
 
 
 
 
 
 
 
 
 
 
 
 
304
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
305
- load_in_4bit=True,
306
- llm_int8_threshold=6.0,
307
- llm_int8_has_fp16_weight=False,
308
- bnb_4bit_compute_dtype=cfg.torch_dtype,
309
- bnb_4bit_use_double_quant=True,
310
- bnb_4bit_quant_type="nf4",
311
  )
312
  # sample packing uses custom FA2 patch
313
  if cfg.flash_attention:
 
301
  **model_config.quantization_config
302
  )
303
  if cfg.adapter == "qlora" and cfg.load_in_4bit:
304
+ bnb_config = {
305
+ "load_in_4bit": True,
306
+ "llm_int8_threshold": 6.0,
307
+ "llm_int8_has_fp16_weight": False,
308
+ "bnb_4bit_compute_dtype": cfg.torch_dtype,
309
+ "bnb_4bit_use_double_quant": True,
310
+ "bnb_4bit_quant_type": "nf4",
311
+ }
312
+
313
+ if cfg.bnb_config_kwargs:
314
+ bnb_config.update(cfg.bnb_config_kwargs)
315
+
316
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
317
+ **bnb_config,
 
 
 
 
 
318
  )
319
  # sample packing uses custom FA2 patch
320
  if cfg.flash_attention: