casperhansen commited on
Commit
e50ab07
1 Parent(s): 05bd6f1

Create preprocess CLI (#785)

Browse files

* Create preprocess CLI

* Print prompt template if debugging

* Add print for unsupported prompters

* Formatting

* Formatting

* Refactor variables

* Formatting

* Formatting

* Formatting

* Formatting

README.md CHANGED
@@ -32,7 +32,6 @@ Features:
32
  - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
33
  - [Config](#config)
34
  - [Train](#train)
35
- - [Training w/ Deepspeed](#training-with-deepspeed)
36
  - [Inference](#inference)
37
  - [Merge LORA to Base](#merge-lora-to-base)
38
  - [Common Errors](#common-errors-)
@@ -824,14 +823,41 @@ Run
824
  accelerate launch -m axolotl.cli.train your_config.yml
825
  ```
826
 
827
- #### Multi-GPU
 
 
 
 
 
 
828
 
829
- You can optionally pre-tokenize dataset with the following before finetuning:
830
  ```bash
831
- CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
832
  ```
833
 
834
- ##### Config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
835
 
836
  - llama FSDP
837
  ```yaml
@@ -856,24 +882,6 @@ wandb_run_id:
856
  wandb_log_model:
857
  ```
858
 
859
- ### Training with Deepspeed
860
-
861
- Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
862
- might typically be able to fit into your GPU's VRAM. More information about the various optimization types
863
- for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
864
-
865
- We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
866
-
867
- ```shell
868
- accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
869
- ```
870
-
871
- or
872
-
873
- ```yaml
874
- deepspeed: deepspeed/zero1.json
875
- ```
876
-
877
  ### Inference
878
 
879
  Pass the appropriate flag to the train command:
 
32
  - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
33
  - [Config](#config)
34
  - [Train](#train)
 
35
  - [Inference](#inference)
36
  - [Merge LORA to Base](#merge-lora-to-base)
37
  - [Common Errors](#common-errors-)
 
823
  accelerate launch -m axolotl.cli.train your_config.yml
824
  ```
825
 
826
+ #### Preprocess dataset
827
+
828
+ You can optionally pre-tokenize dataset with the following before finetuning.
829
+ This is recommended for large datasets.
830
+
831
+ - Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
832
+ - Use `--debug` to see preprocessed examples.
833
 
 
834
  ```bash
835
+ python -m axolotl.cli.preprocess your_config.yml
836
  ```
837
 
838
+ #### Multi-GPU
839
+
840
+ Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
841
+ is the recommended multi-GPU option currently because FSDP may experience
842
+ [loss instability](https://github.com/huggingface/transformers/issues/26498).
843
+
844
+ ##### DeepSpeed
845
+
846
+ Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
847
+ might typically be able to fit into your GPU's VRAM. More information about the various optimization types
848
+ for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
849
+
850
+ We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
851
+
852
+ ```yaml
853
+ deepspeed: deepspeed/zero1.json
854
+ ```
855
+
856
+ ```shell
857
+ accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
858
+ ```
859
+
860
+ ##### FSDP
861
 
862
  - llama FSDP
863
  ```yaml
 
882
  wandb_log_model:
883
  ```
884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
885
  ### Inference
886
 
887
  Pass the appropriate flag to the train command:
scripts/finetune.py CHANGED
@@ -45,8 +45,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
45
  shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
46
  else:
47
  dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
48
- if parsed_cli_args.prepare_ds_only:
49
- return
50
  train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
51
 
52
 
 
45
  shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
46
  else:
47
  dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
 
 
48
  train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
49
 
50
 
src/axolotl/cli/__init__.py CHANGED
@@ -222,7 +222,9 @@ def load_datasets(
222
  ) -> TrainDatasetMeta:
223
  tokenizer = load_tokenizer(cfg)
224
 
225
- train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
 
 
226
 
227
  if cli_args.debug or cfg.debug:
228
  LOG.info("check_dataset_labels...")
@@ -238,6 +240,10 @@ def load_datasets(
238
  text_only=cli_args.debug_text_only,
239
  )
240
 
 
 
 
 
241
  return TrainDatasetMeta(
242
  train_dataset=train_dataset,
243
  eval_dataset=eval_dataset,
 
222
  ) -> TrainDatasetMeta:
223
  tokenizer = load_tokenizer(cfg)
224
 
225
+ train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
226
+ cfg, tokenizer
227
+ )
228
 
229
  if cli_args.debug or cfg.debug:
230
  LOG.info("check_dataset_labels...")
 
240
  text_only=cli_args.debug_text_only,
241
  )
242
 
243
+ LOG.info("printing prompters...")
244
+ for prompter in prompters:
245
+ LOG.info(prompter)
246
+
247
  return TrainDatasetMeta(
248
  train_dataset=train_dataset,
249
  eval_dataset=eval_dataset,
src/axolotl/cli/preprocess.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI to run training on a model
3
+ """
4
+ import logging
5
+ from pathlib import Path
6
+
7
+ import fire
8
+ import transformers
9
+ from colorama import Fore
10
+
11
+ from axolotl.cli import (
12
+ check_accelerate_default_config,
13
+ check_user_token,
14
+ load_cfg,
15
+ load_datasets,
16
+ print_axolotl_text_art,
17
+ )
18
+ from axolotl.common.cli import PreprocessCliArgs
19
+ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
20
+
21
+ LOG = logging.getLogger("axolotl.cli.preprocess")
22
+
23
+
24
+ def do_cli(config: Path = Path("examples/"), **kwargs):
25
+ # pylint: disable=duplicate-code
26
+ print_axolotl_text_art()
27
+ parsed_cfg = load_cfg(config, **kwargs)
28
+ check_accelerate_default_config()
29
+ check_user_token()
30
+ parser = transformers.HfArgumentParser((PreprocessCliArgs))
31
+ parsed_cli_args, _ = parser.parse_args_into_dataclasses(
32
+ return_remaining_strings=True
33
+ )
34
+ if not parsed_cfg.dataset_prepared_path:
35
+ msg = (
36
+ Fore.RED
37
+ + "preprocess CLI called without dataset_prepared_path set, "
38
+ + f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
39
+ + Fore.RESET
40
+ )
41
+ LOG.warning(msg)
42
+ parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
43
+
44
+ _ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
45
+ LOG.info(
46
+ Fore.GREEN
47
+ + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
48
+ + Fore.RESET
49
+ )
50
+
51
+
52
+ if __name__ == "__main__":
53
+ fire.Fire(do_cli)
src/axolotl/cli/train.py CHANGED
@@ -6,7 +6,6 @@ from pathlib import Path
6
 
7
  import fire
8
  import transformers
9
- from colorama import Fore
10
 
11
  from axolotl.cli import (
12
  check_accelerate_default_config,
@@ -16,7 +15,6 @@ from axolotl.cli import (
16
  print_axolotl_text_art,
17
  )
18
  from axolotl.common.cli import TrainerCliArgs
19
- from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
20
  from axolotl.train import train
21
 
22
  LOG = logging.getLogger("axolotl.cli.train")
@@ -32,18 +30,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
32
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
33
  return_remaining_strings=True
34
  )
35
- if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
36
- msg = (
37
- Fore.RED
38
- + "--prepare_ds_only called without dataset_prepared_path set."
39
- + Fore.RESET
40
- )
41
- LOG.warning(msg)
42
- parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
43
-
44
  dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
45
- if parsed_cli_args.prepare_ds_only:
46
- return
47
  train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
48
 
49
 
 
6
 
7
  import fire
8
  import transformers
 
9
 
10
  from axolotl.cli import (
11
  check_accelerate_default_config,
 
15
  print_axolotl_text_art,
16
  )
17
  from axolotl.common.cli import TrainerCliArgs
 
18
  from axolotl.train import train
19
 
20
  LOG = logging.getLogger("axolotl.cli.train")
 
30
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
31
  return_remaining_strings=True
32
  )
 
 
 
 
 
 
 
 
 
33
  dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
 
 
34
  train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
35
 
36
 
src/axolotl/common/cli.py CHANGED
@@ -25,11 +25,22 @@ class TrainerCliArgs:
25
  debug_num_examples: int = field(default=5)
26
  inference: bool = field(default=False)
27
  merge_lora: bool = field(default=False)
28
- prepare_ds_only: bool = field(default=False)
29
  prompter: Optional[str] = field(default=None)
30
  shard: bool = field(default=False)
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def load_model_and_tokenizer(
34
  *,
35
  cfg: DictDefault,
 
25
  debug_num_examples: int = field(default=5)
26
  inference: bool = field(default=False)
27
  merge_lora: bool = field(default=False)
 
28
  prompter: Optional[str] = field(default=None)
29
  shard: bool = field(default=False)
30
 
31
 
32
+ @dataclass
33
+ class PreprocessCliArgs:
34
+ """
35
+ dataclass representing arguments for preprocessing only
36
+ """
37
+
38
+ debug: bool = field(default=False)
39
+ debug_text_only: bool = field(default=False)
40
+ debug_num_examples: int = field(default=1)
41
+ prompter: Optional[str] = field(default=None)
42
+
43
+
44
  def load_model_and_tokenizer(
45
  *,
46
  cfg: DictDefault,
src/axolotl/prompt_tokenizers.py CHANGED
@@ -245,6 +245,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
245
  raise NotImplementedError
246
 
247
  def tokenize_prompt(self, prompt):
 
248
  (
249
  instruction,
250
  input, # pylint: disable=redefined-builtin
 
245
  raise NotImplementedError
246
 
247
  def tokenize_prompt(self, prompt):
248
+ # pylint: disable=duplicate-code
249
  (
250
  instruction,
251
  input, # pylint: disable=redefined-builtin
src/axolotl/prompters.py CHANGED
@@ -4,10 +4,12 @@ import logging
4
  from enum import Enum
5
  from typing import Generator, Optional, Union
6
 
 
7
  from fastchat.conversation import Conversation, get_conv_template
8
 
9
  LOG = logging.getLogger("axolotl")
10
  IGNORE_TOKEN_ID = -100
 
11
 
12
 
13
  class PromptStyle(Enum):
@@ -55,20 +57,15 @@ class AlpacaPrompter:
55
  )
56
  self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
57
 
58
- def build_prompt(
59
- self,
60
- instruction: str,
61
- input: Union[None, str] = None, # pylint: disable=redefined-builtin
62
- output: Union[None, str] = None,
63
- ) -> Generator[str, None, None]:
64
  # returns the full prompt from instruction and optional input
65
  # if a label (=response, =output) is provided, it's also appended.
66
- if input:
67
  res = (
68
  self.system_format.format(system=self.system_prompt)
69
  if self.system_prompt
70
  else ""
71
- ) + self.turn_format.format(instruction=instruction, input=input)
72
  else:
73
  res = (
74
  self.system_format.format(system=self.system_no_input_prompt)
@@ -77,7 +74,21 @@ class AlpacaPrompter:
77
  ) + self.turn_no_input_format.format(instruction=instruction)
78
  if output:
79
  res = f"{res}{output}"
80
- yield res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
  class UnpromptedPrompter(AlpacaPrompter):
@@ -191,14 +202,14 @@ class ReflectAlpacaPrompter:
191
  )
192
  self.response_split = "ASSISTANT:"
193
 
194
- def build_prompt(
195
  self,
196
  instruction: str,
197
  input: Union[None, str] = None, # pylint: disable=redefined-builtin
198
  output: Union[None, str] = None,
199
  reflection: Union[None, str] = None,
200
  corrected: Union[None, str] = None,
201
- ) -> Generator[str, None, None]:
202
  # returns the full prompt from instruction and optional input
203
  # if a label (=response, =output) is provided, it's also appended.
204
  if input:
@@ -212,7 +223,30 @@ class ReflectAlpacaPrompter:
212
  corrected=corrected,
213
  )
214
  res = f"{res}{label}"
215
- yield res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
 
218
  SHAREGPT_ASSERTION_FAILED_ROLE = (
@@ -247,7 +281,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
247
  if role_key_model:
248
  self.role_key_model = role_key_model
249
 
250
- def build_prompt(self, source) -> Generator[str, None, None]:
251
  if len(source) < 2:
252
  # If there isn't a back and forth conversation, ignore it
253
  # also happens on the data splitting leaving empty conversations
@@ -282,11 +316,20 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
282
  LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
283
  conv.append_message(role, sentence["value"])
284
 
285
- for part in conv.get_turns():
 
 
 
 
 
286
  if part[0] and not part[1]:
287
  LOG.warning(f"role with empty message: {part[0]}")
288
  yield part
289
 
 
 
 
 
290
 
291
  class ShareGPTPrompterV2(ShareGPTPrompter):
292
  """
@@ -304,3 +347,15 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
304
  role_key_human=role_key_human,
305
  role_key_model=role_key_model,
306
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from enum import Enum
5
  from typing import Generator, Optional, Union
6
 
7
+ from colorama import Fore
8
  from fastchat.conversation import Conversation, get_conv_template
9
 
10
  LOG = logging.getLogger("axolotl")
11
  IGNORE_TOKEN_ID = -100
12
+ REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
13
 
14
 
15
  class PromptStyle(Enum):
 
57
  )
58
  self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
59
 
60
+ def _build_result(self, instruction, input_text, output):
 
 
 
 
 
61
  # returns the full prompt from instruction and optional input
62
  # if a label (=response, =output) is provided, it's also appended.
63
+ if input_text:
64
  res = (
65
  self.system_format.format(system=self.system_prompt)
66
  if self.system_prompt
67
  else ""
68
+ ) + self.turn_format.format(instruction=instruction, input=input_text)
69
  else:
70
  res = (
71
  self.system_format.format(system=self.system_no_input_prompt)
 
74
  ) + self.turn_no_input_format.format(instruction=instruction)
75
  if output:
76
  res = f"{res}{output}"
77
+
78
+ return res
79
+
80
+ def build_prompt(
81
+ self,
82
+ instruction: str,
83
+ input: Union[None, str] = None, # pylint: disable=redefined-builtin
84
+ output: Union[None, str] = None,
85
+ ) -> Generator[str, None, None]:
86
+ yield self._build_result(instruction, input, output)
87
+
88
+ def __repr__(self) -> str:
89
+ return REPR_TEMPLATE.format(
90
+ full_prompt=self._build_result("{instruction}", "{input}", "{output}")
91
+ )
92
 
93
 
94
  class UnpromptedPrompter(AlpacaPrompter):
 
202
  )
203
  self.response_split = "ASSISTANT:"
204
 
205
+ def _build_result(
206
  self,
207
  instruction: str,
208
  input: Union[None, str] = None, # pylint: disable=redefined-builtin
209
  output: Union[None, str] = None,
210
  reflection: Union[None, str] = None,
211
  corrected: Union[None, str] = None,
212
+ ):
213
  # returns the full prompt from instruction and optional input
214
  # if a label (=response, =output) is provided, it's also appended.
215
  if input:
 
223
  corrected=corrected,
224
  )
225
  res = f"{res}{label}"
226
+
227
+ return res
228
+
229
+ def build_prompt(
230
+ self,
231
+ instruction: str,
232
+ input: Union[None, str] = None, # pylint: disable=redefined-builtin
233
+ output: Union[None, str] = None,
234
+ reflection: Union[None, str] = None,
235
+ corrected: Union[None, str] = None,
236
+ ) -> Generator[str, None, None]:
237
+ # pylint: disable=duplicate-code
238
+ yield self._build_result(
239
+ instruction,
240
+ input,
241
+ output,
242
+ reflection,
243
+ corrected,
244
+ )
245
+
246
+ def __repr__(self) -> str:
247
+ return REPR_TEMPLATE.format(
248
+ full_prompt=self._build_result("{instruction}", "{input}", "{output}")
249
+ )
250
 
251
 
252
  SHAREGPT_ASSERTION_FAILED_ROLE = (
 
281
  if role_key_model:
282
  self.role_key_model = role_key_model
283
 
284
+ def _build_result(self, source):
285
  if len(source) < 2:
286
  # If there isn't a back and forth conversation, ignore it
287
  # also happens on the data splitting leaving empty conversations
 
316
  LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
317
  conv.append_message(role, sentence["value"])
318
 
319
+ return conv.get_turns()
320
+
321
+ def build_prompt(self, source) -> Generator[str, None, None]:
322
+ turns = self._build_result(source)
323
+
324
+ for part in turns:
325
  if part[0] and not part[1]:
326
  LOG.warning(f"role with empty message: {part[0]}")
327
  yield part
328
 
329
+ def __repr__(self) -> str:
330
+ turns = self._build_result([{"from": "{from}", "value": "{value}"}])
331
+ return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
332
+
333
 
334
  class ShareGPTPrompterV2(ShareGPTPrompter):
335
  """
 
347
  role_key_human=role_key_human,
348
  role_key_model=role_key_model,
349
  )
350
+
351
+
352
+ class UnsupportedPrompter:
353
+ """
354
+ A dummy class for custom prompters
355
+ """
356
+
357
+ def __init__(self) -> None:
358
+ pass
359
+
360
+ def __repr__(self):
361
+ return "Pre-tokenized or custom dataset types are unsupported for logging"
src/axolotl/utils/data.py CHANGED
@@ -3,7 +3,7 @@ import functools
3
  import hashlib
4
  import logging
5
  from pathlib import Path
6
- from typing import Dict, List, Tuple, Union
7
 
8
  import torch
9
  from datasets import (
@@ -36,6 +36,7 @@ from axolotl.prompters import (
36
  MultipleChoiceExplainPrompter,
37
  ReflectAlpacaPrompter,
38
  SummarizeTLDRPrompter,
 
39
  )
40
  from axolotl.utils.dict import DictDefault
41
  from axolotl.utils.distributed import is_main_process, zero_first
@@ -55,9 +56,10 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
55
 
56
 
57
  def prepare_dataset(cfg, tokenizer):
 
58
  if not cfg.pretraining_dataset:
59
  with zero_first(is_main_process()):
60
- train_dataset, eval_dataset = load_prepare_datasets(
61
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
62
  )
63
  else:
@@ -70,7 +72,7 @@ def prepare_dataset(cfg, tokenizer):
70
  # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
71
  train_dataset = train_dataset.with_format("torch")
72
  eval_dataset = None
73
- return train_dataset, eval_dataset, cfg.max_steps
74
 
75
  with zero_first(is_main_process()):
76
  train_dataset, eval_dataset = process_datasets_for_packing(
@@ -83,7 +85,7 @@ def prepare_dataset(cfg, tokenizer):
83
  LOG.info(f"Maximum number of steps set at {total_num_steps}")
84
  else:
85
  total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
86
- return train_dataset, eval_dataset, total_num_steps
87
 
88
 
89
  def load_tokenized_prepared_datasets(
@@ -109,6 +111,7 @@ def load_tokenized_prepared_datasets(
109
  else Path(default_dataset_prepared_path) / ds_hash
110
  )
111
  dataset = None
 
112
  use_auth_token = cfg.hf_use_auth_token
113
  try:
114
  if cfg.push_dataset_to_hub:
@@ -147,13 +150,13 @@ def load_tokenized_prepared_datasets(
147
  yield dataset
148
 
149
  # pylint: disable=invalid-name
150
- for d in for_d_in_datasets(cfg.datasets):
151
  ds: Union[Dataset, DatasetDict] = None
152
  ds_from_hub = False
153
  try:
154
  load_dataset(
155
- d.path,
156
- name=d.name,
157
  streaming=True,
158
  token=use_auth_token,
159
  )
@@ -162,33 +165,33 @@ def load_tokenized_prepared_datasets(
162
  pass
163
 
164
  # prefer local dataset, even if hub exists
165
- local_path = Path(d.path)
166
  if local_path.exists():
167
  if local_path.is_dir():
168
  # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
169
  ds = load_dataset(
170
- d.path,
171
- name=d.name,
172
- data_files=d.data_files,
173
  streaming=False,
174
  split=None,
175
  )
176
  elif local_path.is_file():
177
  ds_type = "json"
178
- if d.ds_type:
179
- ds_type = d.ds_type
180
- elif ".parquet" in d.path:
181
  ds_type = "parquet"
182
- elif ".arrow" in d.path:
183
  ds_type = "arrow"
184
- elif ".csv" in d.path:
185
  ds_type = "csv"
186
- elif ".txt" in d.path:
187
  ds_type = "text"
188
  ds = load_dataset(
189
  ds_type,
190
- name=d.name,
191
- data_files=d.path,
192
  streaming=False,
193
  split=None,
194
  )
@@ -198,25 +201,25 @@ def load_tokenized_prepared_datasets(
198
  )
199
  elif ds_from_hub:
200
  ds = load_dataset(
201
- d.path,
202
- name=d.name,
203
  streaming=False,
204
- data_files=d.data_files,
205
  token=use_auth_token,
206
  )
207
  else:
208
- if isinstance(d.data_files, str):
209
  fp = hf_hub_download(
210
- repo_id=d.path,
211
  repo_type="dataset",
212
- filename=d.data_files,
213
  )
214
- elif isinstance(d.data_files, list):
215
  fp = []
216
- for file in d.data_files:
217
  fp.append(
218
  hf_hub_download(
219
- repo_id=d.path,
220
  repo_type="dataset",
221
  filename=file,
222
  )
@@ -226,21 +229,27 @@ def load_tokenized_prepared_datasets(
226
  "data_files must be either a string or list of strings"
227
  )
228
  ds = load_dataset(
229
- "json", name=d.name, data_files=fp, streaming=False, split=None
 
 
 
 
230
  )
231
  if not ds:
232
  raise ValueError("unhandled dataset load")
233
  # support for using a subset of the data
234
- if d.shards:
235
  if "train" in ds:
236
  ds = ds.shuffle(seed=seed)["train"].shard(
237
- num_shards=d.shards, index=0
238
  )
239
  else:
240
- ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
 
 
241
 
242
  d_base_type = d_prompt_style = None
243
- d_type = d.type
244
  if isinstance(d_type, str):
245
  d_type_split = d_type.split(":")
246
  d_base_type = d_type_split[0]
@@ -249,108 +258,26 @@ def load_tokenized_prepared_datasets(
249
  ds = ds["train"]
250
  elif (
251
  isinstance(ds, DatasetDict)
252
- and d.train_on_split
253
- and d.train_on_split in ds
254
  ):
255
- ds = ds[d.train_on_split]
256
  elif isinstance(ds, DatasetDict):
257
  raise ValueError(
258
- f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
259
- )
260
- if (
261
- "input_ids" in ds.features
262
- and "attention_mask" in ds.features
263
- and "labels" in ds.features
264
- ):
265
- # dataset is already tokenized, just drop it straight in
266
- datasets.append(ds)
267
- elif isinstance(d.type, DictDefault):
268
- ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
269
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
270
- datasets.append(ds_wrapper)
271
- elif ds_strategy := load(d.type, tokenizer, cfg, d):
272
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
273
- datasets.append(ds_wrapper)
274
- elif d_base_type == "alpaca":
275
- ds_strategy = AlpacaPromptTokenizingStrategy(
276
- AlpacaPrompter(d_prompt_style),
277
- tokenizer,
278
- cfg.train_on_inputs,
279
- cfg.sequence_len,
280
- )
281
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
282
- datasets.append(ds_wrapper)
283
- elif d_base_type == "explainchoice":
284
- ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
285
- MultipleChoiceExplainPrompter(d_prompt_style),
286
- tokenizer,
287
- cfg.train_on_inputs,
288
- cfg.sequence_len,
289
- )
290
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
291
- datasets.append(ds_wrapper)
292
- elif d_base_type == "concisechoice":
293
- ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
294
- MultipleChoiceConcisePrompter(d_prompt_style),
295
- tokenizer,
296
- cfg.train_on_inputs,
297
- cfg.sequence_len,
298
- )
299
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
300
- datasets.append(ds_wrapper)
301
- elif d_base_type == "summarizetldr":
302
- ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
303
- SummarizeTLDRPrompter(d_prompt_style),
304
- tokenizer,
305
- cfg.train_on_inputs,
306
- cfg.sequence_len,
307
- )
308
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
309
- datasets.append(ds_wrapper)
310
- elif d_base_type == "jeopardy":
311
- ds_strategy = JeopardyPromptTokenizingStrategy(
312
- JeopardyPrompter(d_prompt_style),
313
- tokenizer,
314
- cfg.train_on_inputs,
315
- cfg.sequence_len,
316
- )
317
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
318
- datasets.append(ds_wrapper)
319
- elif d_base_type == "oasst":
320
- ds_strategy = OpenAssistantPromptTokenizingStrategy(
321
- AlpacaPrompter(d_prompt_style),
322
- tokenizer,
323
- cfg.train_on_inputs,
324
- cfg.sequence_len,
325
- )
326
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
327
- datasets.append(ds_wrapper)
328
- elif d_base_type == "gpteacher":
329
- ds_strategy = GPTeacherPromptTokenizingStrategy(
330
- GPTeacherPrompter(d_prompt_style),
331
- tokenizer,
332
- cfg.train_on_inputs,
333
- cfg.sequence_len,
334
- )
335
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
336
- datasets.append(ds_wrapper)
337
- elif d_base_type == "reflection":
338
- ds_strategy = AlpacaReflectionPTStrategy(
339
- ReflectAlpacaPrompter(d_prompt_style),
340
- tokenizer,
341
- cfg.train_on_inputs,
342
- cfg.sequence_len,
343
- )
344
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
345
- datasets.append(ds_wrapper)
346
- else:
347
- suffix = ""
348
- if ":load_" in d.type:
349
- suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
350
- LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
351
- raise ValueError(
352
- f"unhandled prompt tokenization strategy: {d.type} {suffix}"
353
  )
 
 
 
 
 
 
 
 
 
 
 
 
354
  LOG.info("merging datasets")
355
  dataset = concatenate_datasets(datasets)
356
 
@@ -368,14 +295,14 @@ def load_tokenized_prepared_datasets(
368
  f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
369
  )
370
 
371
- return dataset
372
 
373
 
374
  def load_prepare_datasets(
375
  tokenizer: PreTrainedTokenizerBase,
376
  cfg,
377
  default_dataset_prepared_path,
378
- ) -> Tuple[Dataset, Dataset]:
379
  max_packed_sequence_len = (
380
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
381
  )
@@ -384,6 +311,7 @@ def load_prepare_datasets(
384
  ) # make sure we don't accidentally set it larger than sequence_len
385
 
386
  tokenizer_name = tokenizer.__class__.__name__
 
387
  if cfg.max_packed_sequence_len is not None:
388
  # see if we can go ahead and load the stacked dataset
389
  seed = f"@{str(cfg.seed)}" if cfg.seed else ""
@@ -439,7 +367,7 @@ def load_prepare_datasets(
439
  f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
440
  )
441
  else:
442
- dataset = load_tokenized_prepared_datasets(
443
  tokenizer, cfg, default_dataset_prepared_path
444
  )
445
 
@@ -481,7 +409,7 @@ def load_prepare_datasets(
481
  private=True,
482
  )
483
  else:
484
- dataset = load_tokenized_prepared_datasets(
485
  tokenizer, cfg, default_dataset_prepared_path
486
  )
487
 
@@ -532,7 +460,124 @@ def load_prepare_datasets(
532
  train_dataset = dataset
533
  eval_dataset = None
534
 
535
- return train_dataset, eval_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
 
537
 
538
  def encode_pretraining(
 
3
  import hashlib
4
  import logging
5
  from pathlib import Path
6
+ from typing import Any, Dict, List, Tuple, Union
7
 
8
  import torch
9
  from datasets import (
 
36
  MultipleChoiceExplainPrompter,
37
  ReflectAlpacaPrompter,
38
  SummarizeTLDRPrompter,
39
+ UnsupportedPrompter,
40
  )
41
  from axolotl.utils.dict import DictDefault
42
  from axolotl.utils.distributed import is_main_process, zero_first
 
56
 
57
 
58
  def prepare_dataset(cfg, tokenizer):
59
+ prompters = []
60
  if not cfg.pretraining_dataset:
61
  with zero_first(is_main_process()):
62
+ train_dataset, eval_dataset, prompters = load_prepare_datasets(
63
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
64
  )
65
  else:
 
72
  # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
73
  train_dataset = train_dataset.with_format("torch")
74
  eval_dataset = None
75
+ return train_dataset, eval_dataset, cfg.max_steps, prompters
76
 
77
  with zero_first(is_main_process()):
78
  train_dataset, eval_dataset = process_datasets_for_packing(
 
85
  LOG.info(f"Maximum number of steps set at {total_num_steps}")
86
  else:
87
  total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
88
+ return train_dataset, eval_dataset, total_num_steps, prompters
89
 
90
 
91
  def load_tokenized_prepared_datasets(
 
111
  else Path(default_dataset_prepared_path) / ds_hash
112
  )
113
  dataset = None
114
+ prompters = []
115
  use_auth_token = cfg.hf_use_auth_token
116
  try:
117
  if cfg.push_dataset_to_hub:
 
150
  yield dataset
151
 
152
  # pylint: disable=invalid-name
153
+ for config_dataset in for_d_in_datasets(cfg.datasets):
154
  ds: Union[Dataset, DatasetDict] = None
155
  ds_from_hub = False
156
  try:
157
  load_dataset(
158
+ config_dataset.path,
159
+ name=config_dataset.name,
160
  streaming=True,
161
  token=use_auth_token,
162
  )
 
165
  pass
166
 
167
  # prefer local dataset, even if hub exists
168
+ local_path = Path(config_dataset.path)
169
  if local_path.exists():
170
  if local_path.is_dir():
171
  # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
172
  ds = load_dataset(
173
+ config_dataset.path,
174
+ name=config_dataset.name,
175
+ data_files=config_dataset.data_files,
176
  streaming=False,
177
  split=None,
178
  )
179
  elif local_path.is_file():
180
  ds_type = "json"
181
+ if config_dataset.ds_type:
182
+ ds_type = config_dataset.ds_type
183
+ elif ".parquet" in config_dataset.path:
184
  ds_type = "parquet"
185
+ elif ".arrow" in config_dataset.path:
186
  ds_type = "arrow"
187
+ elif ".csv" in config_dataset.path:
188
  ds_type = "csv"
189
+ elif ".txt" in config_dataset.path:
190
  ds_type = "text"
191
  ds = load_dataset(
192
  ds_type,
193
+ name=config_dataset.name,
194
+ data_files=config_dataset.path,
195
  streaming=False,
196
  split=None,
197
  )
 
201
  )
202
  elif ds_from_hub:
203
  ds = load_dataset(
204
+ config_dataset.path,
205
+ name=config_dataset.name,
206
  streaming=False,
207
+ data_files=config_dataset.data_files,
208
  token=use_auth_token,
209
  )
210
  else:
211
+ if isinstance(config_dataset.data_files, str):
212
  fp = hf_hub_download(
213
+ repo_id=config_dataset.path,
214
  repo_type="dataset",
215
+ filename=config_dataset.data_files,
216
  )
217
+ elif isinstance(config_dataset.data_files, list):
218
  fp = []
219
+ for file in config_dataset.data_files:
220
  fp.append(
221
  hf_hub_download(
222
+ repo_id=config_dataset.path,
223
  repo_type="dataset",
224
  filename=file,
225
  )
 
229
  "data_files must be either a string or list of strings"
230
  )
231
  ds = load_dataset(
232
+ "json",
233
+ name=config_dataset.name,
234
+ data_files=fp,
235
+ streaming=False,
236
+ split=None,
237
  )
238
  if not ds:
239
  raise ValueError("unhandled dataset load")
240
  # support for using a subset of the data
241
+ if config_dataset.shards:
242
  if "train" in ds:
243
  ds = ds.shuffle(seed=seed)["train"].shard(
244
+ num_shards=config_dataset.shards, index=0
245
  )
246
  else:
247
+ ds = ds.shuffle(seed=seed).shard(
248
+ num_shards=config_dataset.shards, index=0
249
+ )
250
 
251
  d_base_type = d_prompt_style = None
252
+ d_type = config_dataset.type
253
  if isinstance(d_type, str):
254
  d_type_split = d_type.split(":")
255
  d_base_type = d_type_split[0]
 
258
  ds = ds["train"]
259
  elif (
260
  isinstance(ds, DatasetDict)
261
+ and config_dataset.train_on_split
262
+ and config_dataset.train_on_split in ds
263
  ):
264
+ ds = ds[config_dataset.train_on_split]
265
  elif isinstance(ds, DatasetDict):
266
  raise ValueError(
267
+ f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  )
269
+
270
+ dataset_wrapper, dataset_prompter = get_dataset_wrapper(
271
+ config_dataset=config_dataset,
272
+ dataset=ds,
273
+ tokenizer=tokenizer,
274
+ cfg=cfg,
275
+ d_base_type=d_base_type,
276
+ d_prompt_style=d_prompt_style,
277
+ )
278
+ datasets.append(dataset_wrapper)
279
+ prompters.append(dataset_prompter)
280
+
281
  LOG.info("merging datasets")
282
  dataset = concatenate_datasets(datasets)
283
 
 
295
  f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
296
  )
297
 
298
+ return dataset, prompters
299
 
300
 
301
  def load_prepare_datasets(
302
  tokenizer: PreTrainedTokenizerBase,
303
  cfg,
304
  default_dataset_prepared_path,
305
+ ) -> Tuple[Dataset, Dataset, List[Any]]:
306
  max_packed_sequence_len = (
307
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
308
  )
 
311
  ) # make sure we don't accidentally set it larger than sequence_len
312
 
313
  tokenizer_name = tokenizer.__class__.__name__
314
+ prompters = []
315
  if cfg.max_packed_sequence_len is not None:
316
  # see if we can go ahead and load the stacked dataset
317
  seed = f"@{str(cfg.seed)}" if cfg.seed else ""
 
367
  f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
368
  )
369
  else:
370
+ dataset, prompters = load_tokenized_prepared_datasets(
371
  tokenizer, cfg, default_dataset_prepared_path
372
  )
373
 
 
409
  private=True,
410
  )
411
  else:
412
+ dataset, prompters = load_tokenized_prepared_datasets(
413
  tokenizer, cfg, default_dataset_prepared_path
414
  )
415
 
 
460
  train_dataset = dataset
461
  eval_dataset = None
462
 
463
+ return train_dataset, eval_dataset, prompters
464
+
465
+
466
+ def get_dataset_wrapper(
467
+ config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
468
+ ):
469
+ dataset_wrapper = None
470
+ dataset_prompter = None
471
+
472
+ if (
473
+ "input_ids" in dataset.features
474
+ and "attention_mask" in dataset.features
475
+ and "labels" in dataset.features
476
+ ):
477
+ # dataset is already tokenized, just drop it straight in
478
+ dataset_prompter = UnsupportedPrompter()
479
+ dataset_wrapper = dataset
480
+ elif isinstance(config_dataset.type, DictDefault):
481
+ ds_strategy = load(
482
+ "user_defined", tokenizer, cfg, config_dataset.type.to_dict()
483
+ )
484
+ dataset_prompter = UnsupportedPrompter()
485
+ dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
486
+ elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
487
+ dataset_prompter = UnsupportedPrompter()
488
+ dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
489
+ elif d_base_type == "alpaca":
490
+ dataset_prompter = AlpacaPrompter(d_prompt_style)
491
+ ds_strategy = AlpacaPromptTokenizingStrategy(
492
+ dataset_prompter,
493
+ tokenizer,
494
+ cfg.train_on_inputs,
495
+ cfg.sequence_len,
496
+ )
497
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
498
+ dataset_wrapper = ds_wrapper
499
+ elif d_base_type == "explainchoice":
500
+ dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
501
+ ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
502
+ dataset_prompter,
503
+ tokenizer,
504
+ cfg.train_on_inputs,
505
+ cfg.sequence_len,
506
+ )
507
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
508
+ dataset_wrapper = ds_wrapper
509
+ elif d_base_type == "concisechoice":
510
+ dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
511
+ ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
512
+ dataset_prompter,
513
+ tokenizer,
514
+ cfg.train_on_inputs,
515
+ cfg.sequence_len,
516
+ )
517
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
518
+ dataset_wrapper = ds_wrapper
519
+ elif d_base_type == "summarizetldr":
520
+ dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
521
+ ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
522
+ dataset_prompter,
523
+ tokenizer,
524
+ cfg.train_on_inputs,
525
+ cfg.sequence_len,
526
+ )
527
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
528
+ dataset_wrapper = ds_wrapper
529
+ elif d_base_type == "jeopardy":
530
+ dataset_prompter = JeopardyPrompter(d_prompt_style)
531
+ ds_strategy = JeopardyPromptTokenizingStrategy(
532
+ dataset_prompter,
533
+ tokenizer,
534
+ cfg.train_on_inputs,
535
+ cfg.sequence_len,
536
+ )
537
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
538
+ dataset_wrapper = ds_wrapper
539
+ elif d_base_type == "oasst":
540
+ dataset_prompter = AlpacaPrompter(d_prompt_style)
541
+ ds_strategy = OpenAssistantPromptTokenizingStrategy(
542
+ dataset_prompter,
543
+ tokenizer,
544
+ cfg.train_on_inputs,
545
+ cfg.sequence_len,
546
+ )
547
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
548
+ dataset_wrapper = ds_wrapper
549
+ elif d_base_type == "gpteacher":
550
+ dataset_prompter = GPTeacherPrompter(d_prompt_style)
551
+ ds_strategy = GPTeacherPromptTokenizingStrategy(
552
+ dataset_prompter,
553
+ tokenizer,
554
+ cfg.train_on_inputs,
555
+ cfg.sequence_len,
556
+ )
557
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
558
+ dataset_wrapper = ds_wrapper
559
+ elif d_base_type == "reflection":
560
+ dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
561
+ ds_strategy = AlpacaReflectionPTStrategy(
562
+ dataset_prompter,
563
+ tokenizer,
564
+ cfg.train_on_inputs,
565
+ cfg.sequence_len,
566
+ )
567
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
568
+ dataset_wrapper = ds_wrapper
569
+ else:
570
+ suffix = ""
571
+ if ":load_" in config_dataset.type:
572
+ suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
573
+ LOG.error(
574
+ f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
575
+ )
576
+ raise ValueError(
577
+ f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
578
+ )
579
+
580
+ return dataset_wrapper, dataset_prompter
581
 
582
 
583
  def encode_pretraining(