winglian commited on
Commit
1470650
1 Parent(s): 501b4d1

various bugfixes (#856)

Browse files

* various bugfixes

use latest tinyllama release
check if val_set_size is empty first
update sdp and xformers llama patches for updated upstream transformers
fix system prompt when no input
calculate total and total supervised tokens even when not sample packing

* add fix for when eval size is estimated to be too small

* should be len 1 for dataset length

* add catchall kwargs

examples/llama-2/tiny-llama.yml CHANGED
@@ -1,4 +1,4 @@
1
- base_model: PY007/TinyLlama-1.1B-step-50K-105b
2
 
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
 
1
+ base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
2
 
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
src/axolotl/core/trainer_builder.py CHANGED
@@ -543,16 +543,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
543
  "dataloader_prefetch_factor"
544
  ] = self.cfg.dataloader_prefetch_factor
545
 
546
- if self.cfg.eval_steps:
 
 
 
547
  training_arguments_kwargs["evaluation_strategy"] = "steps"
548
  training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
549
  elif self.cfg.evaluation_strategy:
550
  training_arguments_kwargs[
551
  "evaluation_strategy"
552
  ] = self.cfg.evaluation_strategy
553
- elif self.cfg.val_set_size == 0:
554
- # no eval set, so don't eval
555
- training_arguments_kwargs["evaluation_strategy"] = "no"
556
  else:
557
  # we have an eval set, but no steps defined, default to use epoch
558
  training_arguments_kwargs["evaluation_strategy"] = "epoch"
 
543
  "dataloader_prefetch_factor"
544
  ] = self.cfg.dataloader_prefetch_factor
545
 
546
+ if self.cfg.val_set_size == 0:
547
+ # no eval set, so don't eval
548
+ training_arguments_kwargs["evaluation_strategy"] = "no"
549
+ elif self.cfg.eval_steps:
550
  training_arguments_kwargs["evaluation_strategy"] = "steps"
551
  training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
552
  elif self.cfg.evaluation_strategy:
553
  training_arguments_kwargs[
554
  "evaluation_strategy"
555
  ] = self.cfg.evaluation_strategy
 
 
 
556
  else:
557
  # we have an eval set, but no steps defined, default to use epoch
558
  training_arguments_kwargs["evaluation_strategy"] = "epoch"
src/axolotl/monkeypatch/llama_attn_hijack_sdp.py CHANGED
@@ -25,6 +25,8 @@ def sdp_attention_forward(
25
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
26
  output_attentions: bool = False,
27
  use_cache: bool = False,
 
 
28
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
29
  # pylint: disable=duplicate-code
30
  bsz, q_len, _ = hidden_states.size()
 
25
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
26
  output_attentions: bool = False,
27
  use_cache: bool = False,
28
+ padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
29
+ **kwargs, # pylint: disable=unused-argument
30
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
31
  # pylint: disable=duplicate-code
32
  bsz, q_len, _ = hidden_states.size()
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py CHANGED
@@ -29,6 +29,8 @@ def xformers_forward(
29
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
30
  output_attentions: bool = False,
31
  use_cache: bool = False,
 
 
32
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
33
  # pylint: disable=duplicate-code
34
  bsz, q_len, _ = hidden_states.size()
 
29
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
30
  output_attentions: bool = False,
31
  use_cache: bool = False,
32
+ padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
33
+ **kwargs, # pylint: disable=unused-argument
34
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
35
  # pylint: disable=duplicate-code
36
  bsz, q_len, _ = hidden_states.size()
src/axolotl/prompters.py CHANGED
@@ -75,7 +75,7 @@ class AlpacaPrompter(Prompter):
75
  else:
76
  res = (
77
  self.system_format.format(system=self.system_no_input_prompt)
78
- if self.system_prompt
79
  else ""
80
  ) + self.turn_no_input_format.format(instruction=instruction)
81
  if output:
 
75
  else:
76
  res = (
77
  self.system_format.format(system=self.system_no_input_prompt)
78
+ if self.system_no_input_prompt
79
  else ""
80
  ) + self.turn_no_input_format.format(instruction=instruction)
81
  if output:
src/axolotl/utils/samplers/multipack.py CHANGED
@@ -181,13 +181,16 @@ class MultipackBatchSampler(BatchSampler):
181
  )
182
 
183
  # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
184
- return (
185
- world_size
186
- * math.floor(
187
- 0.99
188
- * lengths_sum_per_device
189
- / self.packing_efficiency_estimate
190
- // self.batch_max_len
191
- )
192
- - 1
 
 
 
193
  )
 
181
  )
182
 
183
  # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
184
+ return min(
185
+ 1,
186
+ (
187
+ world_size
188
+ * math.floor(
189
+ 0.99
190
+ * lengths_sum_per_device
191
+ / self.packing_efficiency_estimate
192
+ // self.batch_max_len
193
+ )
194
+ - 1
195
+ ),
196
  )
src/axolotl/utils/trainer.py CHANGED
@@ -142,31 +142,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
142
 
143
 
144
  def calculate_total_num_steps(cfg, train_dataset):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  if cfg.sample_packing:
146
  # we have to drop anything longer then sequence len otherwise
147
  # flash attention with position ids fails
148
- if not cfg.total_num_tokens:
149
- total_num_tokens = np.sum(
150
- train_dataset.data.column("input_ids")
151
- .to_pandas()
152
- .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
153
- .values
154
- )
155
- LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
156
- cfg.total_num_tokens = total_num_tokens
157
-
158
- if not cfg.total_supervised_tokens:
159
- total_supervised_tokens = (
160
- train_dataset.data.column("labels")
161
- .to_pandas()
162
- .apply(lambda x: np.sum(np.array(x) != -100))
163
- .sum()
164
- )
165
- LOG.debug(
166
- f"`total_supervised_tokens: {total_supervised_tokens}`",
167
- main_process_only=True,
168
- )
169
- cfg.total_supervised_tokens = total_supervised_tokens
170
 
171
  if cfg.sample_packing_eff_est:
172
  total_num_steps = (
 
142
 
143
 
144
  def calculate_total_num_steps(cfg, train_dataset):
145
+ if not cfg.total_num_tokens:
146
+ total_num_tokens = np.sum(
147
+ train_dataset.data.column("input_ids")
148
+ .to_pandas()
149
+ .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
150
+ .values
151
+ )
152
+ LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
153
+ cfg.total_num_tokens = total_num_tokens
154
+
155
+ if not cfg.total_supervised_tokens:
156
+ total_supervised_tokens = (
157
+ train_dataset.data.column("labels")
158
+ .to_pandas()
159
+ .apply(lambda x: np.sum(np.array(x) != -100))
160
+ .sum()
161
+ )
162
+ LOG.debug(
163
+ f"`total_supervised_tokens: {total_supervised_tokens}`",
164
+ main_process_only=True,
165
+ )
166
+ cfg.total_supervised_tokens = total_supervised_tokens
167
+
168
  if cfg.sample_packing:
169
  # we have to drop anything longer then sequence len otherwise
170
  # flash attention with position ids fails
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  if cfg.sample_packing_eff_est:
173
  total_num_steps = (