# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Model and data parallel groups.""" import torch from .utils import ensure_divisibility # Intra-layer model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None # Inter-layer model parallel group that the current rank belongs to. _PIPELINE_MODEL_PARALLEL_GROUP = None # Model parallel group (both intra- and pipeline) that the current rank belongs to. _MODEL_PARALLEL_GROUP = None # Embedding group. _EMBEDDING_GROUP = None # Position embedding group. _POSITION_EMBEDDING_GROUP = None # Data parallel group that the current rank belongs to. _DATA_PARALLEL_GROUP = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None # These values enable us to change the mpu sizes on the fly. _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None # A list of ranks that have a copy of the embedding. _EMBEDDING_GLOBAL_RANKS = None # A list of ranks that have a copy of the position embedding. _POSITION_EMBEDDING_GLOBAL_RANKS = None # A list of global ranks for each pipeline group to ease calculation of the source # rank when broadcasting from the first or last pipeline stage. _PIPELINE_GLOBAL_RANKS = None # A list of global ranks for each data parallel group to ease calculation of the source # rank when broadcasting weights from src to all other data parallel ranks _DATA_PARALLEL_GLOBAL_RANKS = None def is_unitialized(): """Useful for code segments that may be accessed with or without mpu initialization""" return _DATA_PARALLEL_GROUP is None def initialize_model_parallel(tensor_model_parallel_size_=1, pipeline_model_parallel_size_=1, virtual_pipeline_model_parallel_size_=None, pipeline_model_parallel_split_rank_=None): """ Initialize model data parallel groups. Arguments: tensor_model_parallel_size: number of GPUs used for tensor model parallelism. pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline). pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point. Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize the model pipeline. The present function will create 8 tensor model-parallel groups, 4 pipeline model-parallel groups and 8 data-parallel groups as: 8 data_parallel groups: [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] 8 tensor model-parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] 4 pipeline model-parallel groups: [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ if torch.distributed.get_rank() == 0: print('> initializing tensor model parallel with size {}'.format( tensor_model_parallel_size_)) print('> initializing pipeline model parallel with size {}'.format( pipeline_model_parallel_size_)) # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size) pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size) data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size) num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size num_data_parallel_groups = world_size // data_parallel_size if virtual_pipeline_model_parallel_size_ is not None: global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_ if pipeline_model_parallel_split_rank_ is not None: global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_ rank = torch.distributed.get_rank() # Build the data-parallel groups. global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GLOBAL_RANKS assert _DATA_PARALLEL_GROUP is None, \ 'data parallel group is already initialized' all_data_parallel_group_ranks = [] for i in range(pipeline_model_parallel_size): start_rank = i * num_pipeline_model_parallel_groups end_rank = (i + 1) * num_pipeline_model_parallel_groups for j in range(tensor_model_parallel_size): ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) all_data_parallel_group_ranks.append(list(ranks)) group = torch.distributed.new_group(ranks) if rank in ranks: _DATA_PARALLEL_GROUP = group _DATA_PARALLEL_GLOBAL_RANKS = ranks # Build the model-parallel groups. global _MODEL_PARALLEL_GROUP assert _MODEL_PARALLEL_GROUP is None, \ 'model parallel group is already initialized' for i in range(data_parallel_size): ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks] group = torch.distributed.new_group(ranks) if rank in ranks: _MODEL_PARALLEL_GROUP = group # Build the tensor model-parallel groups. global _TENSOR_MODEL_PARALLEL_GROUP assert _TENSOR_MODEL_PARALLEL_GROUP is None, \ 'tensor model parallel group is already initialized' for i in range(num_tensor_model_parallel_groups): ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) group = torch.distributed.new_group(ranks) if rank in ranks: _TENSOR_MODEL_PARALLEL_GROUP = group # Build the pipeline model-parallel groups and embedding groups # (first and last rank in each pipeline model-parallel group). global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_GLOBAL_RANKS assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \ 'pipeline model parallel group is already initialized' global _EMBEDDING_GROUP global _EMBEDDING_GLOBAL_RANKS assert _EMBEDDING_GROUP is None, \ 'embedding group is already initialized' global _POSITION_EMBEDDING_GROUP global _POSITION_EMBEDDING_GLOBAL_RANKS assert _POSITION_EMBEDDING_GROUP is None, \ 'position embedding group is already initialized' for i in range(num_pipeline_model_parallel_groups): ranks = range(i, world_size, num_pipeline_model_parallel_groups) group = torch.distributed.new_group(ranks) if rank in ranks: _PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_GLOBAL_RANKS = ranks # Setup embedding group (to exchange gradients between # first and last stages). if len(ranks) > 1: embedding_ranks = [ranks[0], ranks[-1]] position_embedding_ranks = [ranks[0]] if pipeline_model_parallel_split_rank_ is not None: if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks: embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank_], ranks[-1]] if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks: position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank_]] else: embedding_ranks = ranks position_embedding_ranks = ranks group = torch.distributed.new_group(embedding_ranks) if rank in embedding_ranks: _EMBEDDING_GROUP = group if rank in ranks: _EMBEDDING_GLOBAL_RANKS = embedding_ranks group = torch.distributed.new_group(position_embedding_ranks) if rank in position_embedding_ranks: _POSITION_EMBEDDING_GROUP = group if rank in ranks: _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks def model_parallel_is_initialized(): """Check if model and data parallel groups are initialized.""" if _TENSOR_MODEL_PARALLEL_GROUP is None or \ _PIPELINE_MODEL_PARALLEL_GROUP is None or \ _DATA_PARALLEL_GROUP is None: return False return True def get_model_parallel_group(): """Get the model parallel group the caller rank belongs to.""" assert _MODEL_PARALLEL_GROUP is not None, \ 'model parallel group is not initialized' return _MODEL_PARALLEL_GROUP def get_tensor_model_parallel_group(): """Get the tensor model parallel group the caller rank belongs to.""" assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \ 'intra_layer_model parallel group is not initialized' return _TENSOR_MODEL_PARALLEL_GROUP def get_pipeline_model_parallel_group(): """Get the pipeline model parallel group the caller rank belongs to.""" assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \ 'pipeline_model parallel group is not initialized' return _PIPELINE_MODEL_PARALLEL_GROUP def get_data_parallel_group(): """Get the data parallel group the caller rank belongs to.""" assert _DATA_PARALLEL_GROUP is not None, \ 'data parallel group is not initialized' return _DATA_PARALLEL_GROUP def get_embedding_group(): """Get the embedding group the caller rank belongs to.""" assert _EMBEDDING_GROUP is not None, \ 'embedding group is not initialized' return _EMBEDDING_GROUP def get_position_embedding_group(): """Get the position embedding group the caller rank belongs to.""" assert _POSITION_EMBEDDING_GROUP is not None, \ 'position embedding group is not initialized' return _POSITION_EMBEDDING_GROUP def set_tensor_model_parallel_world_size(world_size): """Set the tensor model parallel size""" global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size def set_pipeline_model_parallel_world_size(world_size): """Set the pipeline model parallel size""" global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size def get_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) def get_pipeline_model_parallel_world_size(): """Return world size for the pipeline model parallel group.""" global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) def set_tensor_model_parallel_rank(rank): """Set tensor model parallel rank.""" global _MPU_TENSOR_MODEL_PARALLEL_RANK _MPU_TENSOR_MODEL_PARALLEL_RANK = rank def set_pipeline_model_parallel_rank(rank): """Set pipeline model parallel rank.""" global _MPU_PIPELINE_MODEL_PARALLEL_RANK _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank def get_tensor_model_parallel_rank(): """Return my rank for the tensor model parallel group.""" global _MPU_TENSOR_MODEL_PARALLEL_RANK if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: return _MPU_TENSOR_MODEL_PARALLEL_RANK return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) def get_pipeline_model_parallel_rank(): """Return my rank for the pipeline model parallel group.""" global _MPU_PIPELINE_MODEL_PARALLEL_RANK if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: return _MPU_PIPELINE_MODEL_PARALLEL_RANK return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) def get_num_layers(args, is_encoder_and_decoder_model): """Compute the number of transformer layers resident on the current rank.""" if get_pipeline_model_parallel_world_size() > 1: if is_encoder_and_decoder_model: assert args.pipeline_model_parallel_split_rank is not None # When a standalone embedding stage is used, a rank is taken from # the encoder's ranks, to be used for the encoder's embedding # layer. This way, the rank referenced by the 'split rank' remains # the same whether or not a standalone embedding stage is used. num_ranks_in_encoder = ( args.pipeline_model_parallel_split_rank - 1 if args.standalone_embedding_stage else args.pipeline_model_parallel_split_rank ) num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder assert args.num_layers % num_ranks_in_encoder == 0, \ 'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers, num_ranks_in_encoder) assert args.num_layers % num_ranks_in_decoder == 0, \ 'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder) if is_pipeline_stage_before_split(): num_layers = ( 0 if args.standalone_embedding_stage and get_pipeline_model_parallel_rank() == 0 else args.num_layers // num_ranks_in_encoder ) else: num_layers = args.num_layers // num_ranks_in_decoder else: assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ 'num_layers must be divisible by transformer_pipeline_model_parallel_size' # When a standalone embedding stage is used, all transformer layers # are divided among pipeline rank >= 1, while on pipeline rank 0, # ranks either contain the input embedding layer (virtual pp rank 0), # or no layers at all (virtual pp rank >= 1). num_layers = ( 0 if args.standalone_embedding_stage and get_pipeline_model_parallel_rank() == 0 else args.num_layers // args.transformer_pipeline_model_parallel_size ) else: num_layers = args.num_layers return num_layers def get_num_layers_decoder(args, is_encoder_and_decoder_model): """Compute the number of transformer layers resident on the current rank.""" if get_pipeline_model_parallel_world_size() > 1: if is_encoder_and_decoder_model: assert args.pipeline_model_parallel_split_rank is not None # When a standalone embedding stage is used, a rank is taken from # the encoder's ranks, to be used for the encoder's embedding # layer. This way, the rank referenced by the 'split rank' remains # the same whether or not a standalone embedding stage is used. num_ranks_in_encoder = ( args.pipeline_model_parallel_split_rank - 1 if args.standalone_embedding_stage else args.pipeline_model_parallel_split_rank ) num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder assert args.num_layers_decoder % num_ranks_in_encoder == 0, \ 'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers_decoder, num_ranks_in_encoder) assert args.num_layers_decoder % num_ranks_in_decoder == 0, \ 'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers_decoder, num_ranks_in_decoder) if is_pipeline_stage_before_split(): num_layers = ( 0 if args.standalone_embedding_stage and get_pipeline_model_parallel_rank() == 0 else args.num_layers_decoder // num_ranks_in_encoder ) else: num_layers = args.num_layers_decoder // num_ranks_in_decoder else: assert args.num_layers_decoder % args.transformer_pipeline_model_parallel_size == 0, \ 'num_layers must be divisible by transformer_pipeline_model_parallel_size' # When a standalone embedding stage is used, all transformer layers # are divided among pipeline rank >= 1, while on pipeline rank 0, # ranks either contain the input embedding layer (virtual pp rank 0), # or no layers at all (virtual pp rank >= 1). num_layers = ( 0 if args.standalone_embedding_stage and get_pipeline_model_parallel_rank() == 0 else args.num_layers_decoder // args.transformer_pipeline_model_parallel_size ) else: num_layers = args.num_layers_decoder return num_layers def is_pipeline_first_stage(ignore_virtual=False): """Return True if in the first pipeline model-parallel stage, False otherwise.""" if not ignore_virtual: if get_virtual_pipeline_model_parallel_world_size() is not None and \ get_virtual_pipeline_model_parallel_rank() != 0: return False return get_pipeline_model_parallel_rank() == 0 def is_pipeline_last_stage(ignore_virtual=False): """Return True if in the last pipeline model-parallel stage, False otherwise.""" if not ignore_virtual: virtual_pipeline_model_parallel_world_size = \ get_virtual_pipeline_model_parallel_world_size() if virtual_pipeline_model_parallel_world_size is not None and \ get_virtual_pipeline_model_parallel_rank() != ( virtual_pipeline_model_parallel_world_size - 1): return False return get_pipeline_model_parallel_rank() == ( get_pipeline_model_parallel_world_size() - 1) def is_rank_in_embedding_group(ignore_virtual=False): """Return true if current rank is in embedding group, False otherwise.""" rank = torch.distributed.get_rank() global _EMBEDDING_GLOBAL_RANKS if ignore_virtual: return rank in _EMBEDDING_GLOBAL_RANKS if rank in _EMBEDDING_GLOBAL_RANKS: if rank == _EMBEDDING_GLOBAL_RANKS[0]: return is_pipeline_first_stage(ignore_virtual=False) elif rank == _EMBEDDING_GLOBAL_RANKS[-1]: return is_pipeline_last_stage(ignore_virtual=False) else: return True return False def is_rank_in_position_embedding_group(): """Return true if current rank is in position embedding group, False otherwise.""" rank = torch.distributed.get_rank() global _POSITION_EMBEDDING_GLOBAL_RANKS return rank in _POSITION_EMBEDDING_GLOBAL_RANKS def is_pipeline_stage_before_split(rank=None): """Return True if pipeline stage executes encoder block for a model with both encoder and decoder.""" if get_pipeline_model_parallel_world_size() == 1: return True if rank is None: rank = get_pipeline_model_parallel_rank() global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: return True if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: return True return False def is_pipeline_stage_after_split(rank=None): """Return True if pipeline stage executes decoder block for a model with both encoder and decoder.""" if get_pipeline_model_parallel_world_size() == 1: return True if rank is None: rank = get_pipeline_model_parallel_rank() global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: return True if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: return True return False def is_pipeline_stage_at_split(): """Return true if pipeline stage executes decoder block and next stage executes encoder block for a model with both encoder and decoder.""" rank = get_pipeline_model_parallel_rank() return is_pipeline_stage_before_split(rank) and \ is_pipeline_stage_after_split(rank+1) def get_virtual_pipeline_model_parallel_rank(): """Return the virtual pipeline-parallel rank.""" global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK def set_virtual_pipeline_model_parallel_rank(rank): """Set the virtual pipeline-parallel rank.""" global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank def get_virtual_pipeline_model_parallel_world_size(): """Return the virtual pipeline-parallel world size.""" global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE def get_tensor_model_parallel_src_rank(): """Calculate the global rank corresponding to the first local rank in the tensor model parallel group.""" global_rank = torch.distributed.get_rank() local_world_size = get_tensor_model_parallel_world_size() return (global_rank // local_world_size) * local_world_size def get_data_parallel_src_rank(): """Calculate the global rank corresponding to the first local rank in the data parallel group.""" assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \ "Data parallel group is not initialized" return _DATA_PARALLEL_GLOBAL_RANKS[0] def get_pipeline_model_parallel_first_rank(): assert _PIPELINE_GLOBAL_RANKS is not None, \ "Pipeline parallel group is not initialized" return _PIPELINE_GLOBAL_RANKS[0] def get_pipeline_model_parallel_last_rank(): assert _PIPELINE_GLOBAL_RANKS is not None, \ "Pipeline parallel group is not initialized" last_rank_local = get_pipeline_model_parallel_world_size() - 1 return _PIPELINE_GLOBAL_RANKS[last_rank_local] def get_pipeline_model_parallel_next_rank(): assert _PIPELINE_GLOBAL_RANKS is not None, \ "Pipeline parallel group is not initialized" rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] def get_pipeline_model_parallel_prev_rank(): assert _PIPELINE_GLOBAL_RANKS is not None, \ "Pipeline parallel group is not initialized" rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] def get_data_parallel_world_size(): """Return world size for the data parallel group.""" return torch.distributed.get_world_size(group=get_data_parallel_group()) def get_data_parallel_rank(): """Return my rank for the data parallel group.""" return torch.distributed.get_rank(group=get_data_parallel_group()) def destroy_model_parallel(): """Set the groups to none.""" global _MODEL_PARALLEL_GROUP _MODEL_PARALLEL_GROUP = None global _TENSOR_MODEL_PARALLEL_GROUP _TENSOR_MODEL_PARALLEL_GROUP = None global _PIPELINE_MODEL_PARALLEL_GROUP _PIPELINE_MODEL_PARALLEL_GROUP = None global _DATA_PARALLEL_GROUP _DATA_PARALLEL_GROUP = None global _EMBEDDING_GROUP _EMBEDDING_GROUP = None global _POSITION_EMBEDDING_GROUP _POSITION_EMBEDDING_GROUP = None