winglian commited on
Commit
21ec195
1 Parent(s): 62eaee7

optionally configure sample packing for evals (#589)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +11 -2
src/axolotl/utils/trainer.py CHANGED
@@ -117,6 +117,10 @@ class AxolotlTrainingArguments(TrainingArguments):
117
  default=False,
118
  metadata={"help": "Use sample packing for efficient training."},
119
  )
 
 
 
 
120
  sample_packing_efficiency: float = field(
121
  default=1.0,
122
  metadata={"help": "Sample packing efficiency for calculating batch length."},
@@ -212,7 +216,11 @@ class AxolotlTrainer(Trainer):
212
  def _get_eval_sampler(
213
  self, eval_dataset: Dataset
214
  ) -> Optional[torch.utils.data.Sampler]:
215
- if self.args.world_size > 1 and self.args.sample_packing:
 
 
 
 
216
  return SequentialDistributedSampler(
217
  eval_dataset,
218
  num_replicas=self.args.world_size,
@@ -241,7 +249,7 @@ class AxolotlTrainer(Trainer):
241
  def get_eval_dataloader(
242
  self, eval_dataset: Optional[Dataset] = None
243
  ) -> Union[DataLoader, MultipackDistributedDataloader]:
244
- if self.args.sample_packing:
245
  eval_dataset = (
246
  eval_dataset if eval_dataset is not None else self.eval_dataset
247
  )
@@ -659,6 +667,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
659
  else "cosine",
660
  weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
661
  sample_packing=cfg.sample_packing if cfg.sample_packing else False,
 
662
  sample_packing_seq_len_multiplier=cfg.micro_batch_size,
663
  relora_steps=cfg.relora_steps,
664
  relora_warmup_steps=cfg.relora_warmup_steps,
 
117
  default=False,
118
  metadata={"help": "Use sample packing for efficient training."},
119
  )
120
+ eval_sample_packing: Optional[bool] = field(
121
+ default=None,
122
+ metadata={"help": "Use sample packing for efficient evals."},
123
+ )
124
  sample_packing_efficiency: float = field(
125
  default=1.0,
126
  metadata={"help": "Sample packing efficiency for calculating batch length."},
 
216
  def _get_eval_sampler(
217
  self, eval_dataset: Dataset
218
  ) -> Optional[torch.utils.data.Sampler]:
219
+ if (
220
+ self.args.world_size > 1
221
+ and self.args.sample_packing
222
+ and self.args.eval_sample_packing is not False
223
+ ):
224
  return SequentialDistributedSampler(
225
  eval_dataset,
226
  num_replicas=self.args.world_size,
 
249
  def get_eval_dataloader(
250
  self, eval_dataset: Optional[Dataset] = None
251
  ) -> Union[DataLoader, MultipackDistributedDataloader]:
252
+ if self.args.sample_packing and self.args.eval_sample_packing is not False:
253
  eval_dataset = (
254
  eval_dataset if eval_dataset is not None else self.eval_dataset
255
  )
 
667
  else "cosine",
668
  weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
669
  sample_packing=cfg.sample_packing if cfg.sample_packing else False,
670
+ eval_sample_packing=cfg.eval_sample_packing,
671
  sample_packing_seq_len_multiplier=cfg.micro_batch_size,
672
  relora_steps=cfg.relora_steps,
673
  relora_warmup_steps=cfg.relora_warmup_steps,