winglian commited on
Commit
641e6f7
1 Parent(s): 6dc68a6

multipack w batch sampler (#795)

Browse files

* test batch sampler w varying batch lens

* wip

* multipack batchsampler wip

* wip

* fix for prepare data loader to get correct # of steps based on gpues

* lint and clean up

* calculate len estimate

* fix total num steps calc

* add options for dataloader_num_workers and dataloader_pin_memory

* remove gitbook

* support prefetch_factor for dataloader optimization

* fix the kwarg

gitbook/README.md DELETED
@@ -1 +0,0 @@
1
- # Page
 
 
gitbook/SUMMARY.md DELETED
@@ -1,4 +0,0 @@
1
- # Table of contents
2
-
3
- * [Page](README.md)
4
- * [Small dev details](small-dev-details.md)
 
 
 
 
 
gitbook/small-dev-details.md DELETED
@@ -1,3 +0,0 @@
1
- # Small dev details
2
-
3
- /
 
 
 
 
src/axolotl/core/trainer_builder.py CHANGED
@@ -6,7 +6,6 @@ import abc
6
  import importlib
7
  import logging
8
  import math
9
- import os
10
  import sys
11
  from abc import abstractmethod
12
  from dataclasses import dataclass, field
@@ -18,9 +17,9 @@ import torch
18
  import transformers
19
  from datasets import Dataset
20
  from torch.optim.lr_scheduler import OneCycleLR
21
- from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
22
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
23
- from transformers.trainer_pt_utils import SequentialDistributedSampler
24
 
25
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
26
  from axolotl.utils.callbacks import (
@@ -31,8 +30,9 @@ from axolotl.utils.callbacks import (
31
  bench_eval_callback_factory,
32
  log_prediction_callback_factory,
33
  )
34
- from axolotl.utils.collators import DataCollatorForSeq2Seq
35
  from axolotl.utils.dataloader import MultipackDistributedDataloader
 
36
  from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
37
 
38
  try:
@@ -102,6 +102,10 @@ class AxolotlTrainingArguments(TrainingArguments):
102
  bench_source_max_len: int = field(
103
  default=2048, metadata={"help": "Maximum source sequence length for bench."}
104
  )
 
 
 
 
105
 
106
 
107
  class AxolotlTrainer(Trainer):
@@ -145,46 +149,69 @@ class AxolotlTrainer(Trainer):
145
  return self.lr_scheduler
146
 
147
  def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
148
- if self.args.world_size > 1 and self.args.sample_packing:
149
- return DistributedSampler(
150
- self.train_dataset,
151
- num_replicas=self.args.world_size,
152
- rank=self.args.process_index,
153
- seed=self.args.seed,
 
 
 
 
 
 
 
154
  )
155
  return super()._get_train_sampler()
156
 
157
  def _get_eval_sampler(
158
  self, eval_dataset: Dataset
159
  ) -> Optional[torch.utils.data.Sampler]:
160
- if (
161
- self.args.world_size > 1
162
- and self.args.sample_packing
163
- and self.args.eval_sample_packing is not False
164
- ):
165
- return SequentialDistributedSampler(
166
- eval_dataset,
167
- num_replicas=self.args.world_size,
168
- rank=self.args.process_index,
169
- batch_size=self.args.per_device_eval_batch_size,
 
 
 
170
  )
171
  return super()._get_eval_sampler(eval_dataset)
172
 
173
- def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
174
  if self.args.sample_packing:
175
- train_sampler = self._get_train_sampler()
176
- return self.accelerator.prepare(
177
- MultipackDistributedDataloader(
178
- self.train_dataset,
179
- batch_size=self._train_batch_size,
180
- seq_max_length=self.args.max_seq_length,
181
- collate_fn=self.data_collator,
182
- sampler=train_sampler,
183
- packing_efficiency_estimate=self.args.sample_packing_efficiency,
184
- sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
185
- device_count=int(os.environ.get("WORLD_SIZE", 1)),
186
- num_epochs=self.num_epochs,
187
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  )
189
  return super().get_train_dataloader()
190
 
@@ -197,18 +224,29 @@ class AxolotlTrainer(Trainer):
197
  )
198
 
199
  eval_sampler = self._get_eval_sampler(eval_dataset)
200
- return self.accelerator.prepare(
201
- MultipackDistributedDataloader(
202
- eval_dataset,
203
- batch_size=self.args.eval_batch_size,
204
- seq_max_length=self.args.max_seq_length,
205
- collate_fn=self.data_collator,
206
- sampler=eval_sampler,
207
- packing_efficiency_estimate=self.args.sample_packing_efficiency,
208
- sample_packing_seq_len_multiplier=self.args.eval_batch_size,
209
- device_count=int(os.environ.get("WORLD_SIZE", 1)),
210
- num_epochs=self.num_epochs,
211
- )
 
 
 
 
 
 
 
 
 
 
 
212
  )
213
  return super().get_eval_dataloader(eval_dataset)
214
 
@@ -229,6 +267,8 @@ class AxolotlTrainer(Trainer):
229
  "num_workers": self.args.dataloader_num_workers,
230
  "pin_memory": self.args.dataloader_pin_memory,
231
  }
 
 
232
 
233
  if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
234
  dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
@@ -493,6 +533,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
493
  "sample_packing_efficiency"
494
  ] = self.cfg.sample_packing_eff_est
495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  if self.cfg.eval_steps:
497
  training_arguments_kwargs["evaluation_strategy"] = "steps"
498
  training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
@@ -672,7 +725,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
672
  train_dataset=self.train_dataset,
673
  eval_dataset=self.eval_dataset,
674
  args=training_args,
675
- data_collator=DataCollatorForSeq2Seq(
676
  self.tokenizer,
677
  return_tensors="pt",
678
  **data_collator_kwargs,
@@ -690,4 +743,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
690
  for callback in self.get_post_trainer_create_callbacks(trainer):
691
  trainer.add_callback(callback)
692
 
 
 
 
 
 
693
  return trainer
 
6
  import importlib
7
  import logging
8
  import math
 
9
  import sys
10
  from abc import abstractmethod
11
  from dataclasses import dataclass, field
 
17
  import transformers
18
  from datasets import Dataset
19
  from torch.optim.lr_scheduler import OneCycleLR
20
+ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
21
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
22
+ from transformers.trainer_utils import seed_worker
23
 
24
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
25
  from axolotl.utils.callbacks import (
 
30
  bench_eval_callback_factory,
31
  log_prediction_callback_factory,
32
  )
33
+ from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
34
  from axolotl.utils.dataloader import MultipackDistributedDataloader
35
+ from axolotl.utils.samplers import MultipackBatchSampler
36
  from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
37
 
38
  try:
 
102
  bench_source_max_len: int = field(
103
  default=2048, metadata={"help": "Maximum source sequence length for bench."}
104
  )
105
+ dataloader_prefetch_factor: Optional[int] = field(
106
+ default=None,
107
+ metadata={"help": "prefetch_factor argument to the dataloader"},
108
+ )
109
 
110
 
111
  class AxolotlTrainer(Trainer):
 
149
  return self.lr_scheduler
150
 
151
  def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
152
+ if self.args.sample_packing:
153
+ return MultipackBatchSampler(
154
+ RandomSampler(self.train_dataset),
155
+ self.args.train_batch_size,
156
+ drop_last=True,
157
+ batch_max_len=self._train_batch_size * self.args.max_seq_length,
158
+ lengths=(
159
+ self.train_dataset.data.column("position_ids")
160
+ .to_pandas()
161
+ .apply(lambda x: x[-1] + 1)
162
+ .values
163
+ ),
164
+ packing_efficiency_estimate=self.args.sample_packing_efficiency,
165
  )
166
  return super()._get_train_sampler()
167
 
168
  def _get_eval_sampler(
169
  self, eval_dataset: Dataset
170
  ) -> Optional[torch.utils.data.Sampler]:
171
+ if self.args.sample_packing and self.args.eval_sample_packing is not False:
172
+ return MultipackBatchSampler(
173
+ SequentialSampler(eval_dataset),
174
+ self.args.per_device_eval_batch_size,
175
+ drop_last=True,
176
+ batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
177
+ lengths=(
178
+ eval_dataset.data.column("position_ids")
179
+ .to_pandas()
180
+ .apply(lambda x: x[-1] + 1)
181
+ .values
182
+ ),
183
+ packing_efficiency_estimate=self.args.sample_packing_efficiency,
184
  )
185
  return super()._get_eval_sampler(eval_dataset)
186
 
187
+ def get_train_dataloader(self) -> DataLoader:
188
  if self.args.sample_packing:
189
+ train_dataset = self.train_dataset
190
+ train_dataset = train_dataset.remove_columns(["length"])
191
+ data_collator = self.data_collator
192
+ dataloader_params = {
193
+ "batch_size": self._train_batch_size,
194
+ "collate_fn": data_collator,
195
+ "num_workers": self.args.dataloader_num_workers,
196
+ "pin_memory": self.args.dataloader_pin_memory,
197
+ }
198
+ if self.args.dataloader_prefetch_factor:
199
+ dataloader_params[
200
+ "prefetch_factor"
201
+ ] = self.args.dataloader_prefetch_factor
202
+
203
+ sampler = self._get_train_sampler()
204
+ if isinstance(sampler, BatchSampler):
205
+ dataloader_params["batch_sampler"] = sampler
206
+ del dataloader_params["batch_size"]
207
+ else:
208
+ dataloader_params["sampler"] = sampler
209
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
210
+ dataloader_params["worker_init_fn"] = seed_worker
211
+
212
+ self.accelerator.even_batches = False
213
+ return self.accelerator.prepare_data_loader(
214
+ DataLoader(train_dataset, **dataloader_params)
215
  )
216
  return super().get_train_dataloader()
217
 
 
224
  )
225
 
226
  eval_sampler = self._get_eval_sampler(eval_dataset)
227
+ eval_dataset = eval_dataset.remove_columns(["length"])
228
+ data_collator = self.data_collator
229
+ dataloader_params = {
230
+ "batch_size": self.args.eval_batch_size,
231
+ "collate_fn": data_collator,
232
+ "num_workers": self.args.dataloader_num_workers,
233
+ "pin_memory": self.args.dataloader_pin_memory,
234
+ }
235
+ if self.args.dataloader_prefetch_factor:
236
+ dataloader_params[
237
+ "prefetch_factor"
238
+ ] = self.args.dataloader_prefetch_factor
239
+
240
+ if isinstance(eval_sampler, BatchSampler):
241
+ dataloader_params["batch_sampler"] = eval_sampler
242
+ del dataloader_params["batch_size"]
243
+ else:
244
+ dataloader_params["sampler"] = eval_sampler
245
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
246
+
247
+ self.accelerator.even_batches = False
248
+ return self.accelerator.prepare_data_loader(
249
+ DataLoader(eval_dataset, **dataloader_params)
250
  )
251
  return super().get_eval_dataloader(eval_dataset)
252
 
 
267
  "num_workers": self.args.dataloader_num_workers,
268
  "pin_memory": self.args.dataloader_pin_memory,
269
  }
270
+ if self.args.dataloader_prefetch_factor:
271
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
272
 
273
  if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
274
  dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
 
533
  "sample_packing_efficiency"
534
  ] = self.cfg.sample_packing_eff_est
535
 
536
+ if self.cfg.dataloader_pin_memory is not None:
537
+ training_arguments_kwargs[
538
+ "dataloader_pin_memory"
539
+ ] = self.cfg.dataloader_pin_memory
540
+ if self.cfg.dataloader_num_workers is not None:
541
+ training_arguments_kwargs[
542
+ "dataloader_num_workers"
543
+ ] = self.cfg.dataloader_num_workers
544
+ if self.cfg.dataloader_prefetch_factor is not None:
545
+ training_arguments_kwargs[
546
+ "dataloader_prefetch_factor"
547
+ ] = self.cfg.dataloader_prefetch_factor
548
+
549
  if self.cfg.eval_steps:
550
  training_arguments_kwargs["evaluation_strategy"] = "steps"
551
  training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
 
725
  train_dataset=self.train_dataset,
726
  eval_dataset=self.eval_dataset,
727
  args=training_args,
728
+ data_collator=BatchSamplerDataCollatorForSeq2Seq(
729
  self.tokenizer,
730
  return_tensors="pt",
731
  **data_collator_kwargs,
 
743
  for callback in self.get_post_trainer_create_callbacks(trainer):
744
  trainer.add_callback(callback)
745
 
746
+ if self.cfg.deepspeed and self.cfg.sample_packing:
747
+ trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
748
+ "train_micro_batch_size_per_gpu"
749
+ ] = self.cfg.micro_batch_size
750
+
751
  return trainer
src/axolotl/utils/collators.py CHANGED
@@ -119,3 +119,30 @@ class DataCollatorForSeq2Seq:
119
  features["decoder_input_ids"] = decoder_input_ids
120
 
121
  return features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  features["decoder_input_ids"] = decoder_input_ids
120
 
121
  return features
122
+
123
+
124
+ @dataclass
125
+ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
126
+ """
127
+ Collator for multipack specific to the using the BatchSampler
128
+ """
129
+
130
+ def __call__(self, features, return_tensors=None):
131
+ chunked_data = {}
132
+ for feature in features[0].keys():
133
+ if feature == "length":
134
+ continue
135
+ if feature == "attention_mask":
136
+ arrays = [
137
+ (1) * np.array(item[feature])
138
+ for item in features
139
+ if feature in item
140
+ ]
141
+ chunked_data[feature] = np.concatenate(arrays)
142
+ else:
143
+ arrays = [
144
+ np.array(item[feature]) for item in features if feature in item
145
+ ]
146
+ chunked_data[feature] = np.concatenate(arrays)
147
+ features = [chunked_data]
148
+ return super().__call__(features, return_tensors=return_tensors)
src/axolotl/utils/data.py CHANGED
@@ -80,11 +80,11 @@ def prepare_dataset(cfg, tokenizer):
80
  )
81
  if cfg.max_steps:
82
  total_num_steps = min(
83
- calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
84
  )
85
  LOG.info(f"Maximum number of steps set at {total_num_steps}")
86
  else:
87
- total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
88
  return train_dataset, eval_dataset, total_num_steps, prompters
89
 
90
 
 
80
  )
81
  if cfg.max_steps:
82
  total_num_steps = min(
83
+ calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
84
  )
85
  LOG.info(f"Maximum number of steps set at {total_num_steps}")
86
  else:
87
+ total_num_steps = calculate_total_num_steps(cfg, train_dataset)
88
  return train_dataset, eval_dataset, total_num_steps, prompters
89
 
90
 
src/axolotl/utils/samplers/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ axolotl samplers module
3
+ """
4
+ from .multipack import MultipackBatchSampler # noqa: F401
src/axolotl/utils/samplers/multipack.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ """
3
+ Multipack Batch Sampler
4
+ """
5
+ import logging
6
+ import math
7
+ import os
8
+ from typing import Any, Iterable, List, Union
9
+
10
+ import numba
11
+ import numpy as np
12
+ from torch.utils.data import BatchSampler, Sampler
13
+
14
+ LOG = logging.getLogger("axolotl.utils.samplers.multipack")
15
+
16
+
17
+ @numba.njit
18
+ def ffd_check(a: np.ndarray, c: int, n: int):
19
+ # First-fit-decreasing bin packing
20
+ # Check if a[] could fit in n bins with capacity c
21
+ # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
22
+
23
+ a = np.sort(a)[::-1]
24
+ bins = np.full((n,), c, dtype=a.dtype)
25
+ for size in a:
26
+ not_found = True
27
+ for idx in range(n):
28
+ if bins[idx] >= size:
29
+ bins[idx] -= size
30
+ not_found = False
31
+ break
32
+
33
+ if not_found:
34
+ return False
35
+
36
+ return True
37
+
38
+
39
+ @numba.njit
40
+ def ffd_with_result(a: np.ndarray, c: int, start_index: int):
41
+ # First-fit-decreasing bin packing (with result return)
42
+
43
+ indices = np.argsort(a)[::-1]
44
+ a = a[indices]
45
+
46
+ bins: List[Any] = []
47
+ bins_result: List[Any] = []
48
+ for a_id, size in enumerate(a):
49
+ add_new = True
50
+ for idx in range(len(bins)):
51
+ if bins[idx] >= size:
52
+ bins[idx] -= size
53
+ bins_result[idx].append(indices[a_id] + start_index)
54
+ add_new = False
55
+ break
56
+
57
+ if add_new:
58
+ bins.append(c - size)
59
+ bins_result.append([indices[a_id] + start_index])
60
+
61
+ return bins_result
62
+
63
+
64
+ @numba.njit
65
+ def allocate(
66
+ lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
67
+ ):
68
+ # Dynamic batch allocator, similar to Multifit
69
+ # https://en.wikipedia.org/wiki/Multifit_algorithm
70
+ # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
71
+
72
+ s = 0
73
+ start_index = 0
74
+ result = []
75
+
76
+ while True:
77
+ # binary search [l, r)
78
+ left = 1
79
+ right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
80
+
81
+ while right - left > 1:
82
+ mid = (left + right) // 2
83
+ if ffd_check(lengths[start_index : start_index + mid], c, n):
84
+ left = mid
85
+ else:
86
+ right = mid
87
+
88
+ # use length l
89
+ batch = ffd_with_result(
90
+ lengths[start_index : start_index + left], c, start_index
91
+ )
92
+ assert len(batch) <= n
93
+ if len(batch) < n:
94
+ break
95
+
96
+ start_index += left
97
+ s = lengths_cumsum[start_index - 1]
98
+
99
+ # add local rank
100
+ result.append(batch[rank])
101
+
102
+ return result, s, len(result) * c * n
103
+
104
+
105
+ class MultipackBatchSampler(BatchSampler):
106
+ """
107
+ Batch Sampler class for multipack
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ sampler: Union[Sampler[int], Iterable[int]],
113
+ batch_size: int,
114
+ drop_last: bool,
115
+ batch_max_len: int,
116
+ lengths: np.ndarray,
117
+ packing_efficiency_estimate: float = 1.0,
118
+ ):
119
+ super().__init__(sampler, batch_size, drop_last)
120
+ self.batch_size = None
121
+ self.batch_max_len = batch_max_len
122
+ self.lengths: np.ndarray = lengths
123
+ self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
124
+
125
+ assert isinstance(self.lengths, np.ndarray)
126
+
127
+ self.epoch = 0
128
+
129
+ # statistics
130
+ self.eff_total_used = 0
131
+ self.eff_total_slots = 0
132
+
133
+ def set_epoch(self, epoch: int):
134
+ self.epoch = epoch
135
+
136
+ def generate_batches(self, set_stats=False):
137
+ indices = [idx for idx in self.sampler]
138
+
139
+ lengths = self.lengths[indices]
140
+ lengths_cumsum = np.cumsum(lengths)
141
+
142
+ batches, total_used, total_slots = allocate(
143
+ lengths=lengths,
144
+ lengths_cumsum=lengths_cumsum,
145
+ rank=0,
146
+ c=self.batch_max_len,
147
+ n=1,
148
+ )
149
+
150
+ batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
151
+
152
+ # statistics
153
+ if set_stats:
154
+ self.eff_total_used += total_used
155
+ self.eff_total_slots += total_slots
156
+
157
+ return batches
158
+
159
+ def __iter__(self):
160
+ batches = self.generate_batches(set_stats=True)
161
+ return iter(batches)
162
+
163
+ def num_batches(self):
164
+ batches = self.generate_batches(set_stats=True)
165
+ return len(batches)
166
+
167
+ def efficiency(self):
168
+ return self.eff_total_used / self.eff_total_slots
169
+
170
+ def __len__(self):
171
+ self.num_batches()
172
+ return self._len_est()
173
+
174
+ def _len_est(self):
175
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
176
+ lengths_sum = np.sum(self.lengths)
177
+ lengths_sum_per_device = lengths_sum // world_size
178
+ LOG.info(
179
+ f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
180
+ f"total_num_tokens per device: {lengths_sum_per_device}"
181
+ )
182
+
183
+ # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
184
+ return (
185
+ world_size
186
+ * math.floor(
187
+ 0.99
188
+ * lengths_sum_per_device
189
+ / self.packing_efficiency_estimate
190
+ // self.batch_max_len
191
+ )
192
+ - 1
193
+ )
src/axolotl/utils/trainer.py CHANGED
@@ -8,20 +8,13 @@ from typing import List
8
  import numpy as np
9
  import torch
10
  import torch.cuda
11
- import torch.distributed as dist
12
  from accelerate.logging import get_logger
13
  from datasets import set_caching_enabled
14
- from torch.utils.data import DistributedSampler, RandomSampler
15
 
16
  from axolotl.core.trainer_builder import HFCausalTrainerBuilder
17
- from axolotl.utils.collators import DataCollatorForSeq2Seq
18
- from axolotl.utils.dataloader import MultipackDistributedDataloader
19
- from axolotl.utils.distributed import (
20
- is_distributed,
21
- is_main_process,
22
- reduce_and_broadcast,
23
- zero_first,
24
- )
25
 
26
  LOG = get_logger("axolotl")
27
 
@@ -148,7 +141,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
148
  return train_dataset, eval_dataset
149
 
150
 
151
- def calculate_total_num_steps(cfg, train_dataset, tokenizer):
152
  if cfg.sample_packing:
153
  # we have to drop anything longer then sequence len otherwise
154
  # flash attention with position ids fails
@@ -196,37 +189,36 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
196
  main_process_only=True,
197
  )
198
  else:
199
- if cfg.world_size > 1 and is_distributed():
200
- sampler = DistributedSampler(
201
- train_dataset,
202
- num_replicas=cfg.world_size,
203
- rank=dist.get_rank(),
204
- seed=cfg.seed or 42,
205
- )
206
- else:
207
- sampler = RandomSampler(train_dataset)
208
-
209
- data_loader = MultipackDistributedDataloader(
210
- train_dataset,
211
  batch_size=cfg.micro_batch_size,
212
- seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
213
- collate_fn=DataCollatorForSeq2Seq(
214
- tokenizer,
215
- return_tensors="pt",
216
- padding="longest",
 
 
 
217
  ),
218
- sampler=sampler,
219
- packing_efficiency_estimate=cfg.sample_packing_eff_est,
220
- sample_packing_seq_len_multiplier=cfg.micro_batch_size,
221
- device_count=int(os.environ.get("WORLD_SIZE", 1)),
222
- num_epochs=cfg.num_epochs,
223
  )
224
- data_loader_len = data_loader.len_w_stats()
225
- actual_eff = data_loader.efficiency()
 
 
 
 
 
226
  LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
227
  # FIXME: is there a bug here somewhere? the total num steps depends
228
  # on the agreed on value for sample_packing_eff_est
229
- total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
 
 
 
 
 
 
230
 
231
  def calc_sample_packing_eff_est(estimates: List[float]):
232
  LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
@@ -246,7 +238,12 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
246
  )
247
  else:
248
  total_num_steps = int(
249
- math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
 
 
 
 
 
250
  )
251
  LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
252
  return total_num_steps
 
8
  import numpy as np
9
  import torch
10
  import torch.cuda
 
11
  from accelerate.logging import get_logger
12
  from datasets import set_caching_enabled
13
+ from torch.utils.data import DataLoader, RandomSampler
14
 
15
  from axolotl.core.trainer_builder import HFCausalTrainerBuilder
16
+ from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
17
+ from axolotl.utils.samplers import MultipackBatchSampler
 
 
 
 
 
 
18
 
19
  LOG = get_logger("axolotl")
20
 
 
141
  return train_dataset, eval_dataset
142
 
143
 
144
+ def calculate_total_num_steps(cfg, train_dataset):
145
  if cfg.sample_packing:
146
  # we have to drop anything longer then sequence len otherwise
147
  # flash attention with position ids fails
 
189
  main_process_only=True,
190
  )
191
  else:
192
+ sampler = MultipackBatchSampler(
193
+ sampler=RandomSampler(train_dataset),
 
 
 
 
 
 
 
 
 
 
194
  batch_size=cfg.micro_batch_size,
195
+ drop_last=True,
196
+ batch_max_len=cfg.micro_batch_size
197
+ * (cfg.max_packed_sequence_len or cfg.sequence_len),
198
+ lengths=(
199
+ train_dataset.data.column("position_ids")
200
+ .to_pandas()
201
+ .apply(lambda x: x[-1] + 1)
202
+ .values
203
  ),
 
 
 
 
 
204
  )
205
+
206
+ data_loader = DataLoader(
207
+ train_dataset.remove_columns(["length"]),
208
+ batch_sampler=sampler,
209
+ )
210
+ data_loader_len = len(data_loader)
211
+ actual_eff = sampler.efficiency()
212
  LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
213
  # FIXME: is there a bug here somewhere? the total num steps depends
214
  # on the agreed on value for sample_packing_eff_est
215
+ total_num_steps = int(
216
+ math.floor(
217
+ data_loader_len
218
+ * cfg.num_epochs
219
+ / int(os.environ.get("WORLD_SIZE", 1))
220
+ )
221
+ )
222
 
223
  def calc_sample_packing_eff_est(estimates: List[float]):
224
  LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
 
238
  )
239
  else:
240
  total_num_steps = int(
241
+ math.ceil(
242
+ len(train_dataset)
243
+ * cfg.num_epochs
244
+ / int(os.environ.get("WORLD_SIZE", 1))
245
+ / cfg.batch_size
246
+ )
247
  )
248
  LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
249
  return total_num_steps