winglian commited on
Commit
0f985e1
1 Parent(s): c1a7b3d

more fixes 20240228 (#1342) [skip ci]

Browse files

* add missing evals_per_epoch setting

* more pydantic fixes

* more fixes

* move test from normalization to validation

* increase eval size for sample packing tests

src/axolotl/cli/__init__.py CHANGED
@@ -13,7 +13,6 @@ from threading import Thread
13
  from typing import Any, Dict, List, Optional, Union
14
  from urllib.parse import urlparse
15
 
16
- import gradio as gr
17
  import requests
18
  import torch
19
  import yaml
@@ -215,6 +214,8 @@ def do_inference_gradio(
215
  cfg: DictDefault,
216
  cli_args: TrainerCliArgs,
217
  ):
 
 
218
  model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
219
  prompter = cli_args.prompter
220
  default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
 
13
  from typing import Any, Dict, List, Optional, Union
14
  from urllib.parse import urlparse
15
 
 
16
  import requests
17
  import torch
18
  import yaml
 
214
  cfg: DictDefault,
215
  cli_args: TrainerCliArgs,
216
  ):
217
+ import gradio as gr
218
+
219
  model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
220
  prompter = cli_args.prompter
221
  default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
src/axolotl/utils/config/__init__.py CHANGED
@@ -164,9 +164,6 @@ def normalize_config(cfg):
164
  ]
165
  ) or cfg.is_qwen_derived_model
166
 
167
- if isinstance(cfg.learning_rate, str):
168
- cfg.learning_rate = float(cfg.learning_rate)
169
-
170
  if isinstance(cfg.pretraining_dataset, dict):
171
  cfg.pretraining_dataset = [cfg.pretraining_dataset]
172
 
 
164
  ]
165
  ) or cfg.is_qwen_derived_model
166
 
 
 
 
167
  if isinstance(cfg.pretraining_dataset, dict):
168
  cfg.pretraining_dataset = [cfg.pretraining_dataset]
169
 
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -302,6 +302,13 @@ class HyperparametersConfig(BaseModel):
302
  )
303
  return batch_size
304
 
 
 
 
 
 
 
 
305
 
306
  class ModelOutputConfig(BaseModel):
307
  """model save configuration subset"""
@@ -368,6 +375,7 @@ class AxolotlInputConfig(
368
  rl: Optional[RLType] = None
369
 
370
  datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
 
371
  dataset_prepared_path: Optional[str] = None
372
  dataset_shard_num: Optional[int] = None
373
  dataset_shard_idx: Optional[int] = None
@@ -456,6 +464,7 @@ class AxolotlInputConfig(
456
  warmup_steps: Optional[int] = None
457
  warmup_ratio: Optional[float] = None
458
  eval_steps: Optional[Union[int, float]] = None
 
459
  evaluation_strategy: Optional[str] = None
460
  save_steps: Optional[Union[int, float]] = None
461
  saves_per_epoch: Optional[int] = None
@@ -463,6 +472,7 @@ class AxolotlInputConfig(
463
  save_total_limit: Optional[int] = None
464
  logging_steps: Optional[int] = None
465
  early_stopping_patience: Optional[int] = None
 
466
 
467
  neftune_noise_alpha: Optional[float] = None
468
 
 
302
  )
303
  return batch_size
304
 
305
+ @field_validator("learning_rate")
306
+ @classmethod
307
+ def convert_learning_rate(cls, learning_rate):
308
+ if learning_rate and isinstance(learning_rate, str):
309
+ learning_rate = float(learning_rate)
310
+ return learning_rate
311
+
312
 
313
  class ModelOutputConfig(BaseModel):
314
  """model save configuration subset"""
 
375
  rl: Optional[RLType] = None
376
 
377
  datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
378
+ test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
379
  dataset_prepared_path: Optional[str] = None
380
  dataset_shard_num: Optional[int] = None
381
  dataset_shard_idx: Optional[int] = None
 
464
  warmup_steps: Optional[int] = None
465
  warmup_ratio: Optional[float] = None
466
  eval_steps: Optional[Union[int, float]] = None
467
+ evals_per_epoch: Optional[Union[int]] = None
468
  evaluation_strategy: Optional[str] = None
469
  save_steps: Optional[Union[int, float]] = None
470
  saves_per_epoch: Optional[int] = None
 
472
  save_total_limit: Optional[int] = None
473
  logging_steps: Optional[int] = None
474
  early_stopping_patience: Optional[int] = None
475
+ load_best_model_at_end: Optional[bool] = False
476
 
477
  neftune_noise_alpha: Optional[float] = None
478
 
src/axolotl/utils/trainer.py CHANGED
@@ -255,7 +255,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
255
  train_dataset.remove_columns(["length"]),
256
  batch_sampler=sampler,
257
  )
258
- data_loader_len = len(data_loader) // batch_size
259
  actual_eff = sampler.efficiency()
260
  LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
261
  # FIXME: is there a bug here somewhere? the total num steps depends
 
255
  train_dataset.remove_columns(["length"]),
256
  batch_sampler=sampler,
257
  )
258
+ data_loader_len = len(data_loader) // cfg.batch_size
259
  actual_eff = sampler.efficiency()
260
  LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
261
  # FIXME: is there a bug here somewhere? the total num steps depends
tests/e2e/patched/test_lora_llama_multipack.py CHANGED
@@ -43,7 +43,7 @@ class TestLoraLlama(unittest.TestCase):
43
  "lora_alpha": 64,
44
  "lora_dropout": 0.05,
45
  "lora_target_linear": True,
46
- "val_set_size": 0.1,
47
  "special_tokens": {
48
  "unk_token": "<unk>",
49
  "bos_token": "<s>",
 
43
  "lora_alpha": 64,
44
  "lora_dropout": 0.05,
45
  "lora_target_linear": True,
46
+ "val_set_size": 0.2,
47
  "special_tokens": {
48
  "unk_token": "<unk>",
49
  "bos_token": "<s>",
tests/test_normalize_config.py CHANGED
@@ -25,20 +25,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
25
  }
26
  )
27
 
28
- def test_lr_as_float(self):
29
- cfg = (
30
- self._get_base_cfg()
31
- | DictDefault( # pylint: disable=unsupported-binary-operation
32
- {
33
- "learning_rate": "5e-5",
34
- }
35
- )
36
- )
37
-
38
- normalize_config(cfg)
39
-
40
- assert cfg.learning_rate == 0.00005
41
-
42
  def test_base_model_config_set_when_empty(self):
43
  cfg = self._get_base_cfg()
44
  del cfg.base_model_config
 
25
  }
26
  )
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def test_base_model_config_set_when_empty(self):
29
  cfg = self._get_base_cfg()
30
  del cfg.base_model_config
tests/test_validation.py CHANGED
@@ -176,6 +176,20 @@ class TestValidation(BaseValidation):
176
  with pytest.raises(ValueError, match=r".*At least two of*"):
177
  validate_config(cfg)
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  def test_qlora(self, minimal_cfg):
180
  base_cfg = (
181
  DictDefault(
 
176
  with pytest.raises(ValueError, match=r".*At least two of*"):
177
  validate_config(cfg)
178
 
179
+ def test_lr_as_float(self, minimal_cfg):
180
+ cfg = (
181
+ DictDefault( # pylint: disable=unsupported-binary-operation
182
+ {
183
+ "learning_rate": "5e-5",
184
+ }
185
+ )
186
+ | minimal_cfg
187
+ )
188
+
189
+ new_cfg = validate_config(cfg)
190
+
191
+ assert new_cfg.learning_rate == 0.00005
192
+
193
  def test_qlora(self, minimal_cfg):
194
  base_cfg = (
195
  DictDefault(