winglian commited on
Commit
1c412c7
1 Parent(s): 490923f

improve handling of the prepared ds path and other cfg defaults (#701)

Browse files
src/axolotl/cli/inference.py CHANGED
@@ -14,6 +14,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
14
  # pylint: disable=duplicate-code
15
  print_axolotl_text_art()
16
  parsed_cfg = load_cfg(config, **kwargs)
 
17
  parser = transformers.HfArgumentParser((TrainerCliArgs))
18
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
19
  return_remaining_strings=True
 
14
  # pylint: disable=duplicate-code
15
  print_axolotl_text_art()
16
  parsed_cfg = load_cfg(config, **kwargs)
17
+ parsed_cfg.sample_packing = False
18
  parser = transformers.HfArgumentParser((TrainerCliArgs))
19
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
20
  return_remaining_strings=True
src/axolotl/cli/train.py CHANGED
@@ -1,10 +1,12 @@
1
  """
2
  CLI to run training on a model
3
  """
 
4
  from pathlib import Path
5
 
6
  import fire
7
  import transformers
 
8
 
9
  from axolotl.cli import (
10
  check_accelerate_default_config,
@@ -14,8 +16,11 @@ from axolotl.cli import (
14
  print_axolotl_text_art,
15
  )
16
  from axolotl.common.cli import TrainerCliArgs
 
17
  from axolotl.train import train
18
 
 
 
19
 
20
  def do_cli(config: Path = Path("examples/"), **kwargs):
21
  # pylint: disable=duplicate-code
@@ -27,6 +32,14 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
27
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
28
  return_remaining_strings=True
29
  )
 
 
 
 
 
 
 
 
30
 
31
  dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
32
  if parsed_cli_args.prepare_ds_only:
 
1
  """
2
  CLI to run training on a model
3
  """
4
+ import logging
5
  from pathlib import Path
6
 
7
  import fire
8
  import transformers
9
+ from colorama import Fore
10
 
11
  from axolotl.cli import (
12
  check_accelerate_default_config,
 
16
  print_axolotl_text_art,
17
  )
18
  from axolotl.common.cli import TrainerCliArgs
19
+ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
20
  from axolotl.train import train
21
 
22
+ LOG = logging.getLogger("axolotl.cli.train")
23
+
24
 
25
  def do_cli(config: Path = Path("examples/"), **kwargs):
26
  # pylint: disable=duplicate-code
 
32
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
33
  return_remaining_strings=True
34
  )
35
+ if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
36
+ msg = (
37
+ Fore.RED
38
+ + "--prepare_ds_only called without dataset_prepared_path set."
39
+ + Fore.RESET
40
+ )
41
+ LOG.warning(msg)
42
+ parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
43
 
44
  dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
45
  if parsed_cli_args.prepare_ds_only:
src/axolotl/common/const.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ Various shared constants
3
+ """
4
+
5
+ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
src/axolotl/utils/data.py CHANGED
@@ -16,6 +16,7 @@ from datasets import (
16
  from huggingface_hub import hf_hub_download
17
  from transformers import PreTrainedTokenizerBase
18
 
 
19
  from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
20
  from axolotl.prompt_strategies import load
21
  from axolotl.prompt_tokenizers import (
@@ -44,7 +45,6 @@ from axolotl.utils.trainer import (
44
  )
45
 
46
  LOG = logging.getLogger("axolotl")
47
- DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
48
 
49
 
50
  def md5(to_hash: str, encoding: str = "utf-8") -> str:
@@ -357,7 +357,7 @@ def load_tokenized_prepared_datasets(
357
  if len(datasets) > 1:
358
  LOG.info("shuffle merged datasets")
359
  dataset = dataset.shuffle(seed=seed)
360
- if cfg.local_rank == 0 and cfg.dataset_prepared_path:
361
  LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
362
  dataset.save_to_disk(prepared_ds_path)
363
  if cfg.push_dataset_to_hub:
 
16
  from huggingface_hub import hf_hub_download
17
  from transformers import PreTrainedTokenizerBase
18
 
19
+ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
20
  from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
21
  from axolotl.prompt_strategies import load
22
  from axolotl.prompt_tokenizers import (
 
45
  )
46
 
47
  LOG = logging.getLogger("axolotl")
 
48
 
49
 
50
  def md5(to_hash: str, encoding: str = "utf-8") -> str:
 
357
  if len(datasets) > 1:
358
  LOG.info("shuffle merged datasets")
359
  dataset = dataset.shuffle(seed=seed)
360
+ if cfg.local_rank == 0:
361
  LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
362
  dataset.save_to_disk(prepared_ds_path)
363
  if cfg.push_dataset_to_hub: