ittailup commited on
Commit
3c2ad00
1 Parent(s): 5d48a10

Feat(config): add max steps (#387)

Browse files
scripts/finetune.py CHANGED
@@ -209,7 +209,13 @@ def train(
209
  cfg, train_dataset, eval_dataset
210
  )
211
  barrier()
212
- total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
 
 
 
 
 
 
213
 
214
  if cfg.debug or "debug" in kwargs:
215
  LOG.info("check_dataset_labels...")
 
209
  cfg, train_dataset, eval_dataset
210
  )
211
  barrier()
212
+ if cfg.max_steps:
213
+ total_num_steps = min(
214
+ calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
215
+ )
216
+ LOG.info(f"Maximum number of steps set at {total_num_steps}")
217
+ else:
218
+ total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
219
 
220
  if cfg.debug or "debug" in kwargs:
221
  LOG.info("check_dataset_labels...")
src/axolotl/utils/trainer.py CHANGED
@@ -461,7 +461,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
461
  evaluation_strategy = "steps"
462
 
463
  training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
464
- # max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
465
  max_seq_length=cfg.sequence_len,
466
  per_device_train_batch_size=cfg.micro_batch_size,
467
  per_device_eval_batch_size=cfg.eval_batch_size
 
461
  evaluation_strategy = "steps"
462
 
463
  training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
464
+ max_steps=total_num_steps if cfg.max_steps else -1,
465
  max_seq_length=cfg.sequence_len,
466
  per_device_train_batch_size=cfg.micro_batch_size,
467
  per_device_eval_batch_size=cfg.eval_batch_size