winglian commited on
Commit
3553172
1 Parent(s): 7f2027d

fixes for alpaca w chatml, and don't include attention_mask w mistral for flash attention (#728)

Browse files
src/axolotl/prompt_strategies/alpaca_chat.py CHANGED
@@ -1,6 +1,6 @@
1
- """Module containing the AlpacaQAPromptTokenizingStrategy class"""
2
 
3
- from typing import Tuple
4
 
5
  from axolotl.prompt_tokenizers import (
6
  AlpacaPromptTokenizingStrategy,
@@ -9,9 +9,13 @@ from axolotl.prompt_tokenizers import (
9
  from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
10
 
11
 
12
- def load(tokenizer, cfg):
 
 
 
 
13
  return AlpacaPromptTokenizingStrategy(
14
- AlpacaPrompter(PromptStyle.CHAT.value),
15
  tokenizer,
16
  cfg.train_on_inputs,
17
  cfg.sequence_len,
 
1
+ """Module for Alpaca prompt strategy classes"""
2
 
3
+ from typing import Any, Dict, Optional, Tuple
4
 
5
  from axolotl.prompt_tokenizers import (
6
  AlpacaPromptTokenizingStrategy,
 
9
  from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
10
 
11
 
12
+ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
13
+ prompt_style = PromptStyle.CHAT.value
14
+ if ds_cfg and "conversation" in ds_cfg:
15
+ prompt_style = ds_cfg["conversation"]
16
+
17
  return AlpacaPromptTokenizingStrategy(
18
+ AlpacaPrompter(prompt_style),
19
  tokenizer,
20
  cfg.train_on_inputs,
21
  cfg.sequence_len,
src/axolotl/utils/trainer.py CHANGED
@@ -423,7 +423,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
423
  )
424
 
425
  # Phi doesn't want the attention_mask feature when training
426
- if "CodeGenTokenizer" in tokenizer.__class__.__name__:
 
 
427
  train_dataset = train_dataset.remove_columns("attention_mask")
428
  if eval_dataset:
429
  eval_dataset = eval_dataset.remove_columns("attention_mask")
 
423
  )
424
 
425
  # Phi doesn't want the attention_mask feature when training
426
+ if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
427
+ cfg.is_mistral_derived_model and cfg.flash_attention
428
+ ):
429
  train_dataset = train_dataset.remove_columns("attention_mask")
430
  if eval_dataset:
431
  eval_dataset = eval_dataset.remove_columns("attention_mask")