winglian commited on
Commit
e5bb22a
1 Parent(s): fdb777b

add optimization for group-by-len (#563)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +10 -0
src/axolotl/utils/trainer.py CHANGED
@@ -358,7 +358,14 @@ class ReLoRATrainer(AxolotlTrainer):
358
 
359
 
360
  def add_position_ids(sample):
 
361
  sample["position_ids"] = torch.arange(len(sample["input_ids"]))
 
 
 
 
 
 
362
  return sample
363
 
364
 
@@ -382,6 +389,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
382
  if eval_dataset:
383
  eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
384
 
 
 
 
385
  if cfg.sample_packing:
386
  train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
387
  if eval_dataset:
 
358
 
359
 
360
  def add_position_ids(sample):
361
+ sample_len = len(sample["input_ids"])
362
  sample["position_ids"] = torch.arange(len(sample["input_ids"]))
363
+ sample["length"] = sample_len
364
+ return sample
365
+
366
+
367
+ def add_length(sample):
368
+ sample["length"] = len(sample["input_ids"])
369
  return sample
370
 
371
 
 
389
  if eval_dataset:
390
  eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
391
 
392
+ if cfg.group_by_length:
393
+ train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
394
+
395
  if cfg.sample_packing:
396
  train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
397
  if eval_dataset: