winglian commited on
Commit
a546ca2
1 Parent(s): 3355706

misc fixes/improvements (#513)

Browse files
src/axolotl/train.py CHANGED
@@ -88,6 +88,11 @@ def train(
88
  if peft_config:
89
  LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
90
  peft_config.save_pretrained(cfg.output_dir)
 
 
 
 
 
91
 
92
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
93
  if cfg.local_rank == 0:
@@ -106,9 +111,6 @@ def train(
106
  if cfg.group_by_length:
107
  LOG.info("hang tight... sorting dataset for group_by_length")
108
 
109
- if not Path(cfg.output_dir).is_dir():
110
- os.makedirs(cfg.output_dir, exist_ok=True)
111
- tokenizer.save_pretrained(cfg.output_dir)
112
  if cfg.flash_optimum:
113
  with torch.backends.cuda.sdp_kernel(
114
  enable_flash=True, enable_math=True, enable_mem_efficient=True
 
88
  if peft_config:
89
  LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
90
  peft_config.save_pretrained(cfg.output_dir)
91
+ # additionally presave the tokenizer and model configs
92
+ if not Path(cfg.output_dir).is_dir():
93
+ os.makedirs(cfg.output_dir, exist_ok=True)
94
+ tokenizer.save_pretrained(str(Path(cfg.output_dir)))
95
+ model.config.save_pretrained(str(Path(cfg.output_dir)))
96
 
97
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
98
  if cfg.local_rank == 0:
 
111
  if cfg.group_by_length:
112
  LOG.info("hang tight... sorting dataset for group_by_length")
113
 
 
 
 
114
  if cfg.flash_optimum:
115
  with torch.backends.cuda.sdp_kernel(
116
  enable_flash=True, enable_math=True, enable_mem_efficient=True
src/axolotl/utils/trainer.py CHANGED
@@ -33,6 +33,7 @@ from axolotl.utils.callbacks import (
33
  )
34
  from axolotl.utils.collators import DataCollatorForSeq2Seq
35
  from axolotl.utils.dataloader import MultipackDistributedDataloader
 
36
  from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
37
 
38
  LOG = logging.getLogger("axolotl")
@@ -375,14 +376,17 @@ def disable_datasets_caching():
375
 
376
  def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
377
  drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
378
- train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
379
- if eval_dataset:
380
- eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
381
-
382
- if cfg.sample_packing:
383
- train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
384
  if eval_dataset:
385
- eval_dataset = eval_dataset.map(add_position_ids, num_proc=os.cpu_count())
 
 
 
 
 
 
 
386
  return train_dataset, eval_dataset
387
 
388
 
 
33
  )
34
  from axolotl.utils.collators import DataCollatorForSeq2Seq
35
  from axolotl.utils.dataloader import MultipackDistributedDataloader
36
+ from axolotl.utils.distributed import is_main_process, zero_first
37
  from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
38
 
39
  LOG = logging.getLogger("axolotl")
 
376
 
377
  def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
378
  drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
379
+ with zero_first(is_main_process()):
380
+ train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
 
 
 
 
381
  if eval_dataset:
382
+ eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
383
+
384
+ if cfg.sample_packing:
385
+ train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
386
+ if eval_dataset:
387
+ eval_dataset = eval_dataset.map(
388
+ add_position_ids, num_proc=os.cpu_count()
389
+ )
390
  return train_dataset, eval_dataset
391
 
392