winglian commited on
Commit
13ac4d8
2 Parent(s): f74edd5 19cf0bd

Merge pull request #268 from OpenAccess-AI-Collective/fix-adam-args

Browse files
src/axolotl/utils/validation.py CHANGED
@@ -87,7 +87,7 @@ def validate_config(cfg):
87
  "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
88
  )
89
 
90
- if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and (
91
  not cfg.optimizer or "adamw" not in cfg.optimizer
92
  ):
93
  logging.warning("adamw hyperparameters found, but no adamw optimizer set")
 
87
  "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
88
  )
89
 
90
+ if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
91
  not cfg.optimizer or "adamw" not in cfg.optimizer
92
  ):
93
  logging.warning("adamw hyperparameters found, but no adamw optimizer set")
tests/test_validation.py CHANGED
@@ -268,7 +268,7 @@ class ValidationTest(unittest.TestCase):
268
  cfg = DictDefault(
269
  {
270
  "optimizer": None,
271
- "adamw_epsilon": 0.0001,
272
  }
273
  )
274
 
@@ -283,7 +283,7 @@ class ValidationTest(unittest.TestCase):
283
  cfg = DictDefault(
284
  {
285
  "optimizer": "adafactor",
286
- "adamw_beta1": 0.0001,
287
  }
288
  )
289
 
@@ -298,9 +298,9 @@ class ValidationTest(unittest.TestCase):
298
  cfg = DictDefault(
299
  {
300
  "optimizer": "adamw_bnb_8bit",
301
- "adamw_beta1": 0.0001,
302
- "adamw_beta2": 0.0001,
303
- "adamw_epsilon": 0.0001,
304
  }
305
  )
306
 
 
268
  cfg = DictDefault(
269
  {
270
  "optimizer": None,
271
+ "adam_epsilon": 0.0001,
272
  }
273
  )
274
 
 
283
  cfg = DictDefault(
284
  {
285
  "optimizer": "adafactor",
286
+ "adam_beta1": 0.0001,
287
  }
288
  )
289
 
 
298
  cfg = DictDefault(
299
  {
300
  "optimizer": "adamw_bnb_8bit",
301
+ "adam_beta1": 0.9,
302
+ "adam_beta2": 0.99,
303
+ "adam_epsilon": 0.0001,
304
  }
305
  )
306