winglian commited on
Commit
09f1543
1 Parent(s): 1991946

No gather single gpu (#523)

Browse files

* don't attempt to gather on multi-gpu

* also check distributed status in bench callback

src/axolotl/utils/callbacks.py CHANGED
@@ -27,6 +27,7 @@ from axolotl.utils.distributed import (
27
  barrier,
28
  gather_scalar_from_all_ranks,
29
  get_world_size,
 
30
  is_main_process,
31
  zero_first,
32
  )
@@ -270,10 +271,13 @@ def bench_eval_callback_factory(trainer, tokenizer):
270
  lambda: len(data_loader), get_world_size()
271
  )
272
 
273
- if not is_main_process():
274
  dist.gather_object(local_bench_names, dst=0)
275
  else:
276
- dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
 
 
 
277
  bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
278
  results = {f"{bench_split}_bench_loss": bench_loss}
279
 
 
27
  barrier,
28
  gather_scalar_from_all_ranks,
29
  get_world_size,
30
+ is_distributed,
31
  is_main_process,
32
  zero_first,
33
  )
 
271
  lambda: len(data_loader), get_world_size()
272
  )
273
 
274
+ if is_distributed() and not is_main_process():
275
  dist.gather_object(local_bench_names, dst=0)
276
  else:
277
+ if is_distributed():
278
+ dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
279
+ else:
280
+ gathered_bench_names = [local_bench_names]
281
  bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
282
  results = {f"{bench_split}_bench_loss": bench_loss}
283
 
src/axolotl/utils/distributed.py CHANGED
@@ -74,6 +74,8 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
74
  - A list of computed values from all ranks if on the gathering rank, otherwise None.
75
  """
76
  value_scalar = fn()
 
 
77
  value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
78
 
79
  if not is_main_process():
 
74
  - A list of computed values from all ranks if on the gathering rank, otherwise None.
75
  """
76
  value_scalar = fn()
77
+ if not is_distributed():
78
+ return [value_scalar]
79
  value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
80
 
81
  if not is_main_process():