winglian commited on
Commit
f7a2263
1 Parent(s): 1aa4007

support custom field for completion from yml (#580)

Browse files

* support custom field for completion from yml

* remove legacy completion check and add doc

* update README docs

README.md CHANGED
@@ -322,6 +322,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
322
  - path: EleutherAI/pile
323
  name: enron_emails
324
  type: completion # format from earlier
 
325
 
326
  # huggingface repo with multiple named configurations/subsets
327
  datasets:
@@ -444,6 +445,9 @@ datasets:
444
  # 'no_input_format' cannot include {input}
445
  no_input_format: "{instruction} "
446
 
 
 
 
447
  # axolotl attempts to save the dataset as an arrow after packing the data together so
448
  # subsequent training attempts load faster, relative path
449
  dataset_prepared_path: data/last_run_prepared
 
322
  - path: EleutherAI/pile
323
  name: enron_emails
324
  type: completion # format from earlier
325
+ field: text # Optional[str] default: text, field to use for completion data
326
 
327
  # huggingface repo with multiple named configurations/subsets
328
  datasets:
 
445
  # 'no_input_format' cannot include {input}
446
  no_input_format: "{instruction} "
447
 
448
+ # for completions datsets, uses the provided field if not `text`
449
+ field:
450
+
451
  # axolotl attempts to save the dataset as an arrow after packing the data together so
452
  # subsequent training attempts load faster, relative path
453
  dataset_prepared_path: data/last_run_prepared
src/axolotl/prompt_strategies/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
  """Module to load prompt strategies."""
2
 
3
  import importlib
 
4
 
5
  from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
6
 
@@ -16,6 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg):
16
  load_kwargs = {}
17
  if strategy == "user_defined":
18
  load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
 
 
 
 
19
  return func(tokenizer, cfg, **load_kwargs)
20
  except Exception: # pylint: disable=broad-exception-caught
21
  return None
 
1
  """Module to load prompt strategies."""
2
 
3
  import importlib
4
+ import inspect
5
 
6
  from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
7
 
 
17
  load_kwargs = {}
18
  if strategy == "user_defined":
19
  load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
20
+ else:
21
+ sig = inspect.signature(func)
22
+ if "ds_cfg" in sig.parameters:
23
+ load_kwargs["ds_cfg"] = ds_cfg
24
  return func(tokenizer, cfg, **load_kwargs)
25
  except Exception: # pylint: disable=broad-exception-caught
26
  return None
src/axolotl/prompt_strategies/completion.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic completion text
3
+ """
4
+ from typing import Any, Dict, Optional
5
+
6
+ from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy
7
+ from axolotl.prompters import CompletionPrompter
8
+
9
+
10
+ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
11
+ strat = CompletionPromptTokenizingStrategy(
12
+ CompletionPrompter(),
13
+ tokenizer,
14
+ cfg.train_on_inputs,
15
+ cfg.sequence_len,
16
+ )
17
+ if ds_cfg and "field" in ds_cfg:
18
+ strat.field = ds_cfg["field"]
19
+
20
+ return strat
src/axolotl/prompt_tokenizers.py CHANGED
@@ -245,8 +245,31 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
245
  Tokenizing strategy for Completion prompts.
246
  """
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  def tokenize_prompt(self, prompt):
249
- full_prompt = self._build_full_prompt(prompt["text"], None, None)
 
 
 
 
 
 
250
  tokenized_full_prompt = self._tokenize(full_prompt)
251
 
252
  return tokenized_full_prompt
 
245
  Tokenizing strategy for Completion prompts.
246
  """
247
 
248
+ _field: str = "text"
249
+
250
+ @property
251
+ def field(self) -> str:
252
+ return self._field
253
+
254
+ @field.setter
255
+ def field(self, new_field: str):
256
+ self._field = new_field
257
+
258
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
259
+ return (
260
+ prompt[self.field],
261
+ "",
262
+ "",
263
+ )
264
+
265
  def tokenize_prompt(self, prompt):
266
+ (
267
+ instruction,
268
+ _,
269
+ _,
270
+ ) = self.parse_instruction_fields(prompt)
271
+
272
+ full_prompt = self._build_full_prompt(instruction, None, None)
273
  tokenized_full_prompt = self._tokenize(full_prompt)
274
 
275
  return tokenized_full_prompt
src/axolotl/utils/data.py CHANGED
@@ -22,7 +22,6 @@ from axolotl.prompt_tokenizers import (
22
  AlpacaMultipleChoicePromptTokenizingStrategy,
23
  AlpacaPromptTokenizingStrategy,
24
  AlpacaReflectionPTStrategy,
25
- CompletionPromptTokenizingStrategy,
26
  GPTeacherPromptTokenizingStrategy,
27
  JeopardyPromptTokenizingStrategy,
28
  OpenAssistantPromptTokenizingStrategy,
@@ -31,7 +30,6 @@ from axolotl.prompt_tokenizers import (
31
  )
32
  from axolotl.prompters import (
33
  AlpacaPrompter,
34
- CompletionPrompter,
35
  GPTeacherPrompter,
36
  JeopardyPrompter,
37
  MultipleChoiceConcisePrompter,
@@ -327,15 +325,6 @@ def load_tokenized_prepared_datasets(
327
  )
328
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
329
  datasets.append(ds_wrapper)
330
- elif d_base_type == "completion":
331
- ds_strategy = CompletionPromptTokenizingStrategy(
332
- CompletionPrompter(),
333
- tokenizer,
334
- cfg.train_on_inputs,
335
- cfg.sequence_len,
336
- )
337
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
338
- datasets.append(ds_wrapper)
339
  else:
340
  suffix = ""
341
  if ":load_" in d.type:
 
22
  AlpacaMultipleChoicePromptTokenizingStrategy,
23
  AlpacaPromptTokenizingStrategy,
24
  AlpacaReflectionPTStrategy,
 
25
  GPTeacherPromptTokenizingStrategy,
26
  JeopardyPromptTokenizingStrategy,
27
  OpenAssistantPromptTokenizingStrategy,
 
30
  )
31
  from axolotl.prompters import (
32
  AlpacaPrompter,
 
33
  GPTeacherPrompter,
34
  JeopardyPrompter,
35
  MultipleChoiceConcisePrompter,
 
325
  )
326
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
327
  datasets.append(ds_wrapper)
 
 
 
 
 
 
 
 
 
328
  else:
329
  suffix = ""
330
  if ":load_" in d.type: