winglian commited on
Commit
2642cae
1 Parent(s): f34648c

refactor to set eval_batch_size earlier if unset, so we can warn if mismatched (#662)

Browse files
README.md CHANGED
@@ -571,7 +571,7 @@ torch_compile_backend: # Optional[str]
571
  # training hyperparameters
572
  gradient_accumulation_steps: 1
573
  micro_batch_size: 2
574
- eval_batch_size: 2
575
  num_epochs: 3
576
  warmup_steps: 100
577
  learning_rate: 0.00003
 
571
  # training hyperparameters
572
  gradient_accumulation_steps: 1
573
  micro_batch_size: 2
574
+ eval_batch_size:
575
  num_epochs: 3
576
  warmup_steps: 100
577
  learning_rate: 0.00003
src/axolotl/utils/config.py CHANGED
@@ -49,6 +49,8 @@ def normalize_config(cfg):
49
  cfg.batch_size = (
50
  cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
51
  )
 
 
52
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
53
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
54
  cfg.eval_table_size = cfg.eval_table_size or 0
@@ -157,6 +159,11 @@ def validate_config(cfg):
157
  "batch_size is not recommended. Please use gradient_accumulation_steps instead.",
158
  "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
159
  )
 
 
 
 
 
160
  if cfg.load_4bit:
161
  raise ValueError("cfg.load_4bit parameter has been deprecated")
162
 
 
49
  cfg.batch_size = (
50
  cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
51
  )
52
+ if cfg.eval_batch_size is None:
53
+ cfg.eval_batch_size = cfg.micro_batch_size
54
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
55
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
56
  cfg.eval_table_size = cfg.eval_table_size or 0
 
159
  "batch_size is not recommended. Please use gradient_accumulation_steps instead.",
160
  "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
161
  )
162
+ if cfg.eval_batch_size != cfg.micro_batch_size:
163
+ LOG.warning(
164
+ "eval_batch_size != micro_batch_size. This can lead to VRAM instability."
165
+ )
166
+
167
  if cfg.load_4bit:
168
  raise ValueError("cfg.load_4bit parameter has been deprecated")
169
 
src/axolotl/utils/trainer.py CHANGED
@@ -668,9 +668,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
668
  max_steps=total_num_steps if cfg.max_steps else -1,
669
  max_seq_length=cfg.sequence_len,
670
  per_device_train_batch_size=cfg.micro_batch_size,
671
- per_device_eval_batch_size=cfg.eval_batch_size
672
- if cfg.eval_batch_size is not None
673
- else cfg.micro_batch_size,
674
  gradient_accumulation_steps=cfg.gradient_accumulation_steps,
675
  eval_accumulation_steps=cfg.gradient_accumulation_steps,
676
  num_train_epochs=cfg.num_epochs,
 
668
  max_steps=total_num_steps if cfg.max_steps else -1,
669
  max_seq_length=cfg.sequence_len,
670
  per_device_train_batch_size=cfg.micro_batch_size,
671
+ per_device_eval_batch_size=cfg.eval_batch_size,
 
 
672
  gradient_accumulation_steps=cfg.gradient_accumulation_steps,
673
  eval_accumulation_steps=cfg.gradient_accumulation_steps,
674
  num_train_epochs=cfg.num_epochs,