""" Test classes for checking functionality of the cfg normalization """ import unittest from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault class NormalizeConfigTestCase(unittest.TestCase): """ test class for normalize_config checks """ def _get_base_cfg(self): return DictDefault( { "base_model": "JackFram/llama-68m", "base_model_config": "JackFram/llama-68m", "tokenizer_type": "LlamaTokenizer", "num_epochs": 1, "micro_batch_size": 1, "gradient_accumulation_steps": 1, } ) def test_lr_as_float(self): cfg = ( self._get_base_cfg() | DictDefault( # pylint: disable=unsupported-binary-operation { "learning_rate": "5e-5", } ) ) normalize_config(cfg) assert cfg.learning_rate == 0.00005 def test_base_model_config_set_when_empty(self): cfg = self._get_base_cfg() del cfg.base_model_config normalize_config(cfg) assert cfg.base_model_config == cfg.base_model