winglian commited on
Commit
7710e81
β€’
1 Parent(s): 48434be

log supervised token count (#448)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +10 -0
src/axolotl/utils/trainer.py CHANGED
@@ -401,6 +401,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
401
  LOG.info(f"πŸ“ UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
402
  cfg.total_num_tokens = total_num_tokens
403
 
 
 
 
 
 
 
 
 
 
 
404
  if cfg.sample_packing_eff_est:
405
  total_num_steps = (
406
  # match count to len est in dataloader
 
401
  LOG.info(f"πŸ“ UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
402
  cfg.total_num_tokens = total_num_tokens
403
 
404
+ if not cfg.total_supervised_tokens:
405
+ total_supervised_tokens = (
406
+ train_dataset.data.column("labels")
407
+ .to_pandas()
408
+ .apply(lambda x: np.sum(np.array(x) != -100))
409
+ .sum()
410
+ )
411
+ LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
412
+ cfg.total_supervised_tokens = total_supervised_tokens
413
+
414
  if cfg.sample_packing_eff_est:
415
  total_num_steps = (
416
  # match count to len est in dataloader