winglian commited on
Commit
e2e68c3
1 Parent(s): a27d594

testing mpt triton

Browse files
src/axolotl/utils/models.py CHANGED
@@ -8,7 +8,7 @@ import transformers
8
  from transformers import (
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
- PreTrainedModel,
12
  )
13
  try:
14
  from transformers import (
@@ -116,8 +116,14 @@ def load_model(
116
  trust_remote_code=True if cfg.trust_remote_code is True else False,
117
  )
118
  else:
 
 
 
 
 
119
  model = AutoModelForCausalLM.from_pretrained(
120
  base_model,
 
121
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
122
  torch_dtype=torch_dtype,
123
  device_map=cfg.device_map,
 
8
  from transformers import (
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
+ PreTrainedModel, AutoConfig,
12
  )
13
  try:
14
  from transformers import (
 
116
  trust_remote_code=True if cfg.trust_remote_code is True else False,
117
  )
118
  else:
119
+ config = AutoConfig.from_pretrained(
120
+ base_model,
121
+ trust_remote_code=True if cfg.trust_remote_code is True else False,
122
+ )
123
+ config.attn_config['attn_impl'] = 'triton'
124
  model = AutoModelForCausalLM.from_pretrained(
125
  base_model,
126
+ config=config,
127
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
128
  torch_dtype=torch_dtype,
129
  device_map=cfg.device_map,
src/axolotl/utils/wandb.py CHANGED
@@ -2,7 +2,9 @@ import os
2
 
3
 
4
  def setup_wandb_env_vars(cfg):
5
- if cfg.wandb_project and len(cfg.wandb_project) > 0:
 
 
6
  os.environ["WANDB_PROJECT"] = cfg.wandb_project
7
  cfg.use_wandb = True
8
  if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
 
2
 
3
 
4
  def setup_wandb_env_vars(cfg):
5
+ if cfg.wandb_mode and cfg.wandb_mode == "offline":
6
+ os.environ["WANDB_MODE"] = cfg.wandb_mode
7
+ elif cfg.wandb_project and len(cfg.wandb_project) > 0:
8
  os.environ["WANDB_PROJECT"] = cfg.wandb_project
9
  cfg.use_wandb = True
10
  if cfg.wandb_watch and len(cfg.wandb_watch) > 0: