yuyan-10b / megatron /dist_signal_handler.py
Shawn001's picture
Upload 131 files
23bd7af
raw
history blame
No virus
2.1 kB
import signal
import torch
def get_world_size():
if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
else:
world_size = 1
return world_size
def get_device(local_rank=None):
backend = torch.distributed.get_backend()
if backend == 'nccl':
if local_rank is None:
device = torch.device('cuda')
else:
device = torch.device(f'cuda:{local_rank}')
elif backend == 'gloo':
device = torch.device('cpu')
else:
raise RuntimeError
return device
def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None):
if not torch.distributed.is_available() or \
not torch.distributed.is_initialized():
return [item]
device = get_device(local_rank)
if group is not None:
group_size = group.size()
else:
group_size = get_world_size()
tensor = torch.tensor([item], device=device, dtype=dtype)
output_tensors = [
torch.zeros(1, dtype=tensor.dtype, device=tensor.device)
for _ in range(group_size)
]
torch.distributed.all_gather(output_tensors, tensor, group, async_op)
output = [elem.item() for elem in output_tensors]
return output
class DistributedSignalHandler:
def __init__(self, sig=signal.SIGTERM):
self.sig = sig
def signals_received(self):
all_received = all_gather_item(
self._signal_received, dtype=torch.int32
)
return all_received
def __enter__(self):
self._signal_received = False
self.released = False
self.original_handler = signal.getsignal(self.sig)
def handler(signum, frame):
self._signal_received = True
signal.signal(self.sig, handler)
return self
def __exit__(self, type, value, tb):
self.release()
def release(self):
if self.released:
return False
signal.signal(self.sig, self.original_handler)
self.released = True
return True