winglian commited on
Commit
4c834bf
1 Parent(s): 8056ecd

cleanup verbosity a bit

Browse files
src/axolotl/train.py CHANGED
@@ -18,6 +18,7 @@ from axolotl.common.cli import TrainerCliArgs
18
  from axolotl.logging_config import configure_logging
19
  from axolotl.monkeypatch import neft_embeddings
20
  from axolotl.utils.dict import DictDefault
 
21
  from axolotl.utils.models import load_model, load_tokenizer
22
  from axolotl.utils.trainer import setup_trainer
23
 
@@ -44,7 +45,10 @@ def train(
44
  *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
45
  ):
46
  # load the tokenizer first
47
- LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
 
 
 
48
  tokenizer = load_tokenizer(cfg)
49
 
50
  train_dataset = dataset_meta.train_dataset
@@ -52,7 +56,10 @@ def train(
52
  total_num_steps = dataset_meta.total_num_steps
53
 
54
  # Load the model and tokenizer
55
- LOG.info("loading model and (optionally) peft_config...")
 
 
 
56
  model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
57
 
58
  safe_serialization = cfg.save_safetensors is True
 
18
  from axolotl.logging_config import configure_logging
19
  from axolotl.monkeypatch import neft_embeddings
20
  from axolotl.utils.dict import DictDefault
21
+ from axolotl.utils.distributed import zero_only
22
  from axolotl.utils.models import load_model, load_tokenizer
23
  from axolotl.utils.trainer import setup_trainer
24
 
 
45
  *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
46
  ):
47
  # load the tokenizer first
48
+ with zero_only():
49
+ LOG.debug(
50
+ f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}"
51
+ )
52
  tokenizer = load_tokenizer(cfg)
53
 
54
  train_dataset = dataset_meta.train_dataset
 
56
  total_num_steps = dataset_meta.total_num_steps
57
 
58
  # Load the model and tokenizer
59
+ msg = "loading model"
60
+ if cfg.adapter:
61
+ msg += " and peft_config..."
62
+ LOG.debug(msg)
63
  model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
64
 
65
  safe_serialization = cfg.save_safetensors is True
src/axolotl/utils/distributed.py CHANGED
@@ -50,6 +50,17 @@ def get_world_size():
50
  return int(os.getenv("WORLD_SIZE", "1"))
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
53
  @contextmanager
54
  def zero_first(is_main):
55
  """
 
50
  return int(os.getenv("WORLD_SIZE", "1"))
51
 
52
 
53
+ @contextmanager
54
+ def zero_only():
55
+ """
56
+ Context manager that only runs the enclosed block on the main rank.
57
+ """
58
+ if is_main_process():
59
+ yield
60
+ else:
61
+ yield None
62
+
63
+
64
  @contextmanager
65
  def zero_first(is_main):
66
  """
src/axolotl/utils/trainer.py CHANGED
@@ -21,6 +21,7 @@ from axolotl.utils.distributed import (
21
  is_main_process,
22
  reduce_and_broadcast,
23
  zero_first,
 
24
  )
25
 
26
  LOG = logging.getLogger("axolotl")
@@ -153,14 +154,14 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
153
  # we have to drop anything longer then sequence len otherwise
154
  # flash attention with position ids fails
155
  if not cfg.total_num_tokens:
156
- LOG.info("calculating total_num_tokens")
157
  total_num_tokens = np.sum(
158
  train_dataset.data.column("input_ids")
159
  .to_pandas()
160
  .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
161
  .values
162
  )
163
- LOG.info(f"total_num_tokens: {total_num_tokens}")
 
164
  cfg.total_num_tokens = total_num_tokens
165
 
166
  if not cfg.total_supervised_tokens:
@@ -170,7 +171,8 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
170
  .apply(lambda x: np.sum(np.array(x) != -100))
171
  .sum()
172
  )
173
- LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
 
174
  cfg.total_supervised_tokens = total_supervised_tokens
175
 
176
  if cfg.sample_packing_eff_est:
@@ -189,9 +191,10 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
189
  )
190
  * cfg.num_epochs
191
  )
192
- LOG.info(
193
- f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
194
- )
 
195
  else:
196
  if cfg.world_size > 1 and is_distributed():
197
  sampler = DistributedSampler(
@@ -220,7 +223,8 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
220
  )
221
  data_loader_len = data_loader.len_w_stats()
222
  actual_eff = data_loader.efficiency()
223
- LOG.info(f"data_loader_len: {data_loader_len}")
 
224
  # FIXME: is there a bug here somewhere? the total num steps depends
225
  # on the agreed on value for sample_packing_eff_est
226
  total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
@@ -237,12 +241,14 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
237
  math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
238
  )
239
  cfg.sample_packing_eff_est = sample_packing_eff_est
240
- LOG.info(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
 
241
  else:
242
  total_num_steps = int(
243
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
244
  )
245
- LOG.info(f"total_num_steps: {total_num_steps}")
 
246
  return total_num_steps
247
 
248
 
 
21
  is_main_process,
22
  reduce_and_broadcast,
23
  zero_first,
24
+ zero_only,
25
  )
26
 
27
  LOG = logging.getLogger("axolotl")
 
154
  # we have to drop anything longer then sequence len otherwise
155
  # flash attention with position ids fails
156
  if not cfg.total_num_tokens:
 
157
  total_num_tokens = np.sum(
158
  train_dataset.data.column("input_ids")
159
  .to_pandas()
160
  .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
161
  .values
162
  )
163
+ with zero_only():
164
+ LOG.debug(f"total_num_tokens: {total_num_tokens}")
165
  cfg.total_num_tokens = total_num_tokens
166
 
167
  if not cfg.total_supervised_tokens:
 
171
  .apply(lambda x: np.sum(np.array(x) != -100))
172
  .sum()
173
  )
174
+ with zero_only():
175
+ LOG.debug(f"`total_supervised_tokens: {total_supervised_tokens}`")
176
  cfg.total_supervised_tokens = total_supervised_tokens
177
 
178
  if cfg.sample_packing_eff_est:
 
191
  )
192
  * cfg.num_epochs
193
  )
194
+ with zero_only():
195
+ LOG.debug(
196
+ f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
197
+ )
198
  else:
199
  if cfg.world_size > 1 and is_distributed():
200
  sampler = DistributedSampler(
 
223
  )
224
  data_loader_len = data_loader.len_w_stats()
225
  actual_eff = data_loader.efficiency()
226
+ with zero_only():
227
+ LOG.debug(f"data_loader_len: {data_loader_len}")
228
  # FIXME: is there a bug here somewhere? the total num steps depends
229
  # on the agreed on value for sample_packing_eff_est
230
  total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
 
241
  math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
242
  )
243
  cfg.sample_packing_eff_est = sample_packing_eff_est
244
+ with zero_only():
245
+ LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
246
  else:
247
  total_num_steps = int(
248
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
249
  )
250
+ with zero_only():
251
+ LOG.debug(f"total_num_steps: {total_num_steps}")
252
  return total_num_steps
253
 
254