winglian commited on
Commit
50682a3
1 Parent(s): 5a1985b

always drop samples that are too long (#452)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +7 -7
src/axolotl/utils/trainer.py CHANGED
@@ -284,15 +284,15 @@ def disable_datasets_caching():
284
 
285
 
286
  def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
 
 
 
 
 
287
  if cfg.sample_packing:
288
- drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
289
- train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
290
- add_position_ids, num_proc=os.cpu_count()
291
- )
292
  if eval_dataset:
293
- eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
294
- add_position_ids, num_proc=os.cpu_count()
295
- )
296
  return train_dataset, eval_dataset
297
 
298
 
 
284
 
285
 
286
  def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
287
+ drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
288
+ train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
289
+ if eval_dataset:
290
+ eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
291
+
292
  if cfg.sample_packing:
293
+ train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
 
 
 
294
  if eval_dataset:
295
+ eval_dataset = eval_dataset.map(add_position_ids, num_proc=os.cpu_count())
 
 
296
  return train_dataset, eval_dataset
297
 
298