winglian commited on
Commit
827ec3d
1 Parent(s): 8b79ff0

refactor neft patch to be more re-usable similar to trl's impl (#796)

Browse files
gitbook/README.md CHANGED
@@ -1,2 +1 @@
1
  # Page
2
-
 
1
  # Page
 
src/axolotl/monkeypatch/llama_embeddings_hijack.py DELETED
@@ -1,40 +0,0 @@
1
- """
2
- patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
3
- """
4
-
5
- import torch
6
- import transformers.models.llama.modeling_llama
7
- from transformers.utils import logging
8
-
9
- logger = logging.get_logger(__name__)
10
-
11
-
12
- def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
13
- # pylint: disable=duplicate-code
14
- def noised_embed(orig_embed, noise_alpha, model):
15
- def new_func(input_ids):
16
- # during training, we add noise to the embedding
17
- # during generation, we don't add noise to the embedding
18
- if model.training:
19
- embed_init = orig_embed(input_ids)
20
- dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
21
- mag_norm = noise_alpha / torch.sqrt(dims)
22
- return embed_init + torch.zeros_like(embed_init).uniform_(
23
- -mag_norm, mag_norm
24
- )
25
- return orig_embed(input_ids)
26
-
27
- return new_func
28
-
29
- def post_init(orig_post_init):
30
- def new_func(self):
31
- orig_post_init(self)
32
- self.embed_tokens.forward = noised_embed(
33
- self.embed_tokens.forward, noise_alpha, self
34
- )
35
-
36
- return new_func
37
-
38
- transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
39
- transformers.models.llama.modeling_llama.LlamaModel.post_init
40
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/monkeypatch/mistral_embeddings_hijack.py DELETED
@@ -1,40 +0,0 @@
1
- """
2
- patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
3
- """
4
-
5
- import torch
6
- import transformers.models.mistral.modeling_mistral
7
- from transformers.utils import logging
8
-
9
- logger = logging.get_logger(__name__)
10
-
11
-
12
- def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
13
- # pylint: disable=duplicate-code
14
- def noised_embed(orig_embed, noise_alpha, model):
15
- def new_func(input_ids):
16
- # during training, we add noise to the embedding
17
- # during generation, we don't add noise to the embedding
18
- if model.training:
19
- embed_init = orig_embed(input_ids)
20
- dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
21
- mag_norm = noise_alpha / torch.sqrt(dims)
22
- return embed_init + torch.zeros_like(embed_init).uniform_(
23
- -mag_norm, mag_norm
24
- )
25
- return orig_embed(input_ids)
26
-
27
- return new_func
28
-
29
- def post_init(orig_post_init):
30
- def new_func(self):
31
- orig_post_init(self)
32
- self.embed_tokens.forward = noised_embed(
33
- self.embed_tokens.forward, noise_alpha, self
34
- )
35
-
36
- return new_func
37
-
38
- transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
39
- transformers.models.mistral.modeling_mistral.MistralModel.post_init
40
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/monkeypatch/neft_embeddings.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
3
+ """
4
+ import torch
5
+ from peft import PeftModel
6
+ from transformers import PreTrainedModel
7
+
8
+
9
+ def patch_neft(alpha, model):
10
+ embeddings = None
11
+ if isinstance(model, PreTrainedModel):
12
+ embeddings = model.get_input_embeddings()
13
+ if isinstance(model, PeftModel):
14
+ embeddings = model.base_model.get_input_embeddings()
15
+ if not embeddings:
16
+ raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
17
+ embeddings.noisy_embedding_alpha = alpha
18
+ old_forward = embeddings.forward
19
+
20
+ # This hack seems to be needed to properly use a custom forward pass
21
+ # all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
22
+ bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
23
+ embeddings, embeddings.__class__
24
+ )
25
+ setattr(embeddings, "forward", bound_method)
26
+
27
+ embeddings._old_forward = old_forward # pylint: disable=protected-access
28
+ return model
29
+
30
+
31
+ def unpatch_neft(model):
32
+ embeddings = None
33
+ if isinstance(model, PreTrainedModel):
34
+ embeddings = model.get_input_embeddings()
35
+ if isinstance(model, PeftModel):
36
+ embeddings = model.base_model.get_input_embeddings()
37
+ if not embeddings:
38
+ raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
39
+ if hasattr(embeddings, "_old_forward"):
40
+ embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
41
+ del embeddings._old_forward # pylint: disable=protected-access
42
+ del embeddings.noisy_embedding_alpha
43
+
44
+
45
+ def neft_forward(self, inputs: torch.Tensor):
46
+ embeddings = self._old_forward(inputs) # pylint: disable=protected-access
47
+
48
+ if self.training:
49
+ dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
50
+ mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
51
+ embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
52
+ -mag_norm, mag_norm
53
+ )
54
+
55
+ return embeddings
56
+
57
+
58
+ def pretrain_hook(cfg, trainer):
59
+ if cfg.noisy_embedding_alpha:
60
+ trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
61
+
62
+
63
+ def post_train_hook(cfg, trainer):
64
+ if cfg.noisy_embedding_alpha:
65
+ unpatch_neft(trainer.model)
src/axolotl/train.py CHANGED
@@ -16,6 +16,7 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
16
 
17
  from axolotl.common.cli import TrainerCliArgs
18
  from axolotl.logging_config import configure_logging
 
19
  from axolotl.utils.dict import DictDefault
20
  from axolotl.utils.models import load_model, load_tokenizer
21
  from axolotl.utils.trainer import setup_trainer
@@ -107,6 +108,7 @@ def train(
107
  if cfg.group_by_length:
108
  LOG.info("hang tight... sorting dataset for group_by_length")
109
 
 
110
  if cfg.flash_optimum:
111
  with torch.backends.cuda.sdp_kernel(
112
  enable_flash=True, enable_math=True, enable_mem_efficient=True
@@ -114,6 +116,7 @@ def train(
114
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
115
  else:
116
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
 
117
 
118
  LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
119
 
@@ -163,3 +166,23 @@ def train(
163
  trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
164
 
165
  return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  from axolotl.common.cli import TrainerCliArgs
18
  from axolotl.logging_config import configure_logging
19
+ from axolotl.monkeypatch import neft_embeddings
20
  from axolotl.utils.dict import DictDefault
21
  from axolotl.utils.models import load_model, load_tokenizer
22
  from axolotl.utils.trainer import setup_trainer
 
108
  if cfg.group_by_length:
109
  LOG.info("hang tight... sorting dataset for group_by_length")
110
 
111
+ pretrain_hooks(cfg, trainer)
112
  if cfg.flash_optimum:
113
  with torch.backends.cuda.sdp_kernel(
114
  enable_flash=True, enable_math=True, enable_mem_efficient=True
 
116
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
117
  else:
118
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
119
+ post_train_hooks(cfg, trainer)
120
 
121
  LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
122
 
 
166
  trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
167
 
168
  return model, tokenizer
169
+
170
+
171
+ def pretrain_hooks(cfg, trainer):
172
+ """
173
+ Run hooks right before kicking off the training
174
+ :param cfg:
175
+ :param trainer:
176
+ :return:
177
+ """
178
+ neft_embeddings.pretrain_hook(cfg, trainer)
179
+
180
+
181
+ def post_train_hooks(cfg, trainer):
182
+ """
183
+ Run hooks right after training completes
184
+ :param cfg:
185
+ :param trainer:
186
+ :return:
187
+ """
188
+ neft_embeddings.post_train_hook(cfg, trainer)
src/axolotl/utils/models.py CHANGED
@@ -180,26 +180,6 @@ def load_model(
180
  LOG.info("patching with flash attention")
181
  replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
182
 
183
- if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
184
- from axolotl.monkeypatch.llama_embeddings_hijack import (
185
- replace_llama_embeddings_with_uniform_distribution,
186
- )
187
-
188
- LOG.info("patching with noisy embeddings")
189
- replace_llama_embeddings_with_uniform_distribution(
190
- noise_alpha=cfg.noisy_embedding_alpha
191
- )
192
-
193
- if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
194
- from axolotl.monkeypatch.mistral_embeddings_hijack import (
195
- replace_mistral_embeddings_with_uniform_distribution,
196
- )
197
-
198
- LOG.info("patching with noisy embeddings")
199
- replace_mistral_embeddings_with_uniform_distribution(
200
- noise_alpha=cfg.noisy_embedding_alpha
201
- )
202
-
203
  if cfg.is_llama_derived_model and cfg.xpos_rope:
204
  from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
205
  replace_llama_rope_with_xpos_rope,
 
180
  LOG.info("patching with flash attention")
181
  replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  if cfg.is_llama_derived_model and cfg.xpos_rope:
184
  from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
185
  replace_llama_rope_with_xpos_rope,