Nanobit commited on
Commit
1ffa386
1 Parent(s): 62ba160

Feat: Warns to add to modules_to_save when adding tokens or switching special_tokens (#787)

Browse files

* Feat: Auto add to modules_to_save when adding tokens

* fix: swap to error instead of warning

* feat: add check when special_tokens differ and add test

src/axolotl/utils/config.py CHANGED
@@ -448,6 +448,20 @@ def validate_config(cfg):
448
  if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
449
  raise ValueError("neftune_noise_alpha must be > 0.0")
450
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  # TODO
452
  # MPT 7b
453
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
448
  if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
449
  raise ValueError("neftune_noise_alpha must be > 0.0")
450
 
451
+ if (
452
+ cfg.adapter
453
+ and cfg.tokens
454
+ and (
455
+ not cfg.lora_modules_to_save
456
+ or not all(
457
+ x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
458
+ )
459
+ )
460
+ ):
461
+ raise ValueError(
462
+ "lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
463
+ )
464
+
465
  # TODO
466
  # MPT 7b
467
  # https://github.com/facebookresearch/bitsandbytes/issues/25
src/axolotl/utils/models.py CHANGED
@@ -136,6 +136,23 @@ def load_tokenizer(cfg):
136
 
137
  if cfg.special_tokens:
138
  for k, val in cfg.special_tokens.items():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  tokenizer.add_special_tokens(
140
  {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
141
  )
 
136
 
137
  if cfg.special_tokens:
138
  for k, val in cfg.special_tokens.items():
139
+ # check if new special token is not already in tokenizer and
140
+ # is adapter training to make sure lora_modules_to_save is set
141
+ if (
142
+ (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
143
+ and cfg.adapter
144
+ and (
145
+ not cfg.lora_modules_to_save
146
+ or not all(
147
+ x in cfg.lora_modules_to_save
148
+ for x in ["embed_tokens", "lm_head"]
149
+ )
150
+ )
151
+ ):
152
+ raise ValueError(
153
+ "Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
154
+ )
155
+
156
  tokenizer.add_special_tokens(
157
  {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
158
  )
tests/test_tokenizers.py CHANGED
@@ -3,6 +3,8 @@ Test cases for the tokenizer loading
3
  """
4
  import unittest
5
 
 
 
6
  from axolotl.utils.dict import DictDefault
7
  from axolotl.utils.models import load_tokenizer
8
 
@@ -31,6 +33,40 @@ class TestTokenizers(unittest.TestCase):
31
  tokenizer = load_tokenizer(cfg)
32
  assert "Fast" not in tokenizer.__class__.__name__
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  if __name__ == "__main__":
36
  unittest.main()
 
3
  """
4
  import unittest
5
 
6
+ import pytest
7
+
8
  from axolotl.utils.dict import DictDefault
9
  from axolotl.utils.models import load_tokenizer
10
 
 
33
  tokenizer = load_tokenizer(cfg)
34
  assert "Fast" not in tokenizer.__class__.__name__
35
 
36
+ def test_special_tokens_modules_to_save(self):
37
+ # setting special_tokens to new token
38
+ cfg = DictDefault(
39
+ {
40
+ "tokenizer_config": "huggyllama/llama-7b",
41
+ "adapter": "lora",
42
+ "special_tokens": {"bos_token": "[INST]"},
43
+ }
44
+ )
45
+ with pytest.raises(
46
+ ValueError,
47
+ match=r".*Please set lora_modules_to_save*",
48
+ ):
49
+ load_tokenizer(cfg)
50
+
51
+ # setting special_tokens but not changing from default
52
+ cfg = DictDefault(
53
+ {
54
+ "tokenizer_config": "huggyllama/llama-7b",
55
+ "adapter": "lora",
56
+ "special_tokens": {"bos_token": "<s>"},
57
+ }
58
+ )
59
+ load_tokenizer(cfg)
60
+
61
+ # non-adapter setting special_tokens
62
+ cfg = DictDefault(
63
+ {
64
+ "tokenizer_config": "huggyllama/llama-7b",
65
+ "special_tokens": {"bos_token": "[INST]"},
66
+ }
67
+ )
68
+ load_tokenizer(cfg)
69
+
70
 
71
  if __name__ == "__main__":
72
  unittest.main()
tests/test_validation.py CHANGED
@@ -682,6 +682,43 @@ class ValidationTest(unittest.TestCase):
682
 
683
  validate_config(cfg)
684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
 
686
  class ValidationWandbTest(ValidationTest):
687
  """
 
682
 
683
  validate_config(cfg)
684
 
685
+ def test_add_tokens_adapter(self):
686
+ cfg = DictDefault(
687
+ {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
688
+ )
689
+
690
+ with pytest.raises(
691
+ ValueError,
692
+ match=r".*lora_modules_to_save not properly set yet adding new tokens*",
693
+ ):
694
+ validate_config(cfg)
695
+
696
+ cfg = DictDefault(
697
+ {
698
+ "adapter": "qlora",
699
+ "load_in_4bit": True,
700
+ "tokens": ["<|imstart|>"],
701
+ "lora_modules_to_save": ["embed_tokens"],
702
+ }
703
+ )
704
+
705
+ with pytest.raises(
706
+ ValueError,
707
+ match=r".*lora_modules_to_save not properly set yet adding new tokens*",
708
+ ):
709
+ validate_config(cfg)
710
+
711
+ cfg = DictDefault(
712
+ {
713
+ "adapter": "qlora",
714
+ "load_in_4bit": True,
715
+ "tokens": ["<|imstart|>"],
716
+ "lora_modules_to_save": ["embed_tokens", "lm_head"],
717
+ }
718
+ )
719
+
720
+ validate_config(cfg)
721
+
722
 
723
  class ValidationWandbTest(ValidationTest):
724
  """