winglian commited on
Commit
cdc71f7
β€’
1 Parent(s): 6459ac7

update table for rwkv4 support, fix process count for dataset (#822)

Browse files
Files changed (3) hide show
  1. README.md +1 -0
  2. src/axolotl/datasets.py +8 -2
  3. src/axolotl/utils/data.py +30 -10
README.md CHANGED
@@ -74,6 +74,7 @@ Features:
74
  | gpt-j | βœ… | βœ… | βœ… | ❌ | ❌ | ❓ | ❓ |
75
  | XGen | βœ… | ❓ | βœ… | ❓ | ❓ | ❓ | βœ… |
76
  | phi | βœ… | βœ… | βœ… | ❓ | ❓ | ❓ | ❓ |
 
77
 
78
 
79
  ## Quickstart ⚑
 
74
  | gpt-j | βœ… | βœ… | βœ… | ❌ | ❌ | ❓ | ❓ |
75
  | XGen | βœ… | ❓ | βœ… | ❓ | ❓ | ❓ | βœ… |
76
  | phi | βœ… | βœ… | βœ… | ❓ | ❓ | ❓ | ❓ |
77
+ | RWKV | βœ… | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
78
 
79
 
80
  ## Quickstart ⚑
src/axolotl/datasets.py CHANGED
@@ -2,7 +2,7 @@
2
 
3
  import logging
4
  import os
5
- from typing import List
6
 
7
  import torch
8
  from datasets import Dataset, IterableDataset
@@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset):
30
  self,
31
  prompt_tokenizer: PromptTokenizingStrategy,
32
  dataset: IterableDataset,
 
33
  **kwargs,
34
  ):
35
  self.prompt_tokenizer = prompt_tokenizer
 
36
  super().__init__(self.process(dataset).data, **kwargs)
37
 
38
  def process(self, dataset):
39
  features = dataset.features.keys()
40
- num_proc = min(64, os.cpu_count())
 
 
 
 
41
  map_kwargs = {}
42
  if self.prompt_tokenizer.supports_batched:
43
  map_kwargs["batched"] = True
 
2
 
3
  import logging
4
  import os
5
+ from typing import List, Optional
6
 
7
  import torch
8
  from datasets import Dataset, IterableDataset
 
30
  self,
31
  prompt_tokenizer: PromptTokenizingStrategy,
32
  dataset: IterableDataset,
33
+ process_count: Optional[int] = None,
34
  **kwargs,
35
  ):
36
  self.prompt_tokenizer = prompt_tokenizer
37
+ self.process_count = process_count
38
  super().__init__(self.process(dataset).data, **kwargs)
39
 
40
  def process(self, dataset):
41
  features = dataset.features.keys()
42
+ num_proc = (
43
+ min(64, self.process_count)
44
+ if self.process_count
45
+ else min(64, os.cpu_count())
46
+ )
47
  map_kwargs = {}
48
  if self.prompt_tokenizer.supports_batched:
49
  map_kwargs["batched"] = True
src/axolotl/utils/data.py CHANGED
@@ -482,10 +482,14 @@ def get_dataset_wrapper(
482
  "user_defined", tokenizer, cfg, config_dataset.type.to_dict()
483
  )
484
  dataset_prompter = UnsupportedPrompter()
485
- dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
 
 
486
  elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
487
  dataset_prompter = UnsupportedPrompter()
488
- dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
 
 
489
  elif d_base_type == "alpaca":
490
  dataset_prompter = AlpacaPrompter(d_prompt_style)
491
  ds_strategy = AlpacaPromptTokenizingStrategy(
@@ -494,7 +498,9 @@ def get_dataset_wrapper(
494
  cfg.train_on_inputs,
495
  cfg.sequence_len,
496
  )
497
- ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
 
 
498
  dataset_wrapper = ds_wrapper
499
  elif d_base_type == "explainchoice":
500
  dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
@@ -504,7 +510,9 @@ def get_dataset_wrapper(
504
  cfg.train_on_inputs,
505
  cfg.sequence_len,
506
  )
507
- ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
 
 
508
  dataset_wrapper = ds_wrapper
509
  elif d_base_type == "concisechoice":
510
  dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
@@ -514,7 +522,9 @@ def get_dataset_wrapper(
514
  cfg.train_on_inputs,
515
  cfg.sequence_len,
516
  )
517
- ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
 
 
518
  dataset_wrapper = ds_wrapper
519
  elif d_base_type == "summarizetldr":
520
  dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
@@ -524,7 +534,9 @@ def get_dataset_wrapper(
524
  cfg.train_on_inputs,
525
  cfg.sequence_len,
526
  )
527
- ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
 
 
528
  dataset_wrapper = ds_wrapper
529
  elif d_base_type == "jeopardy":
530
  dataset_prompter = JeopardyPrompter(d_prompt_style)
@@ -534,7 +546,9 @@ def get_dataset_wrapper(
534
  cfg.train_on_inputs,
535
  cfg.sequence_len,
536
  )
537
- ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
 
 
538
  dataset_wrapper = ds_wrapper
539
  elif d_base_type == "oasst":
540
  dataset_prompter = AlpacaPrompter(d_prompt_style)
@@ -544,7 +558,9 @@ def get_dataset_wrapper(
544
  cfg.train_on_inputs,
545
  cfg.sequence_len,
546
  )
547
- ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
 
 
548
  dataset_wrapper = ds_wrapper
549
  elif d_base_type == "gpteacher":
550
  dataset_prompter = GPTeacherPrompter(d_prompt_style)
@@ -554,7 +570,9 @@ def get_dataset_wrapper(
554
  cfg.train_on_inputs,
555
  cfg.sequence_len,
556
  )
557
- ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
 
 
558
  dataset_wrapper = ds_wrapper
559
  elif d_base_type == "reflection":
560
  dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
@@ -564,7 +582,9 @@ def get_dataset_wrapper(
564
  cfg.train_on_inputs,
565
  cfg.sequence_len,
566
  )
567
- ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
 
 
568
  dataset_wrapper = ds_wrapper
569
  else:
570
  suffix = ""
 
482
  "user_defined", tokenizer, cfg, config_dataset.type.to_dict()
483
  )
484
  dataset_prompter = UnsupportedPrompter()
485
+ dataset_wrapper = TokenizedPromptDataset(
486
+ ds_strategy, dataset, process_count=cfg.dataset_processes
487
+ )
488
  elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
489
  dataset_prompter = UnsupportedPrompter()
490
+ dataset_wrapper = TokenizedPromptDataset(
491
+ ds_strategy, dataset, process_count=cfg.dataset_processes
492
+ )
493
  elif d_base_type == "alpaca":
494
  dataset_prompter = AlpacaPrompter(d_prompt_style)
495
  ds_strategy = AlpacaPromptTokenizingStrategy(
 
498
  cfg.train_on_inputs,
499
  cfg.sequence_len,
500
  )
501
+ ds_wrapper = TokenizedPromptDataset(
502
+ ds_strategy, dataset, process_count=cfg.dataset_processes
503
+ )
504
  dataset_wrapper = ds_wrapper
505
  elif d_base_type == "explainchoice":
506
  dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
 
510
  cfg.train_on_inputs,
511
  cfg.sequence_len,
512
  )
513
+ ds_wrapper = TokenizedPromptDataset(
514
+ ds_strategy, dataset, process_count=cfg.dataset_processes
515
+ )
516
  dataset_wrapper = ds_wrapper
517
  elif d_base_type == "concisechoice":
518
  dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
 
522
  cfg.train_on_inputs,
523
  cfg.sequence_len,
524
  )
525
+ ds_wrapper = TokenizedPromptDataset(
526
+ ds_strategy, dataset, process_count=cfg.dataset_processes
527
+ )
528
  dataset_wrapper = ds_wrapper
529
  elif d_base_type == "summarizetldr":
530
  dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
 
534
  cfg.train_on_inputs,
535
  cfg.sequence_len,
536
  )
537
+ ds_wrapper = TokenizedPromptDataset(
538
+ ds_strategy, dataset, process_count=cfg.dataset_processes
539
+ )
540
  dataset_wrapper = ds_wrapper
541
  elif d_base_type == "jeopardy":
542
  dataset_prompter = JeopardyPrompter(d_prompt_style)
 
546
  cfg.train_on_inputs,
547
  cfg.sequence_len,
548
  )
549
+ ds_wrapper = TokenizedPromptDataset(
550
+ ds_strategy, dataset, process_count=cfg.dataset_processes
551
+ )
552
  dataset_wrapper = ds_wrapper
553
  elif d_base_type == "oasst":
554
  dataset_prompter = AlpacaPrompter(d_prompt_style)
 
558
  cfg.train_on_inputs,
559
  cfg.sequence_len,
560
  )
561
+ ds_wrapper = TokenizedPromptDataset(
562
+ ds_strategy, dataset, process_count=cfg.dataset_processes
563
+ )
564
  dataset_wrapper = ds_wrapper
565
  elif d_base_type == "gpteacher":
566
  dataset_prompter = GPTeacherPrompter(d_prompt_style)
 
570
  cfg.train_on_inputs,
571
  cfg.sequence_len,
572
  )
573
+ ds_wrapper = TokenizedPromptDataset(
574
+ ds_strategy, dataset, process_count=cfg.dataset_processes
575
+ )
576
  dataset_wrapper = ds_wrapper
577
  elif d_base_type == "reflection":
578
  dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
 
582
  cfg.train_on_inputs,
583
  cfg.sequence_len,
584
  )
585
+ ds_wrapper = TokenizedPromptDataset(
586
+ ds_strategy, dataset, process_count=cfg.dataset_processes
587
+ )
588
  dataset_wrapper = ds_wrapper
589
  else:
590
  suffix = ""