winglian commited on
Commit
32e6fe9
1 Parent(s): bbfc333

load the tokenizer seperately from the model

Browse files
Files changed (2) hide show
  1. scripts/finetune.py +21 -12
  2. src/axolotl/utils/models.py +41 -42
scripts/finetune.py CHANGED
@@ -21,7 +21,7 @@ src_dir = os.path.join(project_root, "src")
21
  sys.path.insert(0, src_dir)
22
 
23
  from axolotl.utils.data import load_prepare_datasets
24
- from axolotl.utils.models import load_model
25
  from axolotl.utils.trainer import setup_trainer
26
  from axolotl.utils.wandb import setup_wandb_env_vars
27
 
@@ -161,13 +161,30 @@ def train(
161
 
162
  validate_config(cfg)
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  # Load the model and tokenizer
165
- logging.info("loading model, tokenizer, and peft_config...")
166
- model, tokenizer, peft_config = load_model(
167
  cfg.base_model,
168
  cfg.base_model_config,
169
  cfg.model_type,
170
- cfg.tokenizer_type,
171
  cfg,
172
  adapter=cfg.adapter,
173
  inference=("inference" in kwargs),
@@ -192,10 +209,6 @@ def train(
192
  model.save_pretrained(cfg.output_dir)
193
  return
194
 
195
- train_dataset, eval_dataset = load_prepare_datasets(
196
- tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
197
- )
198
-
199
  if cfg.debug:
200
  logging.info("check_dataset_labels...")
201
  check_dataset_labels(
@@ -205,10 +218,6 @@ def train(
205
  tokenizer,
206
  )
207
 
208
- if prepare_ds_only:
209
- logging.info("Finished preparing dataset. Exiting...")
210
- return
211
-
212
  trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
213
 
214
  model.config.use_cache = False
 
21
  sys.path.insert(0, src_dir)
22
 
23
  from axolotl.utils.data import load_prepare_datasets
24
+ from axolotl.utils.models import load_model, load_tokenizer
25
  from axolotl.utils.trainer import setup_trainer
26
  from axolotl.utils.wandb import setup_wandb_env_vars
27
 
 
161
 
162
  validate_config(cfg)
163
 
164
+ # load the tokenizer first
165
+ logging.info("loading tokenizer...")
166
+ tokenizer = load_tokenizer(
167
+ cfg.base_model_config,
168
+ cfg.tokenizer_type,
169
+ cfg
170
+ )
171
+
172
+ if "inference" not in kwargs and "shard" not in kwargs: # don't need to load dataset for these
173
+ train_dataset, eval_dataset = load_prepare_datasets(
174
+ tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
175
+ )
176
+
177
+ if prepare_ds_only:
178
+ logging.info("Finished preparing dataset. Exiting...")
179
+ return
180
+
181
  # Load the model and tokenizer
182
+ logging.info("loading model and peft_config...")
183
+ model, peft_config = load_model(
184
  cfg.base_model,
185
  cfg.base_model_config,
186
  cfg.model_type,
187
+ tokenizer,
188
  cfg,
189
  adapter=cfg.adapter,
190
  inference=("inference" in kwargs),
 
209
  model.save_pretrained(cfg.output_dir)
210
  return
211
 
 
 
 
 
212
  if cfg.debug:
213
  logging.info("check_dataset_labels...")
214
  check_dataset_labels(
 
218
  tokenizer,
219
  )
220
 
 
 
 
 
221
  trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
222
 
223
  model.config.use_cache = False
src/axolotl/utils/models.py CHANGED
@@ -7,7 +7,6 @@ from typing import Optional, Tuple, TYPE_CHECKING
7
  import bitsandbytes as bnb
8
  import torch
9
  import transformers
10
- from torch import nn
11
  from transformers import (
12
  AutoModelForCausalLM,
13
  AutoTokenizer,
@@ -34,20 +33,56 @@ if TYPE_CHECKING:
34
  from transformers import PreTrainedTokenizer
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def load_model(
38
  base_model,
39
  base_model_config,
40
  model_type,
41
- tokenizer_type,
42
  cfg,
43
  adapter="lora",
44
  inference=False,
45
  ):
46
- # type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
47
 
48
  # TODO refactor as a kwarg
49
  load_in_8bit = cfg.load_in_8bit
50
- tokenizer = None
51
  is_llama_derived_model = "llama" in base_model or (
52
  cfg.model_type and "llama" in cfg.model_type.lower()
53
  )
@@ -122,7 +157,7 @@ def load_model(
122
  model_path = str(cache_model_path)
123
  except:
124
  model_path = cfg.base_model
125
- model, tokenizer = load_llama_model_4bit_low_ram(
126
  base_model_config if base_model_config else base_model,
127
  model_path,
128
  device_map=cfg.device_map,
@@ -207,42 +242,6 @@ def load_model(
207
  **model_kwargs,
208
  )
209
 
210
- if not tokenizer:
211
- try:
212
- if is_llama_derived_model and "LlamaTokenizer" in globals():
213
- tokenizer = LlamaTokenizer.from_pretrained(
214
- base_model_config,
215
- trust_remote_code=True if cfg.trust_remote_code is True else False,
216
- )
217
- else:
218
- tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
219
- base_model_config,
220
- trust_remote_code=True if cfg.trust_remote_code is True else False,
221
- )
222
- except:
223
- tokenizer = AutoTokenizer.from_pretrained(
224
- base_model_config,
225
- trust_remote_code=True if cfg.trust_remote_code is True else False,
226
- )
227
-
228
- logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
229
- logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
230
- logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
231
- logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
232
-
233
- if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
234
- tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
235
-
236
- if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
237
- tokenizer.add_special_tokens({"pad_token": "[PAD]"})
238
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
239
-
240
- if cfg.special_tokens:
241
- for k, v in cfg.special_tokens.items():
242
- tokenizer.add_special_tokens({k: v})
243
- if cfg.tokens:
244
- tokenizer.add_tokens(list(cfg.tokens))
245
-
246
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
247
  model.resize_token_embeddings(embeddings_len)
248
 
@@ -291,7 +290,7 @@ def load_model(
291
  model.config.use_cache = False
292
 
293
  # TODO resume_from_checkpoint handling
294
- return model, tokenizer, lora_config
295
 
296
 
297
  def load_adapter(model, cfg, adapter):
 
7
  import bitsandbytes as bnb
8
  import torch
9
  import transformers
 
10
  from transformers import (
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
 
33
  from transformers import PreTrainedTokenizer
34
 
35
 
36
+ def load_tokenizer(
37
+ base_model_config,
38
+ tokenizer_type,
39
+ cfg,
40
+ ):
41
+ if tokenizer_type:
42
+ tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
43
+ base_model_config,
44
+ trust_remote_code=True if cfg.trust_remote_code is True else False,
45
+ )
46
+ else:
47
+ tokenizer = AutoTokenizer.from_pretrained(
48
+ base_model_config,
49
+ trust_remote_code=True if cfg.trust_remote_code is True else False,
50
+ )
51
+
52
+ logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
53
+ logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
54
+ logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
55
+ logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
56
+
57
+ if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
58
+ tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
59
+
60
+ if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
61
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
62
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
63
+
64
+ if cfg.special_tokens:
65
+ for k, v in cfg.special_tokens.items():
66
+ tokenizer.add_special_tokens({k: v})
67
+ if cfg.tokens:
68
+ tokenizer.add_tokens(list(cfg.tokens))
69
+
70
+ return tokenizer
71
+
72
+
73
  def load_model(
74
  base_model,
75
  base_model_config,
76
  model_type,
77
+ tokenizer,
78
  cfg,
79
  adapter="lora",
80
  inference=False,
81
  ):
82
+ # type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
83
 
84
  # TODO refactor as a kwarg
85
  load_in_8bit = cfg.load_in_8bit
 
86
  is_llama_derived_model = "llama" in base_model or (
87
  cfg.model_type and "llama" in cfg.model_type.lower()
88
  )
 
157
  model_path = str(cache_model_path)
158
  except:
159
  model_path = cfg.base_model
160
+ model, _ = load_llama_model_4bit_low_ram(
161
  base_model_config if base_model_config else base_model,
162
  model_path,
163
  device_map=cfg.device_map,
 
242
  **model_kwargs,
243
  )
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
246
  model.resize_token_embeddings(embeddings_len)
247
 
 
290
  model.config.use_cache = False
291
 
292
  # TODO resume_from_checkpoint handling
293
+ return model, lora_config
294
 
295
 
296
  def load_adapter(model, cfg, adapter):