winglian commited on
Commit
488a67d
1 Parent(s): 71a43f8

experimental expansion of ctx len

Browse files
Files changed (2) hide show
  1. scripts/finetune.py +26 -18
  2. src/axolotl/utils/data.py +31 -1
scripts/finetune.py CHANGED
@@ -6,22 +6,20 @@ import os
6
  import random
7
  import signal
8
  import sys
9
- from functools import partial
10
  from pathlib import Path
11
  from typing import Any, Dict, List, Optional, Union
12
 
13
  import fire
14
  import torch
15
  import yaml
16
- from transformers import GenerationConfig, TextStreamer
17
-
18
- from axolotl.utils.data import load_prepare_datasets
19
- from axolotl.utils.dict import DictDefault
20
- from axolotl.utils.models import load_model, load_tokenizer
21
 
22
  # add src to the pythonpath so we don't need to pip install this
23
  from optimum.bettertransformer import BetterTransformer
 
24
 
 
 
 
25
  from axolotl.utils.tokenization import check_dataset_labels
26
  from axolotl.utils.trainer import setup_trainer
27
  from axolotl.utils.validation import validate_config
@@ -204,9 +202,19 @@ def train(
204
  if check_not_in(
205
  ["inference", "shard", "merge_lora"], kwargs
206
  ): # don't need to load dataset for these
207
- train_dataset, eval_dataset = load_prepare_datasets(
208
- tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
209
- )
 
 
 
 
 
 
 
 
 
 
210
 
211
  if cfg.debug or "debug" in kwargs:
212
  logging.info("check_dataset_labels...")
@@ -256,7 +264,7 @@ def train(
256
  logging.info("check_dataset_labels...")
257
  check_dataset_labels(
258
  train_dataset.select(
259
- [random.randrange(0, len(train_dataset) - 1) for i in range(5)]
260
  ),
261
  tokenizer,
262
  )
@@ -265,10 +273,7 @@ def train(
265
  logging.info("Finished preparing dataset. Exiting...")
266
  return
267
 
268
- try:
269
- model.train()
270
- except:
271
- pass
272
 
273
  trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
274
 
@@ -285,14 +290,15 @@ def train(
285
 
286
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
287
  if cfg.local_rank == 0:
288
- def terminate_handler(signum, frame, model):
 
289
  if cfg.flash_optimum:
290
  model = BetterTransformer.reverse(model)
291
  model.save_pretrained(cfg.output_dir)
292
  sys.exit(0)
 
293
  signal.signal(
294
- signal.SIGINT,
295
- lambda signum, frame: terminate_handler(signum, frame, model)
296
  )
297
 
298
  logging.info("Starting trainer...")
@@ -316,7 +322,9 @@ def train(
316
  if not Path(cfg.output_dir).is_dir():
317
  os.makedirs(cfg.output_dir, exist_ok=True)
318
  if cfg.flash_optimum:
319
- with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
 
 
320
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
321
  else:
322
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
 
6
  import random
7
  import signal
8
  import sys
 
9
  from pathlib import Path
10
  from typing import Any, Dict, List, Optional, Union
11
 
12
  import fire
13
  import torch
14
  import yaml
 
 
 
 
 
15
 
16
  # add src to the pythonpath so we don't need to pip install this
17
  from optimum.bettertransformer import BetterTransformer
18
+ from transformers import GenerationConfig, TextStreamer
19
 
20
+ from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
21
+ from axolotl.utils.dict import DictDefault
22
+ from axolotl.utils.models import load_model, load_tokenizer
23
  from axolotl.utils.tokenization import check_dataset_labels
24
  from axolotl.utils.trainer import setup_trainer
25
  from axolotl.utils.validation import validate_config
 
202
  if check_not_in(
203
  ["inference", "shard", "merge_lora"], kwargs
204
  ): # don't need to load dataset for these
205
+ if not cfg.pretraining_dataset:
206
+ train_dataset, eval_dataset = load_prepare_datasets(
207
+ tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
208
+ )
209
+ else:
210
+ if cfg.pretraining_dataset is True:
211
+ pretraining_dataset = "togethercomputer/RedPajama-Data-1T"
212
+ else:
213
+ pretraining_dataset = cfg.pretraining_dataset
214
+ train_dataset = load_pretraining_dataset(
215
+ pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
216
+ )
217
+ eval_dataset = None
218
 
219
  if cfg.debug or "debug" in kwargs:
220
  logging.info("check_dataset_labels...")
 
264
  logging.info("check_dataset_labels...")
265
  check_dataset_labels(
266
  train_dataset.select(
267
+ [random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec
268
  ),
269
  tokenizer,
270
  )
 
273
  logging.info("Finished preparing dataset. Exiting...")
274
  return
275
 
276
+ model.train()
 
 
 
277
 
278
  trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
279
 
 
290
 
291
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
292
  if cfg.local_rank == 0:
293
+
294
+ def terminate_handler(_, __, model):
295
  if cfg.flash_optimum:
296
  model = BetterTransformer.reverse(model)
297
  model.save_pretrained(cfg.output_dir)
298
  sys.exit(0)
299
+
300
  signal.signal(
301
+ signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
 
302
  )
303
 
304
  logging.info("Starting trainer...")
 
322
  if not Path(cfg.output_dir).is_dir():
323
  os.makedirs(cfg.output_dir, exist_ok=True)
324
  if cfg.flash_optimum:
325
+ with torch.backends.cuda.sdp_kernel(
326
+ enable_flash=True, enable_math=True, enable_mem_efficient=True
327
+ ):
328
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
329
  else:
330
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
src/axolotl/utils/data.py CHANGED
@@ -5,7 +5,8 @@ from hashlib import md5
5
  from pathlib import Path
6
  from typing import List, Tuple, Union
7
 
8
- from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
 
9
  from huggingface_hub import hf_hub_download
10
  from transformers import PreTrainedTokenizerBase
11
 
@@ -392,3 +393,32 @@ def load_prepare_datasets(
392
  eval_dataset = dataset["test"]
393
 
394
  return train_dataset, eval_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from pathlib import Path
6
  from typing import List, Tuple, Union
7
 
8
+ import torch
9
+ from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk
10
  from huggingface_hub import hf_hub_download
11
  from transformers import PreTrainedTokenizerBase
12
 
 
393
  eval_dataset = dataset["test"]
394
 
395
  return train_dataset, eval_dataset
396
+
397
+
398
+ class PretrainingDatasetWrapper(IterableDataset):
399
+ """
400
+ Wrapper for pretraining dataset that avoids loading the dataset into memory
401
+ """
402
+
403
+ def __init__(self, tokenizer, dataset_path, max_tokens=2048):
404
+ self.tokenizer = tokenizer
405
+ self.dataset_path = dataset_path
406
+ self.max_tokens = max_tokens
407
+
408
+ def __iter__(self):
409
+ buffer = []
410
+ for sample in load_dataset(
411
+ self.dataset_path,
412
+ name="all",
413
+ split="train",
414
+ streaming=True,
415
+ ).shuffle(buffer_size=10000):
416
+ buffer += self.tokenizer(sample["text"])["input_ids"]
417
+ buffer += [self.tokenizer.eos_token_id]
418
+ while len(buffer) > self.max_tokens:
419
+ yield torch.tensor(buffer[: self.max_tokens])
420
+ buffer = buffer[self.max_tokens :]
421
+
422
+
423
+ def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
424
+ return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens)