diff --git a/.gitattributes b/.gitattributes index 890040e1f717fb4d46cd3ed141bec5e4b48a4f63..62b3222bf8a6723f48f99da383503d6abaf5e868 100644 --- a/.gitattributes +++ b/.gitattributes @@ -36,3 +36,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text clue_data/csl/train.json filter=lfs diff=lfs merge=lfs -text clue_data/iflytek/train.json filter=lfs diff=lfs merge=lfs -text clue_data/ocnli/train.json filter=lfs diff=lfs merge=lfs -text +megatron/fused_kernels/build/scaled_masked_softmax_cuda.cuda.o filter=lfs diff=lfs merge=lfs -text +megatron/fused_kernels/build/scaled_masked_softmax_cuda.so filter=lfs diff=lfs merge=lfs -text +megatron/fused_kernels/build/scaled_softmax_cuda.cuda.o filter=lfs diff=lfs merge=lfs -text +megatron/fused_kernels/build/scaled_softmax_cuda.so filter=lfs diff=lfs merge=lfs -text +megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.so filter=lfs diff=lfs merge=lfs -text diff --git a/megatron/__init__.py b/megatron/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e195f969e34c2c95fd712b0fef9b54328f740f85 --- /dev/null +++ b/megatron/__init__.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from .global_vars import get_args +from .global_vars import get_current_global_batch_size +from .global_vars import get_num_microbatches +from .global_vars import get_signal_handler +from .global_vars import update_num_microbatches +from .global_vars import get_tokenizer +from .global_vars import get_tensorboard_writer +from .global_vars import get_adlr_autoresume +from .global_vars import get_timers +from .global_vars import get_global_memory_buffer +from .initialize import initialize_megatron + +from .utils import (print_rank_0, + is_last_rank, + print_rank_last) diff --git a/megatron/arguments.py b/megatron/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..262e26548b1ca189837e12187191d08cd6814e93 --- /dev/null +++ b/megatron/arguments.py @@ -0,0 +1,1018 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron arguments.""" + +import argparse +import os + +import torch + +def parse_args(extra_args_provider=None, ignore_unknown_args=False): + """Parse all arguments.""" + parser = argparse.ArgumentParser(description='Megatron-LM Arguments', + allow_abbrev=False) + + # Standard arguments. + parser = _add_network_size_args(parser) + parser = _add_regularization_args(parser) + parser = _add_training_args(parser) + parser = _add_initialization_args(parser) + parser = _add_learning_rate_args(parser) + parser = _add_checkpointing_args(parser) + parser = _add_mixed_precision_args(parser) + parser = _add_distributed_args(parser) + parser = _add_validation_args(parser) + parser = _add_data_args(parser) + parser = _add_autoresume_args(parser) + parser = _add_biencoder_args(parser) + parser = _add_vision_args(parser) + parser = _add_logging_args(parser) + parser = _add_inference_args(parser) + + # Custom arguments. + if extra_args_provider is not None: + parser = extra_args_provider(parser) + + # Parse. + if ignore_unknown_args: + args, _ = parser.parse_known_args() + else: + args = parser.parse_args() + + # Args from environment + args.rank = int(os.getenv('RANK', '0')) + args.world_size = int(os.getenv("WORLD_SIZE", '1')) + + return args + +def validate_args(args, defaults={}): + # Tensor model parallel size. + args.tensor_model_parallel_size = min( + args.tensor_model_parallel_size, args.world_size) + assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\ + ' ({}) is not divisible by tensor model parallel size ({})'.format( + args.world_size, args.tensor_model_parallel_size) + # Pipeline model parallel size. + args.pipeline_model_parallel_size = min( + args.pipeline_model_parallel_size, + (args.world_size // args.tensor_model_parallel_size)) + args.transformer_pipeline_model_parallel_size = ( + args.pipeline_model_parallel_size - 1 + if args.standalone_embedding_stage else + args.pipeline_model_parallel_size + ) + # Checks. + model_parallel_size = args.pipeline_model_parallel_size * \ + args.tensor_model_parallel_size + assert args.world_size % model_parallel_size == 0, 'world size is not'\ + ' divisible by tensor parallel size ({}) times pipeline parallel ' \ + 'size ({})'.format(args.world_size, args.tensor_model_parallel_size, + args.pipeline_model_parallel_size) + args.data_parallel_size = args.world_size // model_parallel_size + if args.rank == 0: + print('using world size: {}, data-parallel-size: {}, ' + 'tensor-model-parallel size: {}, ' + 'pipeline-model-parallel size: {} '.format( + args.world_size, args.data_parallel_size, + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size), flush=True) + if args.pipeline_model_parallel_size > 1: + if args.pipeline_model_parallel_split_rank is not None: + assert args.pipeline_model_parallel_split_rank < \ + args.pipeline_model_parallel_size, 'split rank needs'\ + ' to be less than pipeline model parallel size ({})'.format( + args.pipeline_model_parallel_size) + if args.data_path: + # Dataset arguments + data_path = args.data_path + processed_data_path = [] + for path in data_path: + files = os.listdir(path) + idx_files = [fn[:-4] for fn in files if fn.endswith('.idx')] + bin_files = [fn[:-4] for fn in files if fn.endswith('.bin')] + for idx_fn in idx_files: + if idx_fn in bin_files: + # add weight and data path + processed_data_path.append('1') + processed_data_path.append(os.path.join(path, idx_fn)) + args.raw_data_path = data_path + args.data_path = processed_data_path + + + # Deprecated arguments + assert args.batch_size is None, '--batch-size argument is no longer ' \ + 'valid, use --micro-batch-size instead' + del args.batch_size + assert args.warmup is None, '--warmup argument is no longer valid, use ' \ + '--lr-warmup-fraction instead' + del args.warmup + assert args.model_parallel_size is None, '--model-parallel-size is no ' \ + 'longer valid, use --tensor-model-parallel-size instead' + del args.model_parallel_size + + if args.checkpoint_activations: + args.recompute_granularity = 'full' + args.recompute_method = 'uniform' + if args.rank == 0: + print('--checkpoint-activations is no longer valid, ' + 'use --recompute-granularity and --recompute-method instead. ' + 'Defaulting to recompute-granularity=full and recompute-method=uniform.') + del args.checkpoint_activations + + if args.recompute_activations: + args.recompute_granularity = 'selective' + del args.recompute_activations + + # Set input defaults. + for key in defaults: + # For default to be valid, it should not be provided in the + # arguments that are passed to the program. We check this by + # ensuring the arg is set to None. + if getattr(args, key) is not None: + if args.rank == 0: + print('WARNING: overriding default arguments for {key}:{v} \ + with {key}:{v2}'.format(key=key, v=defaults[key], + v2=getattr(args, key)), + flush=True) + else: + setattr(args, key, defaults[key]) + + # Batch size. + assert args.micro_batch_size is not None + assert args.micro_batch_size > 0 + if args.global_batch_size is None: + args.global_batch_size = args.micro_batch_size * args.data_parallel_size + if args.rank == 0: + print('setting global batch size to {}'.format( + args.global_batch_size), flush=True) + assert args.global_batch_size > 0 + if args.num_layers_per_virtual_pipeline_stage is not None: + assert args.pipeline_model_parallel_size > 2, \ + 'pipeline-model-parallel size should be greater than 2 with ' \ + 'interleaved schedule' + assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \ + 'number of layers is not divisible by number of layers per virtual ' \ + 'pipeline stage' + args.virtual_pipeline_model_parallel_size = \ + (args.num_layers // args.transformer_pipeline_model_parallel_size) // \ + args.num_layers_per_virtual_pipeline_stage + else: + args.virtual_pipeline_model_parallel_size = None + + # Parameters dtype. + args.params_dtype = torch.float + if args.fp16: + assert not args.bf16 + args.params_dtype = torch.half + if args.bf16: + assert not args.fp16 + args.params_dtype = torch.bfloat16 + # bfloat16 requires gradient accumulation and all-reduce to + # be done in fp32. + if not args.accumulate_allreduce_grads_in_fp32: + args.accumulate_allreduce_grads_in_fp32 = True + if args.rank == 0: + print('accumulate and all-reduce gradients in fp32 for ' + 'bfloat16 data type.', flush=True) + + if args.rank == 0: + print('using {} for parameters ...'.format(args.params_dtype), + flush=True) + + # If we do accumulation and all-reduces in fp32, we need to have local DDP + # and we should make sure use-contiguous-buffers-in-local-ddp is not off. + if args.accumulate_allreduce_grads_in_fp32: + assert args.DDP_impl == 'local' + assert args.use_contiguous_buffers_in_local_ddp + else: + if args.gradient_accumulation_fusion: + args.gradient_accumulation_fusion = False + if args.rank == 0: + print('Gradient accumulation fusion to linear layer weight ' + 'gradient computation is supported only with fp32 ' + 'gradient accumulation. Setting gradient_accumulation_fusion ' + 'to False', flush=True) + + # If we use the distributed optimizer, we need to have local DDP + # and we should make sure use-contiguous-buffers-in-local-ddp is on. + if args.use_distributed_optimizer: + assert args.DDP_impl == 'local' + assert args.use_contiguous_buffers_in_local_ddp + + # For torch DDP, we do not use contiguous buffer + if args.DDP_impl == 'torch': + args.use_contiguous_buffers_in_local_ddp = False + + if args.dataloader_type is None: + args.dataloader_type = 'single' + + # Consumed tokens. + args.consumed_train_samples = 0 + args.consumed_valid_samples = 0 + + # Iteration-based training. + if args.train_iters: + # If we use iteration-based training, make sure the + # sample-based options are off. + assert args.train_samples is None, \ + 'expected iteration-based training' + assert args.lr_decay_samples is None, \ + 'expected iteration-based learning rate decay' + assert args.lr_warmup_samples == 0, \ + 'expected iteration-based learning rate warmup' + assert args.rampup_batch_size is None, \ + 'expected no batch-size rampup for iteration-based training' + if args.lr_warmup_fraction is not None: + assert args.lr_warmup_iters == 0, \ + 'can only specify one of lr-warmup-fraction and lr-warmup-iters' + + # Sample-based training. + if args.train_samples: + # If we use sample-based training, make sure the + # iteration-based options are off. + assert args.train_iters is None, \ + 'expected sample-based training' + assert args.lr_decay_iters is None, \ + 'expected sample-based learning rate decay' + assert args.lr_warmup_iters == 0, \ + 'expected sample-based learnig rate warmup' + if args.lr_warmup_fraction is not None: + assert args.lr_warmup_samples == 0, \ + 'can only specify one of lr-warmup-fraction ' \ + 'and lr-warmup-samples' + + # Check required arguments. + required_args = ['num_layers', 'hidden_size', 'num_attention_heads', + 'max_position_embeddings'] + for req_arg in required_args: + _check_arg_is_not_none(args, req_arg) + + # Checks. + if args.ffn_hidden_size is None: + args.ffn_hidden_size = 4 * args.hidden_size + + if args.kv_channels is None: + assert args.hidden_size % args.num_attention_heads == 0 + args.kv_channels = args.hidden_size // args.num_attention_heads + + if args.seq_length is not None: + assert args.encoder_seq_length is None + args.encoder_seq_length = args.seq_length + else: + assert args.encoder_seq_length is not None + args.seq_length = args.encoder_seq_length + + if args.seq_length is not None: + assert args.max_position_embeddings >= args.seq_length + if args.decoder_seq_length is not None: + assert args.max_position_embeddings >= args.decoder_seq_length + if args.lr is not None: + assert args.min_lr <= args.lr + if args.save is not None: + assert args.save_interval is not None + # Mixed precision checks. + if args.fp16_lm_cross_entropy: + assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' + if args.fp32_residual_connection: + assert args.fp16 or args.bf16, \ + 'residual connection in fp32 only supported when using fp16 or bf16.' + + if args.weight_decay_incr_style == 'constant': + assert args.start_weight_decay is None + assert args.end_weight_decay is None + args.start_weight_decay = args.weight_decay + args.end_weight_decay = args.weight_decay + else: + assert args.start_weight_decay is not None + assert args.end_weight_decay is not None + + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + # Persistent fused layer norm. + if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11): + args.no_persist_layer_norm = True + if args.rank == 0: + print('Persistent fused layer norm kernel is supported from ' + 'pytorch v1.11 (nvidia pytorch container paired with v1.11). ' + 'Defaulting to no_persist_layer_norm=True') + + # Activation recomputing. + if args.distribute_saved_activations: + assert args.tensor_model_parallel_size > 1, 'can distribute ' \ + 'recomputed activations only across tensor model ' \ + 'parallel groups' + assert args.recompute_granularity == 'full', \ + 'distributed recompute activations is only '\ + 'application to full recompute granularity' + assert args.recompute_method is not None, \ + 'for distributed recompute activations to work you '\ + 'need to use a recompute method ' + assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \ + 'distributed recompute activations are supported for pytorch ' \ + 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ + 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR) + + if args.recompute_granularity == 'selective': + assert args.recompute_method is None, \ + 'recompute method is not yet supported for ' \ + 'selective recomputing granularity' + + # disable sequence parallelism when tp=1 + # to avoid change in numerics when + # sequence_parallelism is enabled. + if args.tensor_model_parallel_size == 1: + args.sequence_parallel = False + + # disable async_tensor_model_parallel_allreduce when + # model parallel memory optimization is enabled + if args.sequence_parallel: + args.async_tensor_model_parallel_allreduce = False + + _print_args(args) + return args + + +def _print_args(args): + """Print arguments.""" + if args.rank == 0: + print('------------------------ arguments ------------------------', + flush=True) + str_list = [] + for arg in vars(args): + dots = '.' * (48 - len(arg)) + str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) + for arg in sorted(str_list, key=lambda x: x.lower()): + print(arg, flush=True) + print('-------------------- end of arguments ---------------------', + flush=True) + + +def _check_arg_is_not_none(args, arg): + assert getattr(args, arg) is not None, '{} argument is None'.format(arg) + + +def _add_inference_args(parser): + group = parser.add_argument_group(title='inference') + + group.add_argument('--inference-batch-times-seqlen-threshold', + type=int, default=512, + help='During inference, if batch-size times ' + 'sequence-length is smaller than this threshold ' + 'then we will not use pipelining, otherwise we will.') + + return parser + + +def _add_network_size_args(parser): + group = parser.add_argument_group(title='network size') + + group.add_argument('--num-layers', type=int, default=None, + help='Number of transformer layers.') + group.add_argument('--num-layers-decoder', type=int, default=None, + help='Number of transformer layers decoder.') + group.add_argument('--hidden-size', type=int, default=None, + help='Tansformer hidden size.') + group.add_argument('--ffn-hidden-size', type=int, default=None, + help='Transformer Feed-Forward Network hidden size. ' + 'This is set to 4*hidden-size if not provided') + group.add_argument('--num-attention-heads', type=int, default=None, + help='Number of transformer attention heads.') + group.add_argument('--kv-channels', type=int, default=None, + help='Projection weights dimension in multi-head ' + 'attention. This is set to ' + ' args.hidden_size // args.num_attention_heads ' + 'if not provided.') + group.add_argument('--max-position-embeddings', type=int, default=None, + help='Maximum number of position embeddings to use. ' + 'This is the size of position embedding.') + group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, + help='Pad the vocab size to be divisible by this value.' + 'This is added for computational efficieny reasons.') + group.add_argument('--layernorm-epsilon', type=float, default=1e-5, + help='Layer norm epsilon.') + group.add_argument('--apply-residual-connection-post-layernorm', + action='store_true', + help='If set, use original BERT residula connection ' + 'ordering.') + group.add_argument('--openai-gelu', action='store_true', + help='Use OpenAIs GeLU implementation. This option' + 'should not be used unless for backward compatibility' + 'reasons.') + group.add_argument('--onnx-safe', type=bool, required=False, + help='Use workarounds for known problems with ' + 'Torch ONNX exporter') + group.add_argument('--bert-no-binary-head', action='store_false', + help='Disable BERT binary head.', + dest='bert_binary_head') + group.add_argument('--num-experts', type=int, default=None, + help='Number of Experts in Switch Transformer (None means no Switch)') + return parser + + +def _add_logging_args(parser): + group = parser.add_argument_group(title='logging') + + group.add_argument('--log-params-norm', action='store_true', + help='If set, calculate and log parameters norm.') + group.add_argument('--log-num-zeros-in-grad', action='store_true', + help='If set, calculate and log the number of zeros in gradient.') + group.add_argument('--tensorboard-log-interval', type=int, default=1, + help='Report to tensorboard interval.') + group.add_argument('--tensorboard-queue-size', type=int, default=1000, + help='Size of the tensorboard queue for pending events ' + 'and summaries before one of the ‘add’ calls forces a ' + 'flush to disk.') + group.add_argument('--log-timers-to-tensorboard', action='store_true', + help='If set, write timers to tensorboard.') + group.add_argument('--log-batch-size-to-tensorboard', action='store_true', + help='If set, write batch-size to tensorboard.') + group.add_argument('--no-log-learnig-rate-to-tensorboard', + action='store_false', + help='Disable learning rate logging to tensorboard.', + dest='log_learning_rate_to_tensorboard') + group.add_argument('--no-log-loss-scale-to-tensorboard', + action='store_false', + help='Disable loss-scale logging to tensorboard.', + dest='log_loss_scale_to_tensorboard') + group.add_argument('--log-validation-ppl-to-tensorboard', + action='store_true', + help='If set, write validation perplexity to ' + 'tensorboard.') + group.add_argument('--log-memory-to-tensorboard', + action='store_true', + help='Enable memory logging to tensorboard.') + group.add_argument('--log-world-size-to-tensorboard', + action='store_true', + help='Enable world size logging to tensorboard.') + + return parser + + +def _add_regularization_args(parser): + group = parser.add_argument_group(title='regularization') + + group.add_argument('--attention-dropout', type=float, default=0.1, + help='Post attention dropout probability.') + group.add_argument('--hidden-dropout', type=float, default=0.1, + help='Dropout probability for hidden state transformer.') + group.add_argument('--weight-decay', type=float, default=0.01, + help='Weight decay coefficient for L2 regularization.') + group.add_argument('--start-weight-decay', type=float, + help='Initial weight decay coefficient for L2 regularization.') + group.add_argument('--end-weight-decay', type=float, + help='End of run weight decay coefficient for L2 regularization.') + group.add_argument('--weight-decay-incr-style', type=str, default='constant', + choices=['constant', 'linear', 'cosine'], + help='Weight decay increment function.') + group.add_argument('--clip-grad', type=float, default=1.0, + help='Gradient clipping based on global L2 norm.') + group.add_argument('--adam-beta1', type=float, default=0.9, + help='First coefficient for computing running averages ' + 'of gradient and its square') + group.add_argument('--adam-beta2', type=float, default=0.999, + help='Second coefficient for computing running averages ' + 'of gradient and its square') + group.add_argument('--adam-eps', type=float, default=1e-08, + help='Term added to the denominator to improve' + 'numerical stability') + group.add_argument('--sgd-momentum', type=float, default=0.9, + help='Momentum factor for sgd') + + return parser + + +def _add_training_args(parser): + group = parser.add_argument_group(title='training') + + group.add_argument('--micro-batch-size', type=int, default=None, + help='Batch size per model instance (local batch size). ' + 'Global batch size is local batch size times data ' + 'parallel size times number of micro batches.') + group.add_argument('--batch-size', type=int, default=None, + help='Old batch size parameter, do not use. ' + 'Use --micro-batch-size instead') + group.add_argument('--global-batch-size', type=int, default=None, + help='Training batch size. If set, it should be a ' + 'multiple of micro-batch-size times data-parallel-size. ' + 'If this value is None, then ' + 'use micro-batch-size * data-parallel-size as the ' + 'global batch size. This choice will result in 1 for ' + 'number of micro-batches.') + group.add_argument('--rampup-batch-size', nargs='*', default=None, + help='Batch size ramp up with the following values:' + ' --rampup-batch-size ' + ' ' + ' ' + 'For example:' + ' --rampup-batch-size 16 8 300000 \ ' + ' --global-batch-size 1024' + 'will start with global batch size 16 and over ' + ' (1024 - 16) / 8 = 126 intervals will increase' + 'the batch size linearly to 1024. In each interval' + 'we will use approximately 300000 / 126 = 2380 samples.') + group.add_argument('--recompute-activations', action='store_true', + help='recompute activation to allow for training ' + 'with larger models, sequences, and batch sizes.') + group.add_argument('--recompute-granularity', type=str, default=None, + choices=['full', 'selective'], + help='Checkpoint activations to allow for training ' + 'with larger models, sequences, and batch sizes. ' + 'It is supported at two granularities 1) full: ' + 'whole transformer layer is recomputed, ' + '2) selective: core attention part of the transformer ' + 'layer is recomputed.') + group.add_argument('--distribute-saved-activations', + action='store_true', + help='If set, distribute recomputed activations ' + 'across model parallel group.') + group.add_argument('--recompute-method', type=str, default=None, + choices=['uniform', 'block'], + help='1) uniform: uniformly divide the total number of ' + 'Transformer layers and recompute the input activation of ' + 'each divided chunk at specified granularity, ' + '2) recompute the input activations of only a set number of ' + 'individual Transformer layers per pipeline stage and do the ' + 'rest without any recomputing at specified granularity' + 'default) do not apply activations recompute to any layers') + group.add_argument('--recompute-num-layers', type=int, default=1, + help='1) uniform: the number of Transformer layers in each ' + 'uniformly divided recompute unit, ' + '2) block: the number of individual Transformer layers ' + 'to recompute within each pipeline stage.') + + # deprecated + group.add_argument('--checkpoint-activations', action='store_true', + help='Checkpoint activation to allow for training ' + 'with larger models, sequences, and batch sizes.') + group.add_argument('--train-iters', type=int, default=None, + help='Total number of iterations to train over all ' + 'training runs. Note that either train-iters or ' + 'train-samples should be provided.') + group.add_argument('--train-samples', type=int, default=None, + help='Total number of samples to train over all ' + 'training runs. Note that either train-iters or ' + 'train-samples should be provided.') + group.add_argument('--log-interval', type=int, default=100, + help='Report loss and timing interval.') + group.add_argument('--exit-interval', type=int, default=None, + help='Exit the program after the iteration is divisible ' + 'by this value.') + group.add_argument('--exit-duration-in-mins', type=int, default=None, + help='Exit the program after this many minutes.') + group.add_argument('--exit-signal-handler', action='store_true', + help='Dynamically save the checkpoint and shutdown the ' + 'training if SIGTERM is received') + group.add_argument('--tensorboard-dir', type=str, default=None, + help='Write TensorBoard logs to this directory.') + group.add_argument('--no-masked-softmax-fusion', + action='store_false', + help='Disable fusion of query_key_value scaling, ' + 'masking, and softmax.', + dest='masked_softmax_fusion') + group.add_argument('--no-bias-gelu-fusion', action='store_false', + help='Disable bias and gelu fusion.', + dest='bias_gelu_fusion') + group.add_argument('--no-bias-dropout-fusion', action='store_false', + help='Disable bias and dropout fusion.', + dest='bias_dropout_fusion') + group.add_argument('--optimizer', type=str, default='adam', + choices=['adam', 'sgd'], + help='Optimizer function') + group.add_argument('--dataloader-type', type=str, default=None, + choices=['single', 'cyclic'], + help='Single pass vs multiple pass data loader') + group.add_argument('--no-async-tensor-model-parallel-allreduce', + action='store_false', + help='Disable asynchronous execution of ' + 'tensor-model-parallel all-reduce with weight ' + 'gradient compuation of a column-linear layer.', + dest='async_tensor_model_parallel_allreduce') + group.add_argument('--no-persist-layer-norm', action='store_true', + help='Disable using persistent fused layer norm kernel. ' + 'This kernel supports only a set of hidden sizes. Please ' + 'check persist_ln_hidden_sizes if your hidden ' + 'size is supported.') + group.add_argument('--sequence-parallel', action='store_true', + help='Enable sequence parallel optimization.') + group.add_argument('--no-gradient-accumulation-fusion', + action='store_false', + help='Disable fusing gradient accumulation to weight ' + 'gradient computation of linear layers', + dest='gradient_accumulation_fusion') + return parser + + +def _add_initialization_args(parser): + group = parser.add_argument_group(title='initialization') + + group.add_argument('--seed', type=int, default=1234, + help='Random seed used for python, numpy, ' + 'pytorch, and cuda.') + group.add_argument('--data-parallel-random-init', action='store_true', + help='Enable random initialization of params ' + 'across data parallel ranks') + group.add_argument('--init-method-std', type=float, default=0.02, + help='Standard deviation of the zero mean normal ' + 'distribution used for weight initialization.') + group.add_argument('--init-method-xavier-uniform', action='store_true', + help='Enable Xavier uniform parameter initialization') + + return parser + + +def _add_learning_rate_args(parser): + group = parser.add_argument_group(title='learning rate') + + group.add_argument('--lr', type=float, default=None, + help='Initial learning rate. Depending on decay style ' + 'and initial warmup, the learing rate at each ' + 'iteration would be different.') + group.add_argument('--lr-decay-style', type=str, default='linear', + choices=['constant', 'linear', 'cosine'], + help='Learning rate decay function.') + group.add_argument('--lr-decay-iters', type=int, default=None, + help='number of iterations to decay learning rate over,' + ' If None defaults to `--train-iters`') + group.add_argument('--lr-decay-samples', type=int, default=None, + help='number of samples to decay learning rate over,' + ' If None defaults to `--train-samples`') + group.add_argument('--lr-warmup-fraction', type=float, default=None, + help='fraction of lr-warmup-(iters/samples) to use ' + 'for warmup (as a float)') + group.add_argument('--lr-warmup-iters', type=int, default=0, + help='number of iterations to linearly warmup ' + 'learning rate over.') + group.add_argument('--lr-warmup-samples', type=int, default=0, + help='number of samples to linearly warmup ' + 'learning rate over.') + group.add_argument('--warmup', type=int, default=None, + help='Old lr warmup argument, do not use. Use one of the' + '--lr-warmup-* arguments above') + group.add_argument('--min-lr', type=float, default=0.0, + help='Minumum value for learning rate. The scheduler' + 'clip values below this threshold.') + group.add_argument('--override-opt_param-scheduler', action='store_true', + help='Reset the values of the scheduler (learning rate,' + 'warmup iterations, minimum learning rate, maximum ' + 'number of iterations, and decay style from input ' + 'arguments and ignore values from checkpoints. Note' + 'that all the above values will be reset.') + group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true', + help='Use checkpoint to set the values of the scheduler ' + '(learning rate, warmup iterations, minimum learning ' + 'rate, maximum number of iterations, and decay style ' + 'from checkpoint and ignore input arguments.') + + return parser + + +def _add_checkpointing_args(parser): + group = parser.add_argument_group(title='checkpointing') + + group.add_argument('--save', type=str, default=None, + help='Output directory to save checkpoints to.') + group.add_argument('--save-interval', type=int, default=None, + help='Number of iterations between checkpoint saves.') + group.add_argument('--no-save-optim', action='store_true', default=None, + help='Do not save current optimizer.') + group.add_argument('--no-save-rng', action='store_true', default=None, + help='Do not save current rng state.') + group.add_argument('--load', type=str, default=None, + help='Directory containing a model checkpoint.') + group.add_argument('--no-load-optim', action='store_true', default=None, + help='Do not load optimizer when loading checkpoint.') + group.add_argument('--no-load-rng', action='store_true', default=None, + help='Do not load rng state when loading checkpoint.') + group.add_argument('--finetune', action='store_true', + help='Load model for finetuning. Do not load optimizer ' + 'or rng state from checkpoint and set iteration to 0. ' + 'Assumed when loading a release checkpoint.') + group.add_argument('--no-initialization', action='store_false', + help='Do not perform initialization when building model, ' + 'can reduce startup time when definitely loading from a ' + 'checkpoint', + dest='perform_initialization') + group.add_argument('--use-checkpoint-args', action='store_true', + help='Override any command line arguments with arguments ' + 'from the checkpoint') + + return parser + + +def _add_mixed_precision_args(parser): + group = parser.add_argument_group(title='mixed precision') + + group.add_argument('--fp16', action='store_true', + help='Run model in fp16 mode.') + group.add_argument('--bf16', action='store_true', + help='Run model in bfloat16 mode.') + group.add_argument('--loss-scale', type=float, default=None, + help='Static loss scaling, positive power of 2 ' + 'values can improve fp16 convergence. If None, dynamic' + 'loss scaling is used.') + group.add_argument('--initial-loss-scale', type=float, default=2**32, + help='Initial loss-scale for dynamic loss scaling.') + group.add_argument('--min-loss-scale', type=float, default=1.0, + help='Minimum loss scale for dynamic loss scale.') + group.add_argument('--loss-scale-window', type=float, default=1000, + help='Window over which to raise/lower dynamic scale.') + group.add_argument('--hysteresis', type=int, default=2, + help='hysteresis for dynamic loss scaling') + group.add_argument('--fp32-residual-connection', action='store_true', + help='Move residual connections to fp32.') + group.add_argument('--no-query-key-layer-scaling', action='store_false', + help='Do not scale Q * K^T by 1 / layer-number.', + dest='apply_query_key_layer_scaling') + group.add_argument('--attention-softmax-in-fp32', action='store_true', + help='Run attention masking and softmax in fp32. ' + 'This flag is ignored unless ' + '--no-query-key-layer-scaling is specified.') + group.add_argument('--accumulate-allreduce-grads-in-fp32', + action='store_true', + help='Gradient accumulation and all-reduce in fp32.') + group.add_argument('--fp16-lm-cross-entropy', action='store_true', + help='Move the cross entropy unreduced loss calculation' + 'for lm head to fp16.') + + return parser + + +def _add_distributed_args(parser): + group = parser.add_argument_group(title='distributed') + + group.add_argument('--tensor-model-parallel-size', type=int, default=1, + help='Degree of tensor model parallelism.') + group.add_argument('--pipeline-model-parallel-size', type=int, default=1, + help='Degree of pipeline model parallelism.') + group.add_argument('--pipeline-model-parallel-split-rank', + type=int, default=None, + help='Rank where encoder and decoder should be split.') + group.add_argument('--model-parallel-size', type=int, default=None, + help='Old model parallel argument, do not use. Use ' + '--tensor-model-parallel-size instead.') + group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, + help='Number of layers per virtual pipeline stage') + group.add_argument('--distributed-backend', default='nccl', + choices=['nccl', 'gloo'], + help='Which backend to use for distributed training.') + group.add_argument('--DDP-impl', default='local', + choices=['local', 'torch'], + help='which DistributedDataParallel implementation ' + 'to use.') + group.add_argument('--no-contiguous-buffers-in-local-ddp', + action='store_false', help='If set, dont use ' + 'contiguous buffer in local DDP.', + dest='use_contiguous_buffers_in_local_ddp') + group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', + help='Use scatter/gather to optimize communication of tensors in pipeline', + dest='scatter_gather_tensors_in_pipeline') + group.add_argument('--local_rank', type=int, default=None, + help='local rank passed from distributed launcher.') + group.add_argument('--lazy-mpu-init', type=bool, required=False, + help='If set to True, initialize_megatron() ' + 'skips DDP initialization and returns function to ' + 'complete it instead.Also turns on ' + '--use-cpu-initialization flag. This is for ' + 'external DDP manager.' ) + group.add_argument('--use-cpu-initialization', action='store_true', + default=None, help='If set, affine parallel weights ' + 'initialization uses CPU' ) + group.add_argument('--empty-unused-memory-level', default=0, type=int, + choices=[0, 1, 2], + help='Call torch.cuda.empty_cache() each iteration ' + '(training and eval), to reduce fragmentation.' + '0=off, 1=moderate, 2=aggressive.') + group.add_argument('--standalone-embedding-stage', action='store_true', + default=False, help='If set, *input* embedding layer ' + 'is placed on its own pipeline stage, without any ' + 'transformer layers. (For T5, this flag currently only ' + 'affects the encoder embedding.)') + group.add_argument('--use-distributed-optimizer', action='store_true', + help='Use distributed optimizer.') + + return parser + + +def _add_validation_args(parser): + group = parser.add_argument_group(title='validation') + + group.add_argument('--eval-iters', type=int, default=100, + help='Number of iterations to run for evaluation' + 'validation/test for.') + group.add_argument('--eval-interval', type=int, default=1000, + help='Interval between running evaluation on ' + 'validation set.') + + return parser + + +def _add_data_args(parser): + group = parser.add_argument_group(title='data and dataloader') + + group.add_argument('--data-path', nargs='*', default=None, + help='Path to the training dataset. Accepted format:' + '1) a single data path, 2) multiple datasets in the' + 'form: dataset1-weight dataset1-path dataset2-weight ' + 'dataset2-path ...') + group.add_argument('--split', type=str, default='969, 30, 1', + help='Comma-separated list of proportions for training,' + ' validation, and test split. For example the split ' + '`90,5,5` will use 90%% of data for training, 5%% for ' + 'validation and 5%% for test.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to the vocab file.') + group.add_argument('--merge-file', type=str, default=None, + help='Path to the BPE merge file.') + group.add_argument('--vocab-extra-ids', type=int, default=0, + help='Number of additional vocabulary tokens. ' + 'They are used for span masking in the T5 model') + group.add_argument('--seq-length', type=int, default=None, + help='Maximum sequence length to process.') + group.add_argument('--encoder-seq-length', type=int, default=None, + help='Maximum encoder sequence length to process.' + 'This should be exclusive of --seq-length') + group.add_argument('--decoder-seq-length', type=int, default=None, + help="Maximum decoder sequence length to process.") + group.add_argument('--retriever-seq-length', type=int, default=256, + help='Maximum sequence length for the biencoder model ' + ' for retriever') + group.add_argument('--sample-rate', type=float, default=1.0, + help='sample rate for training data. Supposed to be 0 ' + ' < sample_rate < 1') + group.add_argument('--mask-prob', type=float, default=0.15, + help='Probability of replacing a token with mask.') + group.add_argument('--short-seq-prob', type=float, default=0.1, + help='Probability of producing a short sequence.') + group.add_argument('--mmap-warmup', action='store_true', + help='Warm up mmap files.') + group.add_argument('--num-workers', type=int, default=2, + help="Dataloader number of workers.") + group.add_argument('--tokenizer-type', type=str, + default=None, + choices=['BertWordPieceLowerCase', + 'BertWordPieceCase', + 'GPT2BPETokenizer'], + help='What type of tokenizer to use.') + group.add_argument('--data-impl', type=str, default='infer', + choices=['lazy', 'cached', 'mmap', 'infer'], + help='Implementation of indexed datasets.') + group.add_argument('--reset-position-ids', action='store_true', + help='Reset posistion ids after end-of-document token.') + group.add_argument('--reset-attention-mask', action='store_true', + help='Reset self attention maske after ' + 'end-of-document token.') + group.add_argument('--eod-mask-loss', action='store_true', + help='Mask loss for the end of document tokens.') + + return parser + + +def _add_autoresume_args(parser): + group = parser.add_argument_group(title='autoresume') + + group.add_argument('--adlr-autoresume', action='store_true', + help='Enable autoresume on adlr cluster.') + group.add_argument('--adlr-autoresume-interval', type=int, default=1000, + help='Intervals over which check for autoresume' + 'termination signal') + + return parser + + +def _add_biencoder_args(parser): + group = parser.add_argument_group(title='biencoder') + + # network size + group.add_argument('--ict-head-size', type=int, default=None, + help='Size of block embeddings to be used in ICT and ' + 'REALM (paper default: 128)') + group.add_argument('--biencoder-projection-dim', type=int, default=0, + help='Size of projection head used in biencoder (paper' + ' default: 128)') + group.add_argument('--biencoder-shared-query-context-model', action='store_true', + help='Whether to share the parameters of the query ' + 'and context models or not') + + # checkpointing + group.add_argument('--ict-load', type=str, default=None, + help='Directory containing an ICTBertModel checkpoint') + group.add_argument('--bert-load', type=str, default=None, + help='Directory containing an BertModel checkpoint ' + '(needed to start ICT and REALM)') + + # data + group.add_argument('--titles-data-path', type=str, default=None, + help='Path to titles dataset used for ICT') + group.add_argument('--query-in-block-prob', type=float, default=0.1, + help='Probability of keeping query in block for ' + 'ICT dataset') + group.add_argument('--use-one-sent-docs', action='store_true', + help='Whether to use one sentence documents in ICT') + group.add_argument('--evidence-data-path', type=str, default=None, + help='Path to Wikipedia Evidence frm DPR paper') + + # training + group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, + default=[], help="Which top-k accuracies to report " + "(e.g. '1 5 20')") + group.add_argument('--retriever-score-scaling', action='store_true', + help='Whether to scale retriever scores by inverse ' + 'square root of hidden size') + + # faiss index + group.add_argument('--block-data-path', type=str, default=None, + help='Where to save/load BlockData to/from') + group.add_argument('--embedding-path', type=str, default=None, + help='Where to save/load Open-Retrieval Embedding' + ' data to/from') + + # indexer + group.add_argument('--indexer-batch-size', type=int, default=128, + help='How large of batches to use when doing indexing ' + 'jobs') + group.add_argument('--indexer-log-interval', type=int, default=1000, + help='After how many batches should the indexer ' + 'report progress') + return parser + + +def _add_vision_args(parser): + group = parser.add_argument_group(title="vision") + + # general vision arguements + group.add_argument('--num-classes', type=int, default=1000, + help='num of classes in vision classificaiton task') + group.add_argument('--img-h', type=int, default=224, + help='Image height for vision classification task') + group.add_argument('--img-w', type=int, default=224, + help='Image height for vision classification task') + group.add_argument('--num-channels', type=int, default=3, + help='Number of channels in input image data') + group.add_argument('--patch-dim', type=int, default=16, + help='patch dimension') + group.add_argument('--classes-fraction', type=float, default=1.0, + help='training with fraction of classes.') + group.add_argument('--data-per-class-fraction', type=float, default=1.0, + help='training with fraction of data per class.') + group.add_argument('--no-data-sharding', action='store_false', + help='Disable data sharding.', + dest='data_sharding') + group.add_argument('--head-lr-mult', type=float, default=1.0, + help='learning rate multiplier for head during finetuning') + + # pretraining type and backbone selection` + group.add_argument('--vision-pretraining', action='store_true', + help='flag to indicate vision pretraining') + group.add_argument('--vision-pretraining-type', type=str, default='classify', + choices=['classify', 'inpaint', 'dino'], + help='pretraining objectives') + group.add_argument('--vision-backbone-type', type=str, default='vit', + choices=['vit', 'mit', 'swin'], + help='backbone types types') + group.add_argument('--swin-backbone-type', type=str, default='tiny', + choices=['tiny', 'base', 'h3'], + help='pretraining objectives') + + # inpainting arguments + group.add_argument('--mask-type', type=str, default='random', + choices=['random', 'row'], + help='mask types') + group.add_argument('--mask-factor', type=float, default=1.0, + help='mask size scaling parameter') + + # dino arguments + group.add_argument('--iter-per-epoch', type=int, default=1250, + help='iterations per epoch') + group.add_argument('--dino-local-img-size', type=int, default=96, + help='Image size for vision classification task') + group.add_argument('--dino-local-crops-number', type=int, default=10, + help='Number of local crops') + group.add_argument('--dino-head-hidden-size', type=int, default=2048, + help='Hidden dimension size in dino head') + group.add_argument('--dino-bottleneck-size', type=int, default=256, + help='Bottle neck dimension in dino head ') + group.add_argument('--dino-freeze-last-layer', type=float, default=1, + help='Freezing last layer weights') + group.add_argument('--dino-norm-last-layer', action='store_true', + help='Disable Norm in last layer.') + group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04, + help='warump teacher temperature') + group.add_argument('--dino-teacher-temp', type=float, default=0.07, + help='teacher temperature') + group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30, + help='warmup teacher temperaure epochs') + + return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca95a17085884543ab0d18656d3b0f149aa2815 --- /dev/null +++ b/megatron/checkpointing.py @@ -0,0 +1,675 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Input/output checkpointing.""" + +import os +import random +import sys +import numpy as np + +import torch + +from megatron import (mpu, + update_num_microbatches) +from .global_vars import get_args +from .utils import (unwrap_model, + print_rank_0) + + +_CHECKPOINT_VERSION = None + +def set_checkpoint_version(value): + global _CHECKPOINT_VERSION + if _CHECKPOINT_VERSION is not None: + assert _CHECKPOINT_VERSION == value, \ + "checkpoint versions do not match" + _CHECKPOINT_VERSION = value + +def get_checkpoint_version(): + global _CHECKPOINT_VERSION + return _CHECKPOINT_VERSION + +def check_checkpoint_args(checkpoint_args): + """Ensure fixed arguments for a model are the same for the input + arguments and the one retrieved from checkpoint.""" + args = get_args() + + def _compare(arg_name, old_arg_name=None): + if old_arg_name is not None: + checkpoint_value = getattr(checkpoint_args, old_arg_name) + else: + checkpoint_value = getattr(checkpoint_args, arg_name) + args_value = getattr(args, arg_name) + error_message = '{} value from checkpoint ({}) is not equal to the ' \ + 'input argument value ({}).'.format( + arg_name, checkpoint_value, args_value) + assert checkpoint_value == args_value, error_message + + _compare('num_layers') + _compare('hidden_size') + _compare('num_attention_heads') + if args.vocab_file: + _compare('max_position_embeddings') + _compare('make_vocab_size_divisible_by') + _compare('padded_vocab_size') + _compare('tokenizer_type') + if args.data_parallel_random_init: + _compare('data_parallel_random_init') + if get_checkpoint_version() < 3.0: + _compare('tensor_model_parallel_size', + old_arg_name='model_parallel_size') + if get_checkpoint_version() >= 3.0: + _compare('tensor_model_parallel_size') + _compare('pipeline_model_parallel_size') + +def ensure_directory_exists(filename): + """Build filename's path if it does not already exists.""" + dirname = os.path.dirname(filename) + if not os.path.exists(dirname): + os.makedirs(dirname) + + +def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release=False, + pipeline_parallel=None, tensor_rank=None, pipeline_rank=None): + """Determine the directory name for this rank's checkpoint.""" + if release: + directory = 'release' + else: + directory = 'iter_{:07d}'.format(iteration) + + # Use both the tensor and pipeline MP rank. + if pipeline_parallel is None: + pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1) + if tensor_rank is None: + tensor_rank = mpu.get_tensor_model_parallel_rank() + if pipeline_rank is None: + pipeline_rank = mpu.get_pipeline_model_parallel_rank() + + # Use both the tensor and pipeline MP rank. If using the distributed + # optimizer, then the optimizer's path must additionally include the + # data parallel rank. + if not pipeline_parallel: + common_path = os.path.join(checkpoints_path, directory, + f'mp_rank_{tensor_rank:02d}') + else: + common_path = os.path.join(checkpoints_path, directory, + f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}') + + if use_distributed_optimizer: + model_name = os.path.join(common_path, "model_rng.pt") + optim_name = os.path.join( + common_path + "_%03d" % mpu.get_data_parallel_rank(), + "optim.pt") + else: + model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt") + return model_name, optim_name + +def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimizer, release=False): + """Finds the checkpoint for rank 0 without knowing if we are using + pipeline parallelism or not. + + Since the checkpoint naming scheme changes if pipeline parallelism + is present, we need to look for both naming schemes if we don't + know if the checkpoint has pipeline parallelism. + + """ + + # Look for checkpoint with no pipelining + filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release, + pipeline_parallel=False, + tensor_rank=0, pipeline_rank=0) + if os.path.isfile(filenames[0]): + return filenames + + # Look for checkpoint with pipelining + filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release, + pipeline_parallel=True, + tensor_rank=0, pipeline_rank=0) + if os.path.isfile(filenames[0]): + return filenames + + return None, None + +def get_checkpoint_tracker_filename(checkpoints_path): + + """Tracker file rescords the latest chckpoint during + training to restart from.""" + return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') + + +def read_metadata(tracker_filename): + # Read the tracker file and either set the iteration or + # mark it as a release checkpoint. + iteration = 0 + release = False + with open(tracker_filename, 'r') as f: + metastring = f.read().strip() + try: + iteration = int(metastring) + except ValueError: + release = metastring == 'release' + if not release: + print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( + tracker_filename)) + sys.exit() + assert iteration > 0 or release, 'error parsing metadata file {}'.format( + tracker_filename) + + # Get the max iteration retrieved across the ranks. + if torch.distributed.is_initialized(): + iters_cuda = torch.cuda.LongTensor([iteration]) + torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX) + max_iter = iters_cuda[0].item() + + # We should now have all the same iteration. + # If not, print a warning and chose the maximum + # iteration across all ranks. + if iteration != max_iter: + print('WARNING: on rank {} found iteration {} in the ' + 'metadata while max iteration across the ranks ' + 'is {}, replacing it with max iteration.'.format( + rank, iteration, max_iter), flush=True) + else: + # When loading a checkpoint outside of training (for example, + # when editing it), we might not have torch distributed + # initialized, in this case, just assume we have the latest + max_iter = iteration + return max_iter, release + + +def get_rng_state(): + """ collect rng state across data parallel ranks """ + args = get_args() + rng_state = { + 'random_rng_state': random.getstate(), + 'np_rng_state': np.random.get_state(), + 'torch_rng_state': torch.get_rng_state(), + 'cuda_rng_state': torch.cuda.get_rng_state(), + 'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()} + + rng_state_list = None + if torch.distributed.is_initialized() and \ + mpu.get_data_parallel_world_size() > 1 and \ + args.data_parallel_random_init: + rng_state_list = \ + [None for i in range(mpu.get_data_parallel_world_size())] + torch.distributed.all_gather_object( + rng_state_list, + rng_state, + group=mpu.get_data_parallel_group()) + else: + rng_state_list = [rng_state] + + return rng_state_list + + +def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): + """Save a model checkpoint.""" + args = get_args() + + # Only rank zero of the data parallel writes to the disk. + model = unwrap_model(model) + + print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( + iteration, args.save)) + + # Collect rng state across data parallel ranks. + rng_state = get_rng_state() + + # Checkpoint file names. + model_checkpoint_name, optim_checkpoint_name = \ + get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer) + + # Collect args, model, RNG. + model_state_dict = {} + if not torch.distributed.is_initialized() \ + or mpu.get_data_parallel_rank() == 0: + + # Arguments, iteration, and model. + model_state_dict['args'] = args + model_state_dict['checkpoint_version'] = 3.0 + model_state_dict['iteration'] = iteration + if len(model) == 1: + model_state_dict['model'] = model[0].state_dict_for_save_checkpoint() + else: + for i in range(len(model)): + mpu.set_virtual_pipeline_model_parallel_rank(i) + model_state_dict['model%d' % i] = \ + model[i].state_dict_for_save_checkpoint() + + # RNG states. + if not args.no_save_rng: + model_state_dict["rng_state"] = rng_state + + # Collect optimizer state. (Optimizer is saved separately from the model, due + # to the conflicting data pattern when using the distributed optimizer.) + optim_state_dict = {} + if not args.no_save_optim \ + and (not torch.distributed.is_initialized() + or mpu.get_data_parallel_rank() == 0 + or args.use_distributed_optimizer): + + # Optimizer stuff. + if optimizer is not None: + optim_state_dict['optimizer'] = optimizer.state_dict() + if opt_param_scheduler is not None: + optim_state_dict['opt_param_scheduler'] = \ + opt_param_scheduler.state_dict() + + # Save. + if args.use_distributed_optimizer: + # Save model separate from optimizer. + if model_state_dict: + ensure_directory_exists(model_checkpoint_name) + torch.save(model_state_dict, model_checkpoint_name) + if optim_state_dict: + ensure_directory_exists(optim_checkpoint_name) + torch.save(optim_state_dict, optim_checkpoint_name) + else: + # Save model and optimizer together. + state_dict = {**model_state_dict, **optim_state_dict} + if state_dict: # only saves if populated (i.e., inherits conditions above) + ensure_directory_exists(model_checkpoint_name) + torch.save(state_dict, model_checkpoint_name) + + # Wait so everyone is done (necessary) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}'.format( + iteration, args.save)) + + # And update the latest iteration + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + tracker_filename = get_checkpoint_tracker_filename(args.save) + with open(tracker_filename, 'w') as f: + f.write(str(iteration)) + + # Wait so everyone is done (not necessary) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + +def _transpose_first_dim(t, num_splits, num_splits_first, model): + input_shape = t.size() + # We use a self_attention module but the values extracted aren't + # specific to self attention so should work for cross attention as well + while hasattr(model, 'module'): + model = model.module + attention_module = model.language_model.encoder.layers[0].self_attention + hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head + num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition + if num_splits_first: + """[num_splits * np * hn, h] + -->(view) [num_splits, np, hn, h] + -->(tranpose) [np, num_splits, hn, h] + -->(view) [np * num_splits * hn, h] """ + + intermediate_shape = \ + (num_splits, num_attention_heads_per_partition, + hidden_size_per_attention_head) + input_shape[1:] + + t = t.view(*intermediate_shape) + t = t.transpose(0, 1).contiguous() + else: + """[np * hn * num_splits, h] + -->(view) [np, hn, num_splits, h] + -->(tranpose) [np, num_splits, hn, h] + -->(view) [np * num_splits * hn, h] """ + + intermediate_shape = \ + (num_attention_heads_per_partition, + hidden_size_per_attention_head, num_splits) +\ + input_shape[1:] + + t = t.view(*intermediate_shape) + t = t.transpose(1, 2).contiguous() + t = t.view(*input_shape) + + return t + +def fix_query_key_value_ordering(model, checkpoint_version): + """Fix up query/key/value matrix ordering if checkpoint + version is smaller than 2.0 + """ + if checkpoint_version < 2.0: + if isinstance(model, list): + assert len(model)==1 + model = model[0] + for name, param in model.named_parameters(): + if name.endswith(('.query_key_value.weight', '.query_key_value.bias')): + if checkpoint_version == 0: + fixed_param = _transpose_first_dim(param.data, 3, True, model) + elif checkpoint_version == 1.0: + fixed_param = _transpose_first_dim(param.data, 3, False, model) + else: + print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") + sys.exit() + param.data.copy_(fixed_param) + if name.endswith(('.key_value.weight', '.key_value.bias')): + if checkpoint_version == 0: + fixed_param = _transpose_first_dim(param.data, 2, True, model) + elif checkpoint_version == 1.0: + fixed_param = _transpose_first_dim(param.data, 2, False, model) + else: + print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") + sys.exit() + param.data.copy_(fixed_param) + print_rank_0(" succesfully fixed query-key-values ordering for" + " checkpoint version {}".format(checkpoint_version)) + +def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False): + """ Load the base state_dict from the given directory + + If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. + """ + + + # Read the tracker file and set the iteration. + tracker_filename = get_checkpoint_tracker_filename(load_dir) + + # If no tracker file, return nothing + if not os.path.isfile(tracker_filename): + if not rank0: + print_rank_0('WARNING: could not find the metadata file {} '.format( + tracker_filename)) + print_rank_0(' will not load any checkpoints and will start from ' + 'random') + return None, None, False + + # Otherwise, read the tracker file and either set the iteration or + # mark it as a release checkpoint. + iteration, release = read_metadata(tracker_filename) + + # Checkpoint. + if rank0: + checkpoint_names = find_checkpoint_rank_0(load_dir, iteration, use_distributed_optimizer, + release) + else: + checkpoint_names = get_checkpoint_names(load_dir, iteration, use_distributed_optimizer, + release) + if release: + print_rank_0(f' loading release checkpoint from {load_dir}') + else: + print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}') + + model_checkpoint_name, optim_checkpoint_name = checkpoint_names + + # Load the checkpoint. + try: + model_state_dict = torch.load(model_checkpoint_name, map_location='cpu') + if use_distributed_optimizer: + optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu') + else: + optim_state_dict = model_state_dict + except ModuleNotFoundError: + from megatron.fp16_deprecated import loss_scaler + # For backward compatibility. + if not rank0: + print_rank_0(' > deserializing using the old code structure ...') + sys.modules['fp16.loss_scaler'] = sys.modules[ + 'megatron.fp16_deprecated.loss_scaler'] + sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ + 'megatron.fp16_deprecated.loss_scaler'] + model_state_dict = torch.load(model_checkpoint_name, map_location='cpu') + optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu') + sys.modules.pop('fp16.loss_scaler', None) + sys.modules.pop('megatron.fp16.loss_scaler', None) + except BaseException as e: + print_rank_0('could not load the checkpoint') + print_rank_0(e) + sys.exit() + + return model_state_dict, optim_state_dict, release + +def load_args_from_checkpoint(args, load_arg='load'): + """Set required arguments from the checkpoint specified in the + arguments. + + Will overwrite arguments that have a non-None default value, but + will leave any arguments that default to None as set. + + Returns the same args NameSpace with the new values added/updated. + + If no checkpoint is specified in args, or if the checkpoint is + there but invalid, the arguments will not be modified + + """ + load_dir = getattr(args, load_arg) + + if load_dir is None: + print_rank_0('No load directory specified, using provided arguments.') + return args + + model_state_dict, optim_state_dict, release = \ + _load_base_checkpoint(load_dir, + use_distributed_optimizer=args.use_distributed_optimizer, + rank0=True) + + # For args we only care about model state dict + state_dict = model_state_dict + + if not state_dict: + print_rank_0('Checkpoint not found to provide arguments, using provided arguments.') + return args + + if 'args' not in state_dict: + print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.') + return args + + checkpoint_args = state_dict['args'] + checkpoint_version = state_dict.get('checkpoint_version', 0) + args.iteration = state_dict['iteration'] + + def _set_arg(arg_name, old_arg_name=None, force=False): + if not force and getattr(args, arg_name, None) is not None: + return + + if old_arg_name is not None: + checkpoint_value = getattr(checkpoint_args, old_arg_name, None) + else: + checkpoint_value = getattr(checkpoint_args, arg_name, None) + + if checkpoint_value is not None: + print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint") + setattr(args, arg_name, checkpoint_value) + + _set_arg('num_layers') + _set_arg('hidden_size') + _set_arg('ffn_hidden_size') + _set_arg('seq_length') + _set_arg('num_attention_heads') + _set_arg('kv_channels') + _set_arg('max_position_embeddings') + _set_arg('tokenizer_type') + _set_arg('padded_vocab_size') + if checkpoint_version < 3.0: + _set_arg('tensor_model_parallel_size', + 'model_parallel_size') + else: + _set_arg('tensor_model_parallel_size', force=True) + _set_arg('pipeline_model_parallel_size', force=True) + _set_arg('num_layers_per_virtual_pipeline_stage') + return args + + +def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True): + """Load a model checkpoint and return the iteration. + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` of the checkpoint match the names of + parameters and buffers in model. + """ + args = get_args() + load_dir = getattr(args, load_arg) + + model = unwrap_model(model) + + model_state_dict, optim_state_dict, release = \ + _load_base_checkpoint(load_dir, + use_distributed_optimizer=args.use_distributed_optimizer, + rank0=False) + + if model_state_dict is None: + return 0 + + # set checkpoint version + set_checkpoint_version(model_state_dict.get('checkpoint_version', 0)) + + # Set iteration. + if args.finetune or release: + iteration = 0 + else: + try: + iteration = model_state_dict['iteration'] + except KeyError: + try: # Backward compatible with older checkpoints + iteration = model_state_dict['total_iters'] + except KeyError: + print_rank_0('A metadata file exists but unable to load ' + 'iteration from checkpoint {}, exiting'.format( + checkpoint_name)) + sys.exit() + + # Check arguments. + assert args.consumed_train_samples == 0 + assert args.consumed_valid_samples == 0 + if 'args' in model_state_dict: + checkpoint_args = model_state_dict['args'] + check_checkpoint_args(checkpoint_args) + args.consumed_train_samples = getattr(checkpoint_args, + 'consumed_train_samples', 0) + update_num_microbatches(consumed_samples=args.consumed_train_samples) + args.consumed_valid_samples = getattr(checkpoint_args, + 'consumed_valid_samples', 0) + else: + print_rank_0('could not find arguments in the checkpoint ...') + + # Model. + if len(model) == 1: + model[0].load_state_dict(model_state_dict['model'], strict=strict) + else: + for i in range(len(model)): + mpu.set_virtual_pipeline_model_parallel_rank(i) + model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict) + + # Fix up query/key/value matrix ordering if needed + checkpoint_version = get_checkpoint_version() + print_rank_0(f' checkpoint version {checkpoint_version}') + fix_query_key_value_ordering(model, checkpoint_version) + + # Optimizer. + if not release and not args.finetune and not args.no_load_optim: + try: + if optimizer is not None: + optimizer.load_state_dict(optim_state_dict['optimizer']) + if opt_param_scheduler is not None: + if 'lr_scheduler' in optim_state_dict: # backward compatbility + opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler']) + else: + opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler']) + except KeyError: + print_rank_0('Unable to load optimizer from checkpoint {}. ' + 'Specify --no-load-optim or --finetune to prevent ' + 'attempting to load the optimizer state, ' + 'exiting ...'.format(checkpoint_name)) + sys.exit() + + # rng states. + if not release and not args.finetune and not args.no_load_rng: + try: + if 'rng_state' in model_state_dict: + # access rng_state for data parallel rank + if args.data_parallel_random_init: + + rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()] + else: + rng_state = model_state_dict['rng_state'][0] + random.setstate(rng_state['random_rng_state']) + np.random.set_state(rng_state['np_rng_state']) + torch.set_rng_state(rng_state['torch_rng_state']) + torch.cuda.set_rng_state(rng_state['cuda_rng_state']) + # Check for empty states array + if not rng_state['rng_tracker_states']: + raise KeyError + mpu.get_cuda_rng_tracker().set_states( + rng_state['rng_tracker_states']) + else: # backward compatability + random.setstate(model_state_dict['random_rng_state']) + np.random.set_state(model_state_dict['np_rng_state']) + torch.set_rng_state(model_state_dict['torch_rng_state']) + torch.cuda.set_rng_state(model_state_dict['cuda_rng_state']) + # Check for empty states array + if not model_state_dict['rng_tracker_states']: + raise KeyError + mpu.get_cuda_rng_tracker().set_states( + model_state_dict['rng_tracker_states']) + except KeyError: + print_rank_0('Unable to load rng state from checkpoint {}. ' + 'Specify --no-load-rng or --finetune to prevent ' + 'attempting to load the rng state, ' + 'exiting ...'.format(checkpoint_name)) + sys.exit() + + # Some utilities want to load a checkpoint without distributed being initialized + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + print_rank_0(f' successfully loaded checkpoint from {args.load} ' + f'at iteration {iteration}') + + return iteration + + +def load_biencoder_checkpoint(model, only_query_model=False, + only_context_model=False, custom_load_path=None): + """ + selectively load retrieval models for indexing/retrieving + from saved checkpoints + """ + + args = get_args() + + model = unwrap_model(model) + + load_path = custom_load_path if custom_load_path is not None else args.load + + tracker_filename = get_checkpoint_tracker_filename(load_path) + with open(tracker_filename, 'r') as f: + iteration = int(f.read().strip()) + + checkpoint_name, _ = get_checkpoint_names(load_path, iteration, + args.use_distributed_optimizer, + release=False) + + if mpu.get_data_parallel_rank() == 0: + print('global rank {} is loading checkpoint {}'.format( + torch.distributed.get_rank(), checkpoint_name)) + + state_dict = torch.load(model_checkpoint_name, map_location='cpu') + ret_state_dict = state_dict['model'] + + if only_query_model: + ret_state_dict.pop('context_model') + if only_context_model: + ret_state_dict.pop('query_model') + + assert len(model) == 1 + model[0].load_state_dict(ret_state_dict) + torch.distributed.barrier() + + if mpu.get_data_parallel_rank() == 0: + print(' successfully loaded {}'.format(checkpoint_name)) + + return model diff --git a/megatron/data/Makefile b/megatron/data/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..8f9db7686696fbea6c94b998db4b40ef426c748d --- /dev/null +++ b/megatron/data/Makefile @@ -0,0 +1,9 @@ +CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color +CPPFLAGS += $(shell python3 -m pybind11 --includes) +LIBNAME = helpers +LIBEXT = $(shell python3-config --extension-suffix) + +default: $(LIBNAME)$(LIBEXT) + +%$(LIBEXT): %.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/megatron/data/__init__.py b/megatron/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd5f898c6bdf89c6cf0243af102d04f6efed86b8 --- /dev/null +++ b/megatron/data/__init__.py @@ -0,0 +1 @@ +from . import indexed_dataset diff --git a/megatron/data/autoaugment.py b/megatron/data/autoaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..585a4fa6a51214be7ecdf8cc58a81caea700582d --- /dev/null +++ b/megatron/data/autoaugment.py @@ -0,0 +1,320 @@ +"""AutoAugment data augmentation policy for ImageNet. + +-- Begin license text. + +MIT License + +Copyright (c) 2018 Philip Popien + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +-- End license text. + +Code adapted from https://github.com/DeepVoltaire/AutoAugment. + +This module implements the fixed AutoAugment data augmentation policy for ImageNet provided in +Appendix A, Table 9 of reference [1]. It does not include any of the search code for augmentation +policies. + +Reference: +[1] https://arxiv.org/abs/1805.09501 +""" + +import random + +import numpy as np +from PIL import Image +from PIL import ImageEnhance +from PIL import ImageOps + +_MAX_LEVEL = 10 # Maximum integer strength of an augmentation, if applicable. + + +class ImageNetPolicy: + """Definition of an ImageNetPolicy. + + Implements a fixed AutoAugment data augmentation policy targeted at + ImageNet training by randomly applying at runtime one of the 25 pre-defined + data augmentation sub-policies provided in Reference [1]. + + Usage example as a Pytorch Transform: + >>> transform=transforms.Compose([transforms.Resize(256), + >>> ImageNetPolicy(), + >>> transforms.ToTensor()]) + """ + + def __init__(self, fillcolor=(128, 128, 128)): + """Initialize an ImageNetPolicy. + + Args: + fillcolor (tuple): RGB color components of the color to be used for + filling when needed (default: (128, 128, 128), which + corresponds to gray). + """ + # Instantiate a list of sub-policies. + # Each entry of the list is a SubPolicy which consists of + # two augmentation operations, + # each of those parametrized as operation, probability, magnitude. + # Those two operations are applied sequentially on the image upon call. + self.policies = [ + SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor), + SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor), + SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor), + SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor), + SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor), + SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor), + SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor), + SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor), + SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor), + SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor), + SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor), + SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor), + SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor), + SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor), + SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor), + SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor), + SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor), + SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor), + SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor), + SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor), + SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor), + SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor), + SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor), + SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor), + SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor), + ] + + def __call__(self, img): + """Define call method for ImageNetPolicy class.""" + policy_idx = random.randint(0, len(self.policies) - 1) + return self.policies[policy_idx](img) + + def __repr__(self): + """Define repr method for ImageNetPolicy class.""" + return "ImageNetPolicy" + + +class SubPolicy: + """Definition of a SubPolicy. + + A SubPolicy consists of two augmentation operations, + each of those parametrized as operation, probability, magnitude. + The two operations are applied sequentially on the image upon call. + """ + + def __init__( + self, + operation1, + probability1, + magnitude_idx1, + operation2, + probability2, + magnitude_idx2, + fillcolor, + ): + """Initialize a SubPolicy. + + Args: + operation1 (str): Key specifying the first augmentation operation. + There are fourteen key values altogether (see supported_ops below + listing supported operations). probability1 (float): Probability + within [0., 1.] of applying the first augmentation operation. + magnitude_idx1 (int): Integer specifiying the strength of the first + operation as an index further used to derive the magnitude from a + range of possible values. + operation2 (str): Key specifying the second augmentation operation. + probability2 (float): Probability within [0., 1.] of applying the + second augmentation operation. + magnitude_idx2 (int): Integer specifiying the strength of the + second operation as an index further used to derive the magnitude + from a range of possible values. + fillcolor (tuple): RGB color components of the color to be used for + filling. + Returns: + """ + # List of supported operations for operation1 and operation2. + supported_ops = [ + "shearX", + "shearY", + "translateX", + "translateY", + "rotate", + "color", + "posterize", + "solarize", + "contrast", + "sharpness", + "brightness", + "autocontrast", + "equalize", + "invert", + ] + assert (operation1 in supported_ops) and ( + operation2 in supported_ops + ), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation." + + assert ( + 0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0 + ), "SubPolicy: prob1 and prob2 should be within [0., 1.]." + + assert ( + isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10 + ), "SubPolicy: idx1 should be specified as an integer within [0, 10]." + + assert ( + isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10 + ), "SubPolicy: idx2 should be specified as an integer within [0, 10]." + + # Define a dictionary where each key refers to a specific type of + # augmentation and the corresponding value is a range of ten possible + # magnitude values for that augmentation. + num_levels = _MAX_LEVEL + 1 + ranges = { + "shearX": np.linspace(0, 0.3, num_levels), + "shearY": np.linspace(0, 0.3, num_levels), + "translateX": np.linspace(0, 150 / 331, num_levels), + "translateY": np.linspace(0, 150 / 331, num_levels), + "rotate": np.linspace(0, 30, num_levels), + "color": np.linspace(0.0, 0.9, num_levels), + "posterize": np.round(np.linspace(8, 4, num_levels), 0).astype( + np.int + ), + "solarize": np.linspace(256, 0, num_levels), # range [0, 256] + "contrast": np.linspace(0.0, 0.9, num_levels), + "sharpness": np.linspace(0.0, 0.9, num_levels), + "brightness": np.linspace(0.0, 0.9, num_levels), + "autocontrast": [0] + * num_levels, # This augmentation doesn't use magnitude parameter. + "equalize": [0] + * num_levels, # This augmentation doesn't use magnitude parameter. + "invert": [0] + * num_levels, # This augmentation doesn't use magnitude parameter. + } + + def rotate_with_fill(img, magnitude): + """Define rotation transformation with fill. + + The input image is first rotated, then it is blended together with + a gray mask of the same size. Note that fillcolor as defined + elsewhere in this module doesn't apply here. + + Args: + magnitude (float): rotation angle in degrees. + Returns: + rotated_filled (PIL Image): rotated image with gray filling for + disoccluded areas unveiled by the rotation. + """ + rotated = img.convert("RGBA").rotate(magnitude) + rotated_filled = Image.composite( + rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated + ) + return rotated_filled.convert(img.mode) + + # Define a dictionary of augmentation functions where each key refers + # to a specific type of augmentation and the corresponding value defines + # the augmentation itself using a lambda function. + # pylint: disable=unnecessary-lambda + func_dict = { + "shearX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, + fillcolor=fillcolor, + ), + "shearY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), + Image.BICUBIC, + fillcolor=fillcolor, + ), + "translateX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + ( + 1, + 0, + magnitude * img.size[0] * random.choice([-1, 1]), + 0, + 1, + 0, + ), + fillcolor=fillcolor, + ), + "translateY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + ( + 1, + 0, + 0, + 0, + 1, + magnitude * img.size[1] * random.choice([-1, 1]), + ), + fillcolor=fillcolor, + ), + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "posterize": lambda img, magnitude: ImageOps.posterize( + img, magnitude + ), + "solarize": lambda img, magnitude: ImageOps.solarize( + img, magnitude + ), + "contrast": lambda img, magnitude: ImageEnhance.Contrast( + img + ).enhance(1 + magnitude * random.choice([-1, 1])), + "sharpness": lambda img, magnitude: ImageEnhance.Sharpness( + img + ).enhance(1 + magnitude * random.choice([-1, 1])), + "brightness": lambda img, magnitude: ImageEnhance.Brightness( + img + ).enhance(1 + magnitude * random.choice([-1, 1])), + "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: ImageOps.equalize(img), + "invert": lambda img, magnitude: ImageOps.invert(img), + } + + # Store probability, function and magnitude of the first augmentation + # for the sub-policy. + self.probability1 = probability1 + self.operation1 = func_dict[operation1] + self.magnitude1 = ranges[operation1][magnitude_idx1] + + # Store probability, function and magnitude of the second augmentation + # for the sub-policy. + self.probability2 = probability2 + self.operation2 = func_dict[operation2] + self.magnitude2 = ranges[operation2][magnitude_idx2] + + def __call__(self, img): + """Define call method for SubPolicy class.""" + # Randomly apply operation 1. + if random.random() < self.probability1: + img = self.operation1(img, self.magnitude1) + + # Randomly apply operation 2. + if random.random() < self.probability2: + img = self.operation2(img, self.magnitude2) + + return img diff --git a/megatron/data/bert_dataset.py b/megatron/data/bert_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d31b098ec37b3fbeb3a24368ad55034e18fcacea --- /dev/null +++ b/megatron/data/bert_dataset.py @@ -0,0 +1,234 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BERT Style dataset.""" + +import numpy as np +import torch + +from megatron import ( + get_args, + get_tokenizer, + mpu, + print_rank_0 +) +from megatron.data.dataset_utils import ( + get_samples_mapping, + get_a_and_b_segments, + truncate_segments, + create_tokens_and_tokentypes, + create_masked_lm_predictions +) + +class DummyBertDataset(torch.utils.data.Dataset): + def __init__(self, name, num_samples, max_seq_length): + self.name = name + self.num_samples = num_samples + self.max_seq_length = max_seq_length + self.np_rng = np.random.RandomState(seed=0) + # self.token_nps = np_rng.randint(1000, 2000, (self.num_samples, 512)) + # Vocab stuff. + tokenizer = get_tokenizer() + self.vocab_id_list = list(tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_dict = tokenizer.inv_vocab + self.cls_id = tokenizer.cls + self.sep_id = tokenizer.sep + self.mask_id = tokenizer.mask + self.pad_id = tokenizer.pad + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + tokens = self.np_rng.randint(1000, 2000, self.max_seq_length) + masked_position = np.arange(int(tokens.shape[0] * 0.15)) + tokens = tokens.astype(np.int64) + labels = tokens[masked_position] + label_np = np.full_like(tokens, -1) + label_np[masked_position] = labels + tokens[masked_position] = self.mask_id + loss_mask_np = np.zeros_like(tokens) + loss_mask_np[masked_position] = 1 + train_sample = { + 'text': tokens, + 'types': np.zeros_like(tokens), + 'labels': label_np, + 'is_random': 0, + 'loss_mask': loss_mask_np, + 'padding_mask': np.ones_like(tokens), + 'truncated': 0 + } + return train_sample + +class BertDataset(torch.utils.data.Dataset): + + def __init__(self, name, indexed_dataset, data_prefix, + num_epochs, max_num_samples, masked_lm_prob, + max_seq_length, short_seq_prob, seed, binary_head): + + # Params to store. + self.name = name + self.seed = seed + self.masked_lm_prob = masked_lm_prob + self.max_seq_length = max_seq_length + self.binary_head = binary_head + + # Dataset. + self.indexed_dataset = indexed_dataset + + # Build the samples mapping. + self.samples_mapping = get_samples_mapping(self.indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + self.max_seq_length - 3, # account for added tokens + short_seq_prob, + self.seed, + self.name, + self.binary_head) + + # Vocab stuff. + tokenizer = get_tokenizer() + self.vocab_id_list = list(tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_dict = tokenizer.inv_vocab + self.cls_id = tokenizer.cls + self.sep_id = tokenizer.sep + self.mask_id = tokenizer.mask + self.pad_id = tokenizer.pad + + def __len__(self): + return self.samples_mapping.shape[0] + + def __getitem__(self, idx): + start_idx, end_idx, seq_length = self.samples_mapping[idx] + sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)] + # Note that this rng state should be numpy and not python since + # python randint is inclusive whereas the numpy one is exclusive. + # We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1 + np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32)) + return build_training_sample(sample, seq_length, + self.max_seq_length, # needed for padding + self.vocab_id_list, + self.vocab_id_to_token_dict, + self.cls_id, self.sep_id, + self.mask_id, self.pad_id, + self.masked_lm_prob, np_rng, + self.binary_head) + + + + +def build_training_sample(sample, + target_seq_length, max_seq_length, + vocab_id_list, vocab_id_to_token_dict, + cls_id, sep_id, mask_id, pad_id, + masked_lm_prob, np_rng, binary_head): + """Biuld training sample. + + Arguments: + sample: A list of sentences in which each sentence is a list token ids. + target_seq_length: Desired sequence length. + max_seq_length: Maximum length of the sequence. All values are padded to + this length. + vocab_id_list: List of vocabulary ids. Used to pick a random id. + vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. + cls_id: Start of example id. + sep_id: Separator id. + mask_id: Mask token id. + pad_id: Padding token id. + masked_lm_prob: Probability to mask tokens. + np_rng: Random number genenrator. Note that this rng state should be + numpy and not python since python randint is inclusive for + the opper bound whereas the numpy one is exclusive. + """ + + if binary_head: + # We assume that we have at least two sentences in the sample + assert len(sample) > 1 + assert target_seq_length <= max_seq_length + + # Divide sample into two segments (A and B). + if binary_head: + tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, + np_rng) + else: + tokens_a = [] + for j in range(len(sample)): + tokens_a.extend(sample[j]) + tokens_b = [] + is_next_random = False + + # Truncate to `target_sequence_length`. + max_num_tokens = target_seq_length + truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), + len(tokens_b), max_num_tokens, np_rng) + + # Build tokens and toketypes. + tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, + cls_id, sep_id) + + # Masking. + max_predictions_per_seq = masked_lm_prob * max_num_tokens + (tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions( + tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, + cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng) + + # Padding. + tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ + = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, + masked_labels, pad_id, max_seq_length) + + train_sample = { + 'text': tokens_np, + 'types': tokentypes_np, + 'labels': labels_np, + 'is_random': int(is_next_random), + 'loss_mask': loss_mask_np, + 'padding_mask': padding_mask_np, + 'truncated': int(truncated)} + return train_sample + + +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, + masked_labels, pad_id, max_seq_length): + """Pad sequences and convert them to numpy.""" + + # Some checks. + num_tokens = len(tokens) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + assert len(tokentypes) == num_tokens + assert len(masked_positions) == len(masked_labels) + + # Tokens and token types. + filler = [pad_id] * padding_length + tokens_np = np.array(tokens + filler, dtype=np.int64) + tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) + + # Padding mask. + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, + dtype=np.int64) + + # Lables and loss mask. + labels = [-1] * max_seq_length + loss_mask = [0] * max_seq_length + for i in range(len(masked_positions)): + assert masked_positions[i] < num_tokens + labels[masked_positions[i]] = masked_labels[i] + loss_mask[masked_positions[i]] = 1 + labels_np = np.array(labels, dtype=np.int64) + loss_mask_np = np.array(loss_mask, dtype=np.int64) + + return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np diff --git a/megatron/data/biencoder_dataset_utils.py b/megatron/data/biencoder_dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f7b3b961b8cb28905eed8ba543721117cc9e0c81 --- /dev/null +++ b/megatron/data/biencoder_dataset_utils.py @@ -0,0 +1,208 @@ +import os +import time + +import numpy as np +import torch + +from megatron import get_args, get_tokenizer, mpu, print_rank_0 +from megatron.data.dataset_utils import create_masked_lm_predictions, \ + pad_and_convert_to_numpy +from megatron.data.data_samplers import MegatronPretrainingSampler + +def make_attention_mask(source_block, target_block): + """ + Returns a 2-dimensional (2-D) attention mask + :param source_block: 1-D array + :param target_block: 1-D array + """ + mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) + mask = mask.astype(np.int64) + # (source_length, target_length) + return mask + +def get_one_epoch_dataloader(dataset, micro_batch_size=None): + """Specifically one epoch to be used in an indexing job.""" + args = get_args() + + if micro_batch_size is None: + micro_batch_size = args.micro_batch_size + num_workers = args.num_workers + + # Use megatron's sampler with consumed samples set to 0 as + # this is only for evaluation and don't intend to resume half way. + # Also, set the drop last to false as don't intend to remove + # the last batch + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=0, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size(), + drop_last=False) + + return torch.utils.data.DataLoader(dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + pin_memory=True) + + +def get_ict_batch(data_iterator): + # Items and their type. + keys = ['query_tokens', 'query_mask', + 'context_tokens', 'context_mask', 'block_data'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is None: + data = None + else: + data = next(data_iterator) + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + query_tokens = data_b['query_tokens'].long() + query_mask = data_b['query_mask'] < 0.5 + context_tokens = data_b['context_tokens'].long() + context_mask = data_b['context_mask'] < 0.5 + block_indices = data_b['block_data'].long() + + return query_tokens, query_mask,\ + context_tokens, context_mask, block_indices + + +def join_str_list(str_list): + """Join a list of strings, handling spaces appropriately""" + result = "" + for s in str_list: + if s.startswith("##"): + result += s[2:] + else: + result += " " + s + return result + + +class BlockSampleData(object): + """A struct for fully describing a fixed-size block of data as used in REALM + + :param start_idx: for first sentence of the block + :param end_idx: for last sentence of the block (may be partially truncated in sample construction) + :param doc_idx: the index of the document from which the block comes in the original indexed dataset + :param block_idx: a unique integer identifier given to every block. + """ + def __init__(self, start_idx, end_idx, doc_idx, block_idx): + self.start_idx = start_idx + self.end_idx = end_idx + self.doc_idx = doc_idx + self.block_idx = block_idx + + def as_array(self): + return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64) + + def as_tuple(self): + return self.start_idx, self.end_idx, self.doc_idx, self.block_idx + + +class BlockSamplesMapping(object): + def __init__(self, mapping_array): + # make sure that the array is compatible with BlockSampleData + assert mapping_array.shape[1] == 4 + self.mapping_array = mapping_array + + def __len__(self): + return self.mapping_array.shape[0] + + def __getitem__(self, idx): + """Get the data associated with an indexed sample.""" + sample_data = BlockSampleData(*self.mapping_array[idx]) + return sample_data + + +def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs, + max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False): + """Get samples mapping for a dataset over fixed size blocks. This function also requires + a dataset of the titles for the source documents since their lengths must be taken into account. + + :return: samples_mapping (BlockSamplesMapping) + """ + + if not num_epochs: + if not max_num_samples: + raise ValueError("Need to specify either max_num_samples " + "or num_epochs") + num_epochs = np.iinfo(np.int32).max - 1 + if not max_num_samples: + max_num_samples = np.iinfo(np.int64).max - 1 + + # Filename of the index mapping + indexmap_filename = data_prefix + indexmap_filename += '_{}_indexmap'.format(name) + if num_epochs != (np.iinfo(np.int32).max - 1): + indexmap_filename += '_{}ep'.format(num_epochs) + if max_num_samples != (np.iinfo(np.int64).max - 1): + indexmap_filename += '_{}mns'.format(max_num_samples) + indexmap_filename += '_{}msl'.format(max_seq_length) + indexmap_filename += '_{}s'.format(seed) + if use_one_sent_docs: + indexmap_filename += '_1sentok' + indexmap_filename += '.npy' + + # Build the indexed mapping if not exist. + if mpu.get_data_parallel_rank() == 0 and \ + not os.path.isfile(indexmap_filename): + print(' > WARNING: could not find index map file {}, building ' + 'the indices on rank 0 ...'.format(indexmap_filename)) + + # Make sure the types match the helpers input types. + assert block_dataset.doc_idx.dtype == np.int64 + assert block_dataset.sizes.dtype == np.int32 + + # Build samples mapping + verbose = torch.distributed.get_rank() == 0 + start_time = time.time() + print_rank_0(' > building samples index mapping for {} ...'.format( + name)) + + from megatron.data import helpers + mapping_array = helpers.build_blocks_mapping( + block_dataset.doc_idx, + block_dataset.sizes, + title_dataset.sizes, + num_epochs, + max_num_samples, + max_seq_length - 3, # account for added tokens + seed, + verbose, + use_one_sent_docs) + + + print_rank_0(' > done building samples index mapping') + np.save(indexmap_filename, mapping_array, allow_pickle=True) + print_rank_0(' > saved the index mapping in {}'.format( + indexmap_filename)) + # Make sure all the ranks have built the mapping + print_rank_0(' > elapsed time to build and save samples mapping ' + '(seconds): {:4f}'.format( + time.time() - start_time)) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + assert counts[0].item() == torch.distributed.get_world_size( + group=mpu.get_data_parallel_group()) + + # Load indexed dataset. + print_rank_0(' > loading indexed mapping from {}'.format( + indexmap_filename)) + start_time = time.time() + + mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') + samples_mapping = BlockSamplesMapping(mapping_array) + + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + print_rank_0(' total number of samples: {}'.format( + mapping_array.shape[0])) + + return samples_mapping diff --git a/megatron/data/blendable_dataset.py b/megatron/data/blendable_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5ba4b98aa4a5c638605c904bfdcd68e2907f3d39 --- /dev/null +++ b/megatron/data/blendable_dataset.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Blendable dataset.""" + +import time + +import numpy as np +import torch + +from megatron import print_rank_0 +from megatron import mpu + + +class BlendableDataset(torch.utils.data.Dataset): + + + def __init__(self, datasets, weights): + + self.datasets = datasets + num_datasets = len(datasets) + assert num_datasets == len(weights) + + self.size = 0 + for dataset in self.datasets: + self.size += len(dataset) + + # Normalize weights. + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + + # Build indecies. + start_time = time.time() + assert num_datasets < 255 + self.dataset_index = np.zeros(self.size, dtype=np.uint8) + self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) + + from megatron.data import helpers + helpers.build_blending_indices(self.dataset_index, + self.dataset_sample_index, + weights, num_datasets, self.size, + torch.distributed.get_rank() == 0) + print_rank_0('> elapsed time for building blendable dataset indices: ' + '{:.2f} (sec)'.format(time.time() - start_time)) + + + def __len__(self): + return self.size + + + def __getitem__(self, idx): + dataset_idx = self.dataset_index[idx] + sample_idx = self.dataset_sample_index[idx] + return self.datasets[dataset_idx][sample_idx] diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..2efef42bf415caed48db18a431140fd54a625059 --- /dev/null +++ b/megatron/data/data_samplers.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataloaders.""" + + +import random +import torch +import numpy as np +from torch.utils.data import Dataset +from megatron import get_args +from megatron import mpu + + +def build_pretraining_data_loader(dataset, consumed_samples): + """Buld dataloader given an input dataset.""" + + if dataset is None: + return None + args = get_args() + + # Megatron sampler + if args.dataloader_type == 'single': + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size()) + elif args.dataloader_type == 'cyclic': + batch_sampler = MegatronPretrainingRandomSampler( + dataset, + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size(), + data_sharding=args.data_sharding) + else: + raise Exception('{} dataloader type is not supported.'.format( + args.dataloader_type)) + + # Torch dataloader. + return torch.utils.data.DataLoader(dataset, + batch_sampler=batch_sampler, + num_workers=args.num_workers, + pin_memory=True) + +class MegatronPretrainingSampler: + + def __init__(self, total_samples, consumed_samples, micro_batch_size, + data_parallel_rank, data_parallel_size, drop_last=True): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.drop_last = drop_last + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.consumed_samples < self.total_samples, \ + 'no samples left to consume: {}, {}'.format(self.consumed_samples, + self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def get_start_end_idx(self): + start_idx = self.data_parallel_rank * self.micro_batch_size + end_idx = start_idx + self.micro_batch_size + return start_idx, end_idx + + def __iter__(self): + batch = [] + # Last batch will be dropped if drop_last is not set False + for idx in range(self.consumed_samples, self.total_samples): + batch.append(idx) + if len(batch) == self.micro_batch_times_data_parallel_size: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + + +class RandomSeedDataset(Dataset): + + def __init__(self, dataset): + args = get_args() + self.base_seed = args.seed + self.curr_seed = args.seed + self.dataset = dataset + + def __len__(self): + return len(self.dataset) + + def set_epoch(self, epoch): + self.curr_seed = self.base_seed + epoch + + def __getitem__(self, idx): + seed = idx + self.curr_seed + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + return self.dataset[idx] + + +class MegatronPretrainingRandomSampler: + + def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size, + data_parallel_rank, data_parallel_size, data_sharding): + # Keep a copy of input params for later use. + self.dataset = dataset + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.data_parallel_size = data_parallel_size + self.data_sharding = data_sharding + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.last_batch_size = \ + self.total_samples % self.micro_batch_times_data_parallel_size + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def __iter__(self): + active_total_samples = self.total_samples - self.last_batch_size + self.epoch = self.consumed_samples // active_total_samples + current_epoch_samples = self.consumed_samples % active_total_samples + assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 + + if isinstance(self.dataset, RandomSeedDataset): + self.dataset.set_epoch(self.epoch) + + # data sharding and random sampling + if self.data_sharding: + bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ + * self.micro_batch_size + bucket_offset = current_epoch_samples // self.data_parallel_size + start_idx = self.data_parallel_rank * bucket_size + + g = torch.Generator() + g.manual_seed(self.epoch) + random_idx = torch.randperm(bucket_size, generator=g).tolist() + idx_range = [start_idx + x for x in random_idx[bucket_offset:]] + else: + full_bucket_size = (self.total_samples // self.micro_batch_size) \ + * self.micro_batch_size + full_bucket_offset = current_epoch_samples + g = torch.Generator() + g.manual_seed(self.epoch) + idx_range_total = \ + torch.randperm(full_bucket_size, generator=g).tolist() + idx_range_active = idx_range_total[full_bucket_offset:] + idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size] + + batch = [] + # Last batch if not complete will be dropped. + for idx in idx_range: + batch.append(idx) + if len(batch) == self.micro_batch_size: + self.consumed_samples += self.micro_batch_times_data_parallel_size + yield batch + batch = [] diff --git a/megatron/data/dataset_utils.py b/megatron/data/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..552782121059083a0ca226e222d0c29e3f342c1c --- /dev/null +++ b/megatron/data/dataset_utils.py @@ -0,0 +1,938 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors, and NVIDIA. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Most of the code here has been copied from: +# https://github.com/google-research/albert/blob/master/create_pretraining_data.py +# with some modifications. + +import math +import os +import time +import collections + +import numpy as np +import torch +import random + +from megatron import ( + get_tokenizer, + get_args, + mpu, + print_rank_0 +) +from megatron.data.blendable_dataset import BlendableDataset +from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset + +DSET_TYPE_BERT = 'standard_bert' +DSET_TYPE_ICT = 'ict' +DSET_TYPE_T5 = 't5' +DSET_TYPE_GLM = 'glm' + +DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_GLM] + + +def get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples): + + # The data prefix should be in the format of: + # weight-1, data-prefix-1, weight-2, data-prefix-2, .. + assert len(data_prefix) % 2 == 0 + num_datasets = len(data_prefix) // 2 + weights = [0]*num_datasets + prefixes = [0]*num_datasets + for i in range(num_datasets): + weights[i] = float(data_prefix[2*i]) + prefixes[i] = (data_prefix[2*i+1]).strip() + # Normalize weights + weight_sum = 0.0 + for weight in weights: + weight_sum += weight + assert weight_sum > 0.0 + weights = [weight / weight_sum for weight in weights] + + # Add 0.5% (the 1.005 factor) so in case the bleding dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + datasets_train_valid_test_num_samples = [] + for weight in weights: + datasets_train_valid_test_num_samples.append( + [int(math.ceil(val * weight * 1.005)) + for val in train_valid_test_num_samples]) + + + return prefixes, weights, datasets_train_valid_test_num_samples + + +def compile_helper(): + """Compile helper function ar runtime. Make sure this + is invoked on a single process.""" + import os + import subprocess + path = os.path.abspath(os.path.dirname(__file__)) + ret = subprocess.run(['make', '-C', path]) + if ret.returncode != 0: + print("Making C++ dataset helpers module failed, exiting.") + import sys + sys.exit(1) + + +def get_a_and_b_segments(sample, np_rng): + """Divide sample into a and b segments.""" + + # Number of sentences in the sample. + n_sentences = len(sample) + # Make sure we always have two sentences. + assert n_sentences > 1, 'make sure each sample has at least two sentences.' + + # First part: + # `a_end` is how many sentences go into the `A`. + a_end = 1 + if n_sentences >= 3: + # Note that randin in numpy is exclusive. + a_end = np_rng.randint(1, n_sentences) + tokens_a = [] + for j in range(a_end): + tokens_a.extend(sample[j]) + + # Second part: + tokens_b = [] + for j in range(a_end, n_sentences): + tokens_b.extend(sample[j]) + + # Random next: + is_next_random = False + if np_rng.random() < 0.5: + is_next_random = True + tokens_a, tokens_b = tokens_b, tokens_a + + return tokens_a, tokens_b, is_next_random + + +def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): + """Truncates a pair of sequences to a maximum sequence length.""" + #print(len_a, len_b, max_num_tokens) + assert len_a > 0 + if len_a + len_b <= max_num_tokens: + return False + while len_a + len_b > max_num_tokens: + if len_a > len_b: + len_a -= 1 + tokens = tokens_a + else: + len_b -= 1 + tokens = tokens_b + if np_rng.random() < 0.5: + del tokens[0] + else: + tokens.pop() + return True + + +def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): + """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" + + tokens = [] + tokentypes = [] + # [CLS]. + tokens.append(cls_id) + tokentypes.append(0) + # Segment A. + for token in tokens_a: + tokens.append(token) + tokentypes.append(0) + # [SEP]. + tokens.append(sep_id) + tokentypes.append(0) + # Segment B. + for token in tokens_b: + tokens.append(token) + tokentypes.append(1) + if tokens_b: + # [SEP]. + tokens.append(sep_id) + tokentypes.append(1) + + return tokens, tokentypes + +def create_tokens(tokens_a, tokens_b, cls_id, sep_id): + """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" + + tokens = [] + # [CLS]. + tokens.append(cls_id) + # Segment A. + for token in tokens_a: + tokens.append(token) + # [SEP]. + tokens.append(sep_id) + # Segment B. + for token in tokens_b: + tokens.append(token) + if tokens_b: + # [SEP]. + tokens.append(sep_id) + + return tokens + +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", + ["index", "label"]) + + +def is_start_piece(piece): + """Check if the current word piece is the starting piece (BERT).""" + # When a word has been split into + # WordPieces, the first token does not have any marker and any subsequence + # tokens are prefixed with ##. So whenever we see the ## token, we + # append it to the previous set of word indexes. + return not piece.startswith("##") + + +def create_masked_lm_predictions(tokens, + vocab_id_list, vocab_id_to_token_dict, + masked_lm_prob, + cls_id, sep_id, mask_id, + max_predictions_per_seq, + np_rng, + max_ngrams=3, + do_whole_word_mask=True, + favor_longer_ngram=False, + do_permutation=False, + geometric_dist=False, + masking_style="bert"): + """Creates the predictions for the masked LM objective. + Note: Tokens here are vocab ids and not text tokens.""" + + cand_indexes = [] + # Note(mingdachen): We create a list for recording if the piece is + # the starting piece of current token, where 1 means true, so that + # on-the-fly whole word masking is possible. + token_boundary = [0] * len(tokens) + + for (i, token) in enumerate(tokens): + if token == cls_id or token == sep_id: + token_boundary[i] = 1 + continue + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if (do_whole_word_mask and len(cand_indexes) >= 1 and + not is_start_piece(vocab_id_to_token_dict[token])): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + if is_start_piece(vocab_id_to_token_dict[token]): + token_boundary[i] = 1 + + output_tokens = list(tokens) + + masked_lm_positions = [] + masked_lm_labels = [] + + if masked_lm_prob == 0: + return (output_tokens, masked_lm_positions, + masked_lm_labels, token_boundary) + + num_to_predict = min(max_predictions_per_seq, + max(1, int(round(len(tokens) * masked_lm_prob)))) + + ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) + if not geometric_dist: + # Note(mingdachen): + # By default, we set the probilities to favor shorter ngram sequences. + pvals = 1. / np.arange(1, max_ngrams + 1) + pvals /= pvals.sum(keepdims=True) + if favor_longer_ngram: + pvals = pvals[::-1] + + ngram_indexes = [] + for idx in range(len(cand_indexes)): + ngram_index = [] + for n in ngrams: + ngram_index.append(cand_indexes[idx:idx + n]) + ngram_indexes.append(ngram_index) + + np_rng.shuffle(ngram_indexes) + + (masked_lms, masked_spans) = ([], []) + covered_indexes = set() + for cand_index_set in ngram_indexes: + if len(masked_lms) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes: + continue + + if not geometric_dist: + n = np_rng.choice(ngrams[:len(cand_index_set)], + p=pvals[:len(cand_index_set)] / + pvals[:len(cand_index_set)].sum(keepdims=True)) + else: + # Sampling "n" from the geometric distribution and clipping it to + # the max_ngrams. Using p=0.2 default from the SpanBERT paper + # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1) + n = min(np_rng.geometric(0.2), max_ngrams) + + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # Note(mingdachen): + # Repeatedly looking for a candidate that does not exceed the + # maximum number of predictions by trying shorter ngrams. + while len(masked_lms) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + masked_token = None + if masking_style == "bert": + # 80% of the time, replace with [MASK] + if np_rng.random() < 0.8: + masked_token = mask_id + else: + # 10% of the time, keep original + if np_rng.random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] + elif masking_style == "t5": + masked_token = mask_id + else: + raise ValueError("invalid value of masking style") + + output_tokens[index] = masked_token + masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) + + masked_spans.append(MaskedLmInstance( + index=index_set, + label=[tokens[index] for index in index_set])) + + assert len(masked_lms) <= num_to_predict + np_rng.shuffle(ngram_indexes) + + select_indexes = set() + if do_permutation: + for cand_index_set in ngram_indexes: + if len(select_indexes) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes or index in select_indexes: + continue + + n = np.random.choice(ngrams[:len(cand_index_set)], + p=pvals[:len(cand_index_set)] / + pvals[:len(cand_index_set)].sum(keepdims=True)) + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + + while len(select_indexes) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(select_indexes) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes or index in select_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + select_indexes.add(index) + assert len(select_indexes) <= num_to_predict + + select_indexes = sorted(select_indexes) + permute_indexes = list(select_indexes) + np_rng.shuffle(permute_indexes) + orig_token = list(output_tokens) + + for src_i, tgt_i in zip(select_indexes, permute_indexes): + output_tokens[src_i] = orig_token[tgt_i] + masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) + + masked_lms = sorted(masked_lms, key=lambda x: x.index) + # Sort the spans by the index of the first span + masked_spans = sorted(masked_spans, key=lambda x: x.index[0]) + + for p in masked_lms: + masked_lm_positions.append(p.index) + masked_lm_labels.append(p.label) + return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans) + + +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, + masked_labels, pad_id, max_seq_length): + """Pad sequences and convert them to numpy.""" + + # Some checks. + num_tokens = len(tokens) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + assert len(tokentypes) == num_tokens + assert len(masked_positions) == len(masked_labels) + + # Tokens and token types. + filler = [pad_id] * padding_length + tokens_np = np.array(tokens + filler, dtype=np.int64) + tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) + + # Padding mask. + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, + dtype=np.int64) + + # Lables and loss mask. + labels = [-1] * max_seq_length + loss_mask = [0] * max_seq_length + for i in range(len(masked_positions)): + assert masked_positions[i] < num_tokens + labels[masked_positions[i]] = masked_labels[i] + loss_mask[masked_positions[i]] = 1 + labels_np = np.array(labels, dtype=np.int64) + loss_mask_np = np.array(loss_mask, dtype=np.int64) + + return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np + + +def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, short_seq_prob, seed, + skip_warmup, binary_head=False, + max_seq_length_dec=None, + dataset_type='standard_bert'): + if len(data_prefix) == 1: + return _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + short_seq_prob, seed, + skip_warmup, + binary_head, + max_seq_length_dec, + dataset_type=dataset_type) + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], data_impl, splits_string, + datasets_train_valid_test_num_samples[i], + max_seq_length, masked_lm_prob, short_seq_prob, + seed, skip_warmup, binary_head, max_seq_length_dec, dataset_type=dataset_type) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights) + + return (blending_train_dataset, blending_valid_dataset, + blending_test_dataset) + + +def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, short_seq_prob, seed, + skip_warmup, binary_head, + max_seq_length_dec, + dataset_type='standard_bert'): + + if dataset_type not in DSET_TYPES: + raise ValueError("Invalid dataset_type: ", dataset_type) + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + if dataset_type == DSET_TYPE_ICT: + args = get_args() + title_dataset = get_indexed_dataset_(args.titles_data_path, + data_impl, + skip_warmup) + + # Get start and end indices of train/valid/train into doc-idx + # Note that doc-idx is desinged to be num-docs + 1 so we can + # easily iterate over it. + total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + print_rank_0(' > dataset split:') + + def print_split_stats(name, index): + print_rank_0(' {}:'.format(name)) + print_rank_0(' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[index], splits[index + 1], + splits[index + 1] - splits[index])) + start_index = indexed_dataset.doc_idx[splits[index]] + end_index = indexed_dataset.doc_idx[splits[index + 1]] + print_rank_0(' sentence indices in [{}, {}) total of {} ' + 'sentences'.format(start_index, end_index, + end_index - start_index)) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + from megatron.data.bert_dataset import BertDataset + from megatron.data.ict_dataset import ICTDataset + from megatron.data.t5_dataset import T5Dataset + from megatron.data.glm_dataset import GlmDataset + dataset = None + if splits[index + 1] > splits[index]: + # Get the pointer to the original doc-idx so we can set it later. + doc_idx_ptr = indexed_dataset.get_doc_idx() + # Slice the doc-idx + start_index = splits[index] + # Add +1 so we can index into the dataset to get the upper bound. + end_index = splits[index + 1] + 1 + # New doc_idx view. + indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) + # Build the dataset accordingly. + kwargs = dict( + name=name, + data_prefix=data_prefix, + num_epochs=None, + max_num_samples=train_valid_test_num_samples[index], + max_seq_length=max_seq_length, + seed=seed, + ) + + if dataset_type == DSET_TYPE_ICT: + args = get_args() + dataset = ICTDataset( + block_dataset=indexed_dataset, + title_dataset=title_dataset, + query_in_block_prob=args.query_in_block_prob, + use_one_sent_docs=args.use_one_sent_docs, + binary_head=binary_head, + **kwargs + ) + elif dataset_type == DSET_TYPE_T5: + dataset = T5Dataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + max_seq_length_dec=max_seq_length_dec, + short_seq_prob=short_seq_prob, + **kwargs + ) + elif dataset_type == DSET_TYPE_BERT: + dataset = BertDataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + short_seq_prob=short_seq_prob, + binary_head=binary_head, + **kwargs + ) + elif dataset_type == DSET_TYPE_GLM: + dataset = GlmDataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + short_seq_prob=short_seq_prob, + binary_head=binary_head, + **kwargs + ) + else: + raise NotImplementedError("Dataset type not fully implemented.") + + # Set the original pointer so dataset remains the main dataset. + indexed_dataset.set_doc_idx(doc_idx_ptr) + # Checks. + assert indexed_dataset.doc_idx[0] == 0 + assert indexed_dataset.doc_idx.shape[0] == \ + (total_num_of_documents + 1) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + + print_rank_0(' > building dataset index ...') + + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, + data_impl, + skip_warmup) + assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] + print_rank_0(' > finished creating indexed dataset in {:4f} ' + 'seconds'.format(time.time() - start_time)) + + print_rank_0(' > indexed dataset stats:') + print_rank_0(' number of documents: {}'.format( + indexed_dataset.doc_idx.shape[0] - 1)) + print_rank_0(' number of sentences: {}'.format( + indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +def get_train_valid_test_split_(splits_string, size): + """ Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(',') != -1: + splits = [float(s) for s in splits_string.split(',')] + elif splits_string.find('/') != -1: + splits = [float(s) for s in splits_string.split('/')] + else: + splits = [float(splits_string)] + while len(splits) < 3: + splits.append(0.) + splits = splits[:3] + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + assert len(splits_index) == 4 + assert splits_index[-1] == size + return splits_index + +def get_samples_mapping(indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + max_seq_length, + short_seq_prob, + seed, + name, + binary_head): + """Get a list that maps a sample index to a starting sentence index, end sentence index, and length""" + + if not num_epochs: + if not max_num_samples: + raise ValueError("Need to specify either max_num_samples " + "or num_epochs") + num_epochs = np.iinfo(np.int32).max - 1 + if not max_num_samples: + max_num_samples = np.iinfo(np.int64).max - 1 + + # Filename of the index mapping + indexmap_filename = data_prefix + indexmap_filename += '_{}_indexmap'.format(name) + if num_epochs != (np.iinfo(np.int32).max - 1): + indexmap_filename += '_{}ep'.format(num_epochs) + if max_num_samples != (np.iinfo(np.int64).max - 1): + indexmap_filename += '_{}mns'.format(max_num_samples) + indexmap_filename += '_{}msl'.format(max_seq_length) + indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) + indexmap_filename += '_{}s'.format(seed) + indexmap_filename += '.npy' + + # Build the indexed mapping if not exist. + if torch.distributed.get_rank() == 0 and \ + not os.path.isfile(indexmap_filename): + print(' > WARNING: could not find index map file {}, building ' + 'the indices on rank 0 ...'.format(indexmap_filename)) + + # Make sure the types match the helpers input types. + assert indexed_dataset.doc_idx.dtype == np.int64 + assert indexed_dataset.sizes.dtype == np.int32 + + # Build samples mapping + verbose = torch.distributed.get_rank() == 0 + start_time = time.time() + print_rank_0(' > building samples index mapping for {} ...'.format( + name)) + # First compile and then import. + from megatron.data import helpers + samples_mapping = helpers.build_mapping( + indexed_dataset.doc_idx, + indexed_dataset.sizes, + num_epochs, + max_num_samples, + max_seq_length, + short_seq_prob, + seed, + verbose, + 2 if binary_head else 1) + print_rank_0(' > done building samples index maping') + np.save(indexmap_filename, samples_mapping, allow_pickle=True) + print_rank_0(' > saved the index mapping in {}'.format( + indexmap_filename)) + # Make sure all the ranks have built the mapping + print_rank_0(' > elasped time to build and save samples mapping ' + '(seconds): {:4f}'.format( + time.time() - start_time)) + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + + # Load indexed dataset. + print_rank_0(' > loading indexed mapping from {}'.format( + indexmap_filename)) + start_time = time.time() + samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + print_rank_0(' total number of samples: {}'.format( + samples_mapping.shape[0])) + + return samples_mapping + + +class MaskEncoder(object): + def __init__(self): + tokenizer = get_tokenizer() + self.vocab_size = tokenizer.vocab_size + self.vocab_id_list = list(tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_dict = tokenizer.inv_vocab + self.cls_id = tokenizer.cls + self.sep_id = tokenizer.sep + self.mask_id = tokenizer.mask + self.pad_id = tokenizer.pad + + import jieba_fast + self.zh_tokenizer = jieba_fast.lcut + self.random_ratio = 0 + + + def word_starts(self, source): + raw_tokens = [self.vocab_id_to_token_dict[i] for i in source.tolist()] + words = [raw_tokens[0]] + self.zh_tokenizer(''.join(raw_tokens[1:-1]), HMM=True) + [raw_tokens[-1]] + + def _is_chinese_char(c): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if len(c) > 1: + return all([_is_chinese_char(c_i) for c_i in c]) + cp = ord(c) + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def align_linear(atokens, btokens): + a2c = [] + c2b = [] + a2b = [] + length = 0 + for tok in atokens: + a2c.append([length + i for i in range(len(tok))]) + length += len(tok) + for i, tok in enumerate(btokens): + c2b.extend([i for _ in range(len(tok))]) + + for i, amap in enumerate(a2c): + bmap = [c2b[ci] for ci in amap] + a2b.append(list(set(bmap))) + return a2b + + raw_to_word_align = align_linear(raw_tokens, words) + is_word_start = torch.zeros(source.size()) + word_starts = [] + skip_cur_word = True + for i in range(1, len(raw_to_word_align)): + if raw_to_word_align[i-1] == raw_to_word_align[i]: + # not a word start, as they align to the same word + if not skip_cur_word and not _is_chinese_char(raw_tokens[i]): + word_starts.pop(-1) + skip_cur_word = True + continue + else: + is_word_start[i] = 1 + if _is_chinese_char(raw_tokens[i]): + word_starts.append(i) + skip_cur_word = False + is_word_start[0] = 0 + is_word_start[-1] = 0 + word_starts = torch.tensor(word_starts).long().view(-1, 1) + return is_word_start, word_starts + + def add_whole_word_mask(self, source, p, replace_length=1): + is_word_start, word_starts = self.word_starts(source) + num_to_mask_word = int(math.ceil(word_starts.size(0) * p)) + num_to_mask_char = int(math.ceil(word_starts.size(0) * p * 0.1)) + num_to_mask = num_to_mask_word + num_to_mask_char + if num_to_mask > word_starts.size(0): + word_starts = is_word_start.nonzero(as_tuple=False) + num_inserts = 0 + if num_to_mask == 0: + return source + + lengths = torch.ones((num_to_mask,)).long() + assert is_word_start[-1] == 0 + indices = word_starts[ + torch.randperm(word_starts.size(0))[:num_to_mask] + ].squeeze(1) + if len(indices) < num_to_mask: + num_to_mask = len(indices) + + mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio + source_length = source.size(0) + assert source_length - 1 not in indices + to_keep = torch.ones(source_length, dtype=torch.bool) + is_word_start[ + -1 + ] = 255 # acts as a long length, so spans don't go over the end of doc + if replace_length == 0: + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + # print(source.size(), word_starts.size(), indices.size(), mask_random.size()) + # try: + source[indices] = self.mask_id + source[indices[mask_random]] = torch.randint( + 1, self.vocab_size, size=(mask_random.sum(),) + ) + # except: + # print(source) + # print(indices) + # print(mask_random) + # print() + # sorted_indices = torch.sort(indices)[0] + # continue_mask_pos = ((sorted_indices + 1)[:-1] == sorted_indices[1:]) + # continue_mask_indices = sorted_indices[1:][continue_mask_pos] + # to_keep[continue_mask_indices] = 0 + + # for char indices, we already masked, the following loop handles word mask + indices = indices[:num_to_mask_word] + mask_random = mask_random[:num_to_mask_word] + while indices.size(0) > 0: + uncompleted = is_word_start[indices + 1] == 0 + indices = indices[uncompleted] + 1 + mask_random = mask_random[uncompleted] + if replace_length != -1: + # delete token + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_id + source[indices[mask_random]] = torch.randint( + 1, self.vocab_size, size=(mask_random.sum(),) + ) + + assert source_length - 1 not in indices + source = source[to_keep] + + return source + + def shif_chinese_word(self, tokens, tokens_bf_mask): + assert len(tokens) == len(tokens_bf_mask), 'length must be equal in this function' + buff_list = [] + buff_list_index = [] + for i in range(len(tokens)): + if tokens[i] == tokens_bf_mask[i]: + if len(buff_list) == 0: + continue + else: + if len(buff_list) != 1: + random.shuffle(buff_list) + tokens[buff_list_index[0] : buff_list_index[-1]+1] = buff_list + buff_list = [] + buff_list_index = [] + else: + buff_list.append(tokens_bf_mask[i]) + buff_list_index.append(i) + + return tokens + + def mass_style_mask(self, tokens): + tokens = tokens[:] + p = random.uniform(0.3, 0.5) + num_to_mask = int(len(tokens) * p) + start_index = int((1 - p) / 2 * len(tokens)) + tokens[start_index : start_index + num_to_mask] = [self.mask_id] * num_to_mask + + return tokens + + def delete_chinese_word(self, tokens, tokens_bf_mask): + return_tokens = [] + assert len(tokens) == len(tokens_bf_mask), 'length must be equal in this function' + for i in range(len(tokens)): + if tokens[i] == tokens_bf_mask[i]: + return_tokens.append(tokens[i]) + + return return_tokens \ No newline at end of file diff --git a/megatron/data/glm_dataset.py b/megatron/data/glm_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d2030888d85ebaa64b383ac364d039d5327f8af0 --- /dev/null +++ b/megatron/data/glm_dataset.py @@ -0,0 +1,377 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BERT Style dataset.""" + +import numpy as np +import torch + +from megatron import ( + get_args, + get_tokenizer, + mpu, + print_rank_0 +) +from megatron.data.dataset_utils import ( + get_samples_mapping, + get_a_and_b_segments, + truncate_segments, + create_tokens_and_tokentypes, + create_tokens, + create_masked_lm_predictions, + MaskEncoder +) + +class DummyBertDataset(torch.utils.data.Dataset): + def __init__(self, name, num_samples, max_seq_length): + self.name = name + self.num_samples = num_samples + self.max_seq_length = max_seq_length + self.np_rng = np.random.RandomState(seed=0) + # self.token_nps = np_rng.randint(1000, 2000, (self.num_samples, 512)) + # Vocab stuff. + tokenizer = get_tokenizer() + self.vocab_id_list = list(tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_dict = tokenizer.inv_vocab + self.cls_id = tokenizer.cls + self.sep_id = tokenizer.sep + self.mask_id = tokenizer.mask + self.pad_id = tokenizer.pad + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + tokens = self.np_rng.randint(1000, 2000, self.max_seq_length) + masked_position = np.arange(int(tokens.shape[0] * 0.15)) + tokens = tokens.astype(np.int64) + labels = tokens[masked_position] + label_np = np.full_like(tokens, -1) + label_np[masked_position] = labels + tokens[masked_position] = self.mask_id + loss_mask_np = np.zeros_like(tokens) + loss_mask_np[masked_position] = 1 + train_sample = { + 'text': tokens, + 'types': np.zeros_like(tokens), + 'labels': label_np, + 'is_random': 0, + 'loss_mask': loss_mask_np, + 'padding_mask': np.ones_like(tokens), + 'truncated': 0 + } + return train_sample + +class GlmDataset(torch.utils.data.Dataset): + + def __init__(self, name, indexed_dataset, data_prefix, + num_epochs, max_num_samples, masked_lm_prob, + max_seq_length, short_seq_prob, seed, binary_head): + + # Params to store. + self.name = name + self.seed = seed + self.masked_lm_prob = masked_lm_prob + self.max_seq_length = max_seq_length + self.binary_head = binary_head + + # Dataset. + self.indexed_dataset = indexed_dataset + + # Build the samples mapping. + self.samples_mapping = get_samples_mapping(self.indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + self.max_seq_length - 3, # account for added tokens + short_seq_prob, + self.seed, + self.name, + self.binary_head) + + # Vocab stuff. + tokenizer = get_tokenizer() + self.vocab_id_list = list(tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_dict = tokenizer.inv_vocab + self.cls_id = tokenizer.cls + self.sep_id = tokenizer.sep + self.mask_id = tokenizer.mask + self.pad_id = tokenizer.pad + + def __len__(self): + return self.samples_mapping.shape[0] + + def __getitem__(self, idx): + start_idx, end_idx, seq_length = self.samples_mapping[idx] + sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)] + # Note that this rng state should be numpy and not python since + # python randint is inclusive whereas the numpy one is exclusive. + # We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1 + np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32)) + return build_training_sample(sample, seq_length, + self.max_seq_length, # needed for padding + self.vocab_id_list, + self.vocab_id_to_token_dict, + self.cls_id, self.sep_id, + self.mask_id, self.pad_id, + self.masked_lm_prob, np_rng, + self.binary_head) + +def sent_level_task(binary_head, sample, target_seq_length, max_seq_length, np_rng): + if binary_head: + # We assume that we have at least two sentences in the sample + assert len(sample) > 1 + assert target_seq_length <= max_seq_length + # Divide sample into two segments (A and B). + if binary_head: + tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng) + else: + tokens_a = [] + for j in range(len(sample)): + tokens_a.extend(sample[j]) + tokens_b = [] + is_next_random = False + # Truncate to `target_sequence_length`. + + max_num_tokens = target_seq_length + truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), + len(tokens_b), max_num_tokens, np_rng) + return is_next_random, truncated, max_num_tokens, tokens_a, tokens_b + +def generate_decoder_input_and_output(tokens, pad_id, sep_id): + """ + decoder input [SEP] [CSL] A B C D + decoder output [CLS] A B C D E + """ + + decoder_output = tokens[:] + decoder_input = [0] * len(decoder_output) + decoder_input[0] = sep_id # match the preprocessing in fairseq + # decoder_input[0] = sep_id # match the preprocessing in fairseq + decoder_input[1:] = decoder_output[:-1] + + """ + decoder input [CSL] A B C D [SEP] + decoder output A B C D [SEP] [PAD] + """ + + # decoder_input = tokens[:] + # decoder_output = [0] * len(decoder_input) + # decoder_output[:-1] = decoder_input[1:] + # decoder_output[-1] = pad_id + + return decoder_input, decoder_output + + + +def build_training_sample(sample, + target_seq_length, max_seq_length, + vocab_id_list, vocab_id_to_token_dict, + cls_id, sep_id, mask_id, pad_id, + masked_lm_prob, np_rng, binary_head): + + """ + sent-level task + """ + is_next_random, truncated, max_num_tokens, tokens_a, tokens_b = sent_level_task( + binary_head, sample, target_seq_length, max_seq_length, np_rng) + tokens_bf_mask = create_tokens(tokens_a, tokens_b, cls_id, sep_id) + if is_next_random: + raw_tokens = create_tokens(tokens_b, tokens_a, cls_id, sep_id) + else: + raw_tokens = tokens_bf_mask[:] + + """ + decoder-input and output + """ + decoder_input, decoder_output = generate_decoder_input_and_output(raw_tokens, pad_id, sep_id) + + # importance part + + encoder_loss_flag = 0 + decoder_loss_flag = 0 + sent_loss_flag = 1 + encoder_rng = torch.rand(1).item() + me = MaskEncoder() + if encoder_rng < 1.1: + # only train with encoder and decoder + # Masking. + if 0: + max_predictions_per_seq = masked_lm_prob * max_num_tokens + (tokens, _, _, _, _) = create_masked_lm_predictions( + tokens_bf_mask, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, + cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng, masking_style="t5") + if 1 : + tokens = torch.LongTensor(tokens_bf_mask) + tokens = me.add_whole_word_mask(tokens, 0.15, -1) + tokens = tokens.tolist() + shift_rng = torch.rand(1).item() + if shift_rng < 0.0: + tokens = me.shif_chinese_word(tokens, tokens_bf_mask) + encoder_loss_flag = 1 + decoder_loss_flag = 1 + else: + # train only with decoder + tokens = torch.LongTensor(tokens_bf_mask) + decoder_rng = torch.rand(1).item() + if decoder_rng < 0.4: + # WWM mask 30% word + tokens = me.add_whole_word_mask(tokens, 0.3, -1) + tokens = tokens.tolist() + if decoder_rng >= 0.4 and decoder_rng < 0.6: + # MASS mask style + tokens = me.mass_style_mask(tokens_bf_mask) + if decoder_rng > 0.6: + # delete tokens + tokens = me.add_whole_word_mask(tokens, 0.3, -1) + tokens = tokens.tolist() + tokens = me.delete_chinese_word(tokens, tokens_bf_mask) + tmp_tt = get_tokenizer() + # print("encoder ori input", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens_bf_mask))) + # print("encoder input ", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens))) + # print("------\n\n") + + + decoder_loss_flag = 1 + + # tmp_tt = get_tokenizer() + # print("encoder ori input", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens_bf_mask))) + # print("encoder input ", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens))) + # print("decoder input", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(decoder_input))) + # print("decoder output", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(decoder_output))) + + tokentypes = [] + encoder_labels = [] + encoder_labels_mask = [] + padding_mask = [] + apppend_type_id = 0 + + if len(tokens) == len(tokens_bf_mask): + # encoder and decoder can train togather + for index in range(len(tokens)): + padding_mask.append(1) + # generate tokens type + if tokens[index] == sep_id: + apppend_type_id = 1 + tokentypes.append(apppend_type_id) + + if tokens[index] == tokens_bf_mask[index]: + encoder_labels.append(-1) + encoder_labels_mask.append(0) + else: + encoder_labels.append(tokens_bf_mask[index]) + encoder_labels_mask.append(1) + else: + # only train decoder + for index in range(len(tokens)): + padding_mask.append(1) + if tokens[index] == sep_id: + apppend_type_id = 1 + tokentypes.append(apppend_type_id) + encoder_labels.append(-1) + encoder_labels_mask.append(0) + + tokens_np = pad_and_convert_to_numpy_light(tokens, max_seq_length, pad_id) + tokentypes_np = pad_and_convert_to_numpy_light(tokentypes, max_seq_length, pad_id) + padding_mask_np = pad_and_convert_to_numpy_light(padding_mask, max_seq_length, pad_id) + encoder_labels_np = pad_and_convert_to_numpy_light(encoder_labels, max_seq_length, -1) + encoder_labels_mask_np = pad_and_convert_to_numpy_light(encoder_labels_mask, max_seq_length, pad_id) + decoder_input_np = pad_and_convert_to_numpy_light(decoder_input, max_seq_length, pad_id) + decoder_output_np = pad_and_convert_to_numpy_light(decoder_output, max_seq_length, pad_id) + + # print(tokens_np) + # print(encoder_labels_np) + # print(padding_mask_np) + # print(encoder_labels_mask_np) + + # generate tokentypes + train_sample = { + 'text': tokens_np, # encoder_input + 'types': tokentypes_np, # token_type + 'is_random': int(is_next_random), #sop_labels + 'truncated': int(truncated), # if truncated + 'labels': encoder_labels_np, #encoder_labels + 'loss_mask': encoder_labels_mask_np, # mlm_mask + 'padding_mask': padding_mask_np, # padding_mask + 'decoder_input': decoder_input_np, # decoder_input + 'decoder_output': decoder_output_np, #decoder_output + 'encoder_loss_flag': int(encoder_loss_flag), + 'decoder_loss_flag': int(decoder_loss_flag), + 'sent_loss_flag': int(sent_loss_flag), + } + + # print(tokens_np.shape) + # print(tokens_np) + + # print(tokentypes_np.shape) + # print(tokentypes_np) + + # print(encoder_labels_np.shape) + # print(encoder_labels_np) + + # print(encoder_labels_mask_np.shape) + # print(encoder_labels_mask_np) + + # print(padding_mask_np.shape) + # print(padding_mask_np) + + # print(decoder_input_np.shape) + # print(decoder_input_np) + + # print(decoder_output_np.shape) + # print(decoder_output_np) + + # print("=====\n\n\n") + # import sys;sys.exit(0) + return train_sample + +def pad_and_convert_to_numpy_light(tokens, max_seq_length, pad_id): + padding_length = max_seq_length - len(tokens) + assert padding_length >= 0 + filler = [pad_id] * padding_length + tokens_np = np.array(tokens + filler, dtype=np.int64) + return tokens_np + +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, + masked_labels, pad_id, max_seq_length): + """Pad sequences and convert them to numpy.""" + + # Some checks. + num_tokens = len(tokens) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + assert len(tokentypes) == num_tokens + assert len(masked_positions) == len(masked_labels) + + # Tokens and token types. + filler = [pad_id] * padding_length + tokens_np = np.array(tokens + filler, dtype=np.int64) + tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) + + # Padding mask. + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, + dtype=np.int64) + + # Lables and loss mask. + labels = [-1] * max_seq_length + loss_mask = [0] * max_seq_length + for i in range(len(masked_positions)): + assert masked_positions[i] < num_tokens + labels[masked_positions[i]] = masked_labels[i] + loss_mask[masked_positions[i]] = 1 + labels_np = np.array(labels, dtype=np.int64) + loss_mask_np = np.array(loss_mask, dtype=np.int64) + + return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e6c64e975d47244f07691bb611599d81cbbff9b2 --- /dev/null +++ b/megatron/data/gpt_dataset.py @@ -0,0 +1,430 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPT style dataset.""" + +import os +import time + +import numpy as np +import torch + +from megatron import mpu, print_rank_0 +from megatron.data.blendable_dataset import BlendableDataset +from megatron.data.dataset_utils import get_datasets_weights_and_num_samples +from megatron.data.dataset_utils import get_train_valid_test_split_ +from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset + + +def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + seq_length, seed, skip_warmup): + """Build train, valid, and test datasets.""" + + # Single dataset. + if len(data_prefix) == 1: + return _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + seq_length, seed, skip_warmup) + + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], data_impl, splits_string, + datasets_train_valid_test_num_samples[i], + seq_length, seed, skip_warmup) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights) + + return (blending_train_dataset, blending_valid_dataset, + blending_test_dataset) + + +def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + seq_length, seed, skip_warmup): + """Build train, valid, and test datasets.""" + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + total_num_of_documents = indexed_dataset.sizes.shape[0] + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + print_rank_0(' > dataset split:') + + def print_split_stats(name, index): + print_rank_0(' {}:'.format(name)) + print_rank_0(' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[index], splits[index + 1], + splits[index + 1] - splits[index])) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + documents = np.arange(start=splits[index], stop=splits[index + 1], + step=1, dtype=np.int32) + dataset = GPTDataset(name, data_prefix, + documents, indexed_dataset, + train_valid_test_num_samples[index], + seq_length, seed) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + """Build indexed dataset.""" + print_rank_0(' > building dataset index ...') + + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, + data_impl, + skip_warmup) + print_rank_0(' > finished creating indexed dataset in {:4f} ' + 'seconds'.format(time.time() - start_time)) + print_rank_0(' number of documents: {}'.format( + indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +class GPTDataset(torch.utils.data.Dataset): + + def __init__(self, name, data_prefix, documents, indexed_dataset, + num_samples, seq_length, seed): + + self.name = name + self.indexed_dataset = indexed_dataset + + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < indexed_dataset.sizes.shape[0] + + # Build index mappings. + self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( + self.name, data_prefix, documents, self.indexed_dataset.sizes, + num_samples, seq_length, seed) + + def __len__(self): + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + return self.sample_idx.shape[0] - 1 + + def __getitem__(self, idx): + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = self.sample_idx[idx][0] + doc_index_l = self.sample_idx[idx + 1][0] + offset_f = self.sample_idx[idx][1] + offset_l = self.sample_idx[idx + 1][1] + # If we are within the same document, just extract the chunk. + if doc_index_f == doc_index_l: + sample = self.indexed_dataset.get(self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1) + else: + # Otherwise, get the rest of the initial document. + sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], + offset=offset_f)] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) + # And finally add the relevant portion of last document. + sample_list.append(self.indexed_dataset.get( + self.doc_idx[doc_index_l], + length=offset_l + 1)) + sample = np.concatenate(sample_list) + + return {'text': np.array(sample, dtype=np.int64)} + + +def _build_index_mappings(name, data_prefix, documents, sizes, + num_samples, seq_length, seed): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + _filename = data_prefix + _filename += '_{}_indexmap'.format(name) + _filename += '_{}ns'.format(num_samples) + _filename += '_{}sl'.format(seq_length) + _filename += '_{}s'.format(seed) + doc_idx_filename = _filename + '_doc_idx.npy' + sample_idx_filename = _filename + '_sample_idx.npy' + shuffle_idx_filename = _filename + '_shuffle_idx.npy' + + # Build the indexed mapping if not exist. + if torch.distributed.get_rank() == 0: + if (not os.path.isfile(doc_idx_filename)) or \ + (not os.path.isfile(sample_idx_filename)) or \ + (not os.path.isfile(shuffle_idx_filename)): + + print_rank_0(' > WARNING: could not find index map files, building ' + 'the indices on rank 0 ...') + + # For the last epoch, decide whether include the entire epoch + # in the global shuffle or not. + + # If we need only one epoch, then separating last epoch does + # not mean anything. + if num_epochs == 1: + separate_last_epoch = False + print(' > only one epoch required, setting ' + 'separate_last_epoch to False', flush=True) + + else: + # Get the number of samples for the last epoch + num_samples_from_epochs_minus_one = ( + (num_epochs - 1) * tokens_per_epoch - 1) // seq_length + last_epoch_num_samples = num_samples - \ + num_samples_from_epochs_minus_one + assert last_epoch_num_samples >= 0, \ + 'last epoch number of samples should be non-negative.' + num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length + assert last_epoch_num_samples < (num_samples_per_epoch + 1), \ + 'last epoch number of samples exceeded max value.' + # If we have less than 80% of the samples for the last epoch, + # seperate out the epoch and treat it differently. + # Note: the 80% number is just based on common sense and can + # be adjusted if needed. + separate_last_epoch = (last_epoch_num_samples < + int(0.80 * num_samples_per_epoch)) + if separate_last_epoch: + string = ' > last epoch number of samples ({}) is smaller '\ + 'than 80% of number of samples per epoch ({}), '\ + 'setting separate_last_epoch to True' + else: + string = ' > last epoch number of samples ({}) is larger '\ + 'than 80% of number of samples per epoch ({}), '\ + 'setting separate_last_epoch to False' + print(string.format(last_epoch_num_samples, + num_samples_per_epoch), flush=True) + + # doc-idx. + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, + separate_last_epoch) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save doc-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + # First compile and then import. + from megatron.data import helpers + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, + num_epochs, tokens_per_epoch) + # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length, + # num_epochs, tokens_per_epoch) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save sample-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + if separate_last_epoch: + num_samples_ = num_samples_from_epochs_minus_one + else: + num_samples_ = sample_idx.shape[0] - 1 + shuffle_idx = _build_shuffle_idx(num_samples_, + sample_idx.shape[0] - 1, np_rng) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save shuffle-idx mapping' + ' (seconds): {:4f}'.format(time.time() - start_time)) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + + # Load mappings. + start_time = time.time() + print_rank_0(' > loading doc-idx mapping from {}'.format( + doc_idx_filename)) + doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' > loading sample-idx mapping from {}'.format( + sample_idx_filename)) + sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' > loading shuffle-idx mapping from {}'.format( + shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + print_rank_0(' total number of samples: {}'.format( + sample_idx.shape[0])) + print_rank_0(' total number of epochs: {}'.format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx + + +def _num_tokens(documents, sizes): + """Total number of tokens in the dataset.""" + return np.sum(sizes[documents]) + + +def _num_epochs(tokens_per_epoch, seq_length, num_samples): + """Based on number of samples and sequence lenght, calculate how many + epochs will be needed.""" + num_epochs = 0 + total_tokens = 0 + while True: + num_epochs += 1 + total_tokens += tokens_per_epoch + # -1 is because we need to retrieve seq_length + 1 token each time + # but the last token will overlap with the first token of the next + # sample except for the last sample. + if ((total_tokens - 1) // seq_length) >= num_samples: + return num_epochs + + +def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch): + """Build an array with length = number-of-epochs * number-of-dcuments. + Each index is mapped to a corresponding document.""" + if not separate_last_epoch or num_epochs == 1: + doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + doc_idx_first = _build_doc_idx(documents, num_epochs-1, np_rng, False) + doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) + return np.concatenate((doc_idx_first, doc_idx_last)) + + +def _build_sample_idx(sizes, doc_idx, seq_length, + num_epochs, tokens_per_epoch): + """Sample index mapping is a 2D array with sizes + [number-of-samples + 1, 2] where [..., 0] contains + the index into `doc_idx` and [..., 1] is the + starting offset in that document.""" + + # Total number of samples. For -1 see comments in `_num_epochs`. + num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length + sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32) + + # Index into sample_idx. + sample_index = 0 + # Index into doc_idx. + doc_idx_index = 0 + # Begining offset for each document. + doc_offset = 0 + # Start with first document and no offset. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + while sample_index <= num_samples: + # Start with a fresh sequence. + remaining_seq_length = seq_length + 1 + while remaining_seq_length != 0: + # Get the document length. + doc_id = doc_idx[doc_idx_index] + doc_length = sizes[doc_id] - doc_offset + # And add it to the current sequence. + remaining_seq_length -= doc_length + # If we have more than a full sequence, adjust offset and set + # remaining length to zero so we return from the while loop. + # Note that -1 here is for the same reason we have -1 in + # `_num_epochs` calculations. + if remaining_seq_length <= 0: + doc_offset += (remaining_seq_length + doc_length - 1) + remaining_seq_length = 0 + else: + # Otherwise, start from the begining of the next document. + doc_idx_index += 1 + doc_offset = 0 + # Record the sequence. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + + return sample_idx + + +def _build_shuffle_idx(num_samples, total_size, np_rng): + """Build the range [0, size) and shuffle.""" + print(' > building shuffle index with split [0, {}) and [{}, {}) ' + '...'.format(num_samples, num_samples, total_size), flush=True) + + dtype_ = np.uint32 + if total_size >= (np.iinfo(np.uint32).max - 1): + dtype_ = np.int64 + + shuffle_idx_first = np.arange(start=0, stop=num_samples, + step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_first) + if num_samples == total_size: + return shuffle_idx_first + + shuffle_idx_last = np.arange(start=num_samples, stop=total_size, + step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_last) + + return np.concatenate((shuffle_idx_first, shuffle_idx_last)) diff --git a/megatron/data/helpers.cpp b/megatron/data/helpers.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e45926a976961eb5094658ba478cb697a88c8000 --- /dev/null +++ b/megatron/data/helpers.cpp @@ -0,0 +1,717 @@ +/* + coding=utf-8 + Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + + +/* Helper methods for fast index mapping builds */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace std; + +const int32_t LONG_SENTENCE_LEN = 512; + + +void build_blending_indices(py::array_t& dataset_index, + py::array_t& dataset_sample_index, + const py::array_t& weights, + const int32_t num_datasets, + const int64_t size, const bool verbose) { + /* Given multiple datasets and a weighting array, build samples + such that it follows those wieghts.*/ + + if (verbose) { + std::cout << "> building indices for blendable datasets ..." << std::endl; + } + + // Get the pointer access without the checks. + auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); + auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); + auto weights_ptr = weights.unchecked<1>(); + + // Initialize buffer for number of samples used for each dataset. + int64_t current_samples[num_datasets]; + for(int64_t i = 0; i < num_datasets; ++i) { + current_samples[i] = 0; + } + + // For each sample: + for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { + + // Determine where the max error in sampling is happening. + auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); + int64_t max_error_index = 0; + double max_error = weights_ptr[0] * sample_idx_double - + static_cast(current_samples[0]); + for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { + double error = weights_ptr[dataset_idx] * sample_idx_double - + static_cast(current_samples[dataset_idx]); + if (error > max_error) { + max_error = error; + max_error_index = dataset_idx; + } + } + + // Populate the indices. + dataset_index_ptr[sample_idx] = static_cast(max_error_index); + dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; + + // Update the total samples. + current_samples[max_error_index] += 1; + + } + + // print info + if (verbose) { + std::cout << " > sample ratios:" << std::endl; + for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { + auto ratio = static_cast(current_samples[dataset_idx]) / + static_cast(size); + std::cout << " dataset " << dataset_idx << ", input: " << + weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + } + } + +} + + +py::array build_sample_idx(const py::array_t& sizes_, + const py::array_t& doc_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch) { + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; + int32_t* sample_idx = new int32_t[2*(num_samples+1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << + doc_idx_.shape(0) / num_epochs << endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " sequence length: " << seq_length << + endl << std::flush; + cout << " total number of samples: " << num_samples << + endl << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Begining offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the begining of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + } + + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void *mem_) { + int32_t *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int32_t); + return py::array(std::vector{num_samples+1, 2}, // shape + {2*byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references + +} + + +inline int32_t get_target_sample_len(const int32_t short_seq_ratio, + const int32_t max_length, + std::mt19937& rand32_gen) { + /* Training sample length. */ + if (short_seq_ratio == 0) { + return max_length; + } + const auto random_number = rand32_gen(); + if ((random_number % short_seq_ratio) == 0) { + return 2 + random_number % (max_length - 1); + } + return max_length; +} + + +template +py::array build_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const double short_seq_prob, + const int32_t seed, + const bool verbose, + const int32_t min_num_sent) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(short_seq_prob >= 0.0); + assert(short_seq_prob <= 1.0); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + + // For efficiency, convert probability to ratio. Note: rand() generates int. + int32_t short_seq_ratio = 0; + if (short_seq_prob > 0) { + short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); + } + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " short sequence probability: " << short_seq_prob << + endl << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and it's length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the seed so both iterations produce the same results. + std::mt19937 rand32_gen(seed); + + // Set the flag on second iteration. + second = (iteration == 1); + + // Counters: + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + + // Current map index. + uint64_t map_index = 0; + + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + + // If we have more than two sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + auto target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent > 1) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Check for overflow. + if ((3 * map_index + 2) > + std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() + << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + seq_len = 0; + num_sent = 0; + } + + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 3}, // shape + {3*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + + +py::array build_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const double short_seq_prob, + const int seed, + const bool verbose, + const int32_t min_num_sent) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } +} + +template +py::array build_blocks_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const int32_t seed, + const bool verbose, + const bool use_one_sent_blocks) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + auto titles_sizes = titles_sizes_.unchecked<1>(); + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and its length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Acceptable number of sentences per block. + int min_num_sent = 2; + if (use_one_sent_blocks) { + min_num_sent = 1; + } + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the flag on second iteration. + second = (iteration == 1); + + // Current map index. + uint64_t map_index = 0; + + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + const auto target_seq_len = max_seq_length - titles_sizes[doc]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent >= min_num_sent) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + // If we have enough sentences and no long sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and there are an acceptable number of sentences left + // and if we have at least the minimum number of sentences. + // or if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Populate the map. + if (second) { + const auto map_index_0 = 4 * map_index; + // Each sample has 4 items: the starting sentence index, ending sentence index, + // the index of the document from which the block comes (used for fetching titles) + // and the unique id of the block (used for creating block indexes) + + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(doc); + maps[map_index_0 + 3] = static_cast(block_id); + } + + // Update indices / counters. + ++map_index; + ++block_id; + prev_start_index = sent_index + 1; + seq_len = 0; + num_sent = 0; + } + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[4*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 4 * i; + const auto j0 = 4 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + swap(maps[i0 + 3], maps[j0 + 3]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 4}, // shape + {4*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + +py::array build_blocks_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const int seed, + const bool verbose, + const bool use_one_sent_blocks) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } +} + +PYBIND11_MODULE(helpers, m) { + m.def("build_mapping", &build_mapping); + m.def("build_blocks_mapping", &build_blocks_mapping); + m.def("build_sample_idx", &build_sample_idx); + m.def("build_blending_indices", &build_blending_indices); +} diff --git a/megatron/data/helpers.cpython-38-x86_64-linux-gnu.so b/megatron/data/helpers.cpython-38-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..0de5ce1f755a3443a1ea7b67fbd04995abe759f6 Binary files /dev/null and b/megatron/data/helpers.cpython-38-x86_64-linux-gnu.so differ diff --git a/megatron/data/helpers.cpython-39-x86_64-linux-gnu.so b/megatron/data/helpers.cpython-39-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..541cdec2936637c835c72d362ec35c590e91a051 Binary files /dev/null and b/megatron/data/helpers.cpython-39-x86_64-linux-gnu.so differ diff --git a/megatron/data/ict_dataset.py b/megatron/data/ict_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6dac35ff9d413898146df5c1cc8553719e142105 --- /dev/null +++ b/megatron/data/ict_dataset.py @@ -0,0 +1,156 @@ +import itertools +import random + +import numpy as np +from torch.utils.data import Dataset + +from megatron import get_tokenizer +from megatron import get_args +from megatron.data.dataset_utils import get_indexed_dataset_ +from megatron.data.realm_dataset_utils import get_block_samples_mapping + +def make_attention_mask(source_block, target_block): + """ + Returns a 2-dimensional (2-D) attention mask + :param source_block: 1-D array + :param target_block: 1-D array + """ + mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) + mask = mask.astype(np.int64) + # (source_length, target_length) + return mask + +def get_ict_dataset(use_titles=True, query_in_block_prob=1): + """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) + rather than for training, since it is only built with a single epoch sample mapping. + """ + args = get_args() + block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) + titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) + + kwargs = dict( + name='full', + block_dataset=block_dataset, + title_dataset=titles_dataset, + data_prefix=args.data_path, + num_epochs=1, + max_num_samples=None, + max_seq_length=args.seq_length, + seed=1, + query_in_block_prob=query_in_block_prob, + use_titles=use_titles, + use_one_sent_docs=args.use_one_sent_docs + ) + dataset = ICTDataset(**kwargs) + return dataset + + +class ICTDataset(Dataset): + """Dataset containing sentences and their blocks for an inverse cloze task.""" + def __init__(self, name, block_dataset, title_dataset, data_prefix, + num_epochs, max_num_samples, max_seq_length, query_in_block_prob, + seed, use_titles=True, use_one_sent_docs=False, binary_head=False): + self.name = name + self.seed = seed + self.max_seq_length = max_seq_length + self.query_in_block_prob = query_in_block_prob + self.block_dataset = block_dataset + self.title_dataset = title_dataset + self.rng = random.Random(self.seed) + self.use_titles = use_titles + self.use_one_sent_docs = use_one_sent_docs + + self.samples_mapping = get_block_samples_mapping( + block_dataset, title_dataset, data_prefix, num_epochs, + max_num_samples, max_seq_length, seed, name, use_one_sent_docs) + self.tokenizer = get_tokenizer() + self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_list = self.tokenizer.inv_vocab + self.cls_id = self.tokenizer.cls + self.sep_id = self.tokenizer.sep + self.mask_id = self.tokenizer.mask + self.pad_id = self.tokenizer.pad + + def __len__(self): + return len(self.samples_mapping) + + def __getitem__(self, idx): + """Get an ICT example of a pseudo-query and the block of text from which it was extracted""" + sample_data = self.samples_mapping[idx] + start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple() + + if self.use_titles: + title = self.title_dataset[int(doc_idx)] + title_pad_offset = 3 + len(title) + else: + title = None + title_pad_offset = 2 + block = [self.block_dataset[i] for i in range(start_idx, end_idx)] + assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1 + + # randint() is inclusive for Python rng + rand_sent_idx = self.rng.randint(0, len(block) - 1) + + # keep the query in the context query_in_block_prob fraction of the time. + if self.rng.random() < self.query_in_block_prob: + query = block[rand_sent_idx].copy() + else: + query = block.pop(rand_sent_idx) + + # still need to truncate because blocks are concluded when + # the sentence lengths have exceeded max_seq_length. + query = query[:self.max_seq_length - 2] + block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] + + query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) + context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title) + + query_mask = make_attention_mask(query_tokens, query_tokens) + context_mask = make_attention_mask(context_tokens, context_tokens) + + block_data = sample_data.as_array() + + sample = { + 'query_tokens': query_tokens, + 'query_mask': query_mask, + 'query_pad_mask': query_pad_mask, + 'context_tokens': context_tokens, + 'context_mask': context_mask, + 'context_pad_mask': context_pad_mask, + 'block_data': block_data, + } + + return sample + + def get_block(self, start_idx, end_idx, doc_idx): + """Get the IDs for an evidence block plus the title of the corresponding document""" + block = [self.block_dataset[i] for i in range(start_idx, end_idx)] + title = self.title_dataset[int(doc_idx)] + + block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] + block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) + + return block_tokens, block_pad_mask + + def get_null_block(self): + """Get empty block and title - used in REALM pretraining""" + block, title = [], [] + block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) + + return block_tokens, block_pad_mask + + def concat_and_pad_tokens(self, tokens, title=None): + """Concat with special tokens and pad sequence to self.max_seq_length""" + tokens = list(tokens) + if title is None: + tokens = [self.cls_id] + tokens + [self.sep_id] + else: + title = list(title) + tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id] + assert len(tokens) <= self.max_seq_length + + num_pad = self.max_seq_length - len(tokens) + pad_mask = [1] * len(tokens) + [0] * num_pad + tokens += [self.pad_id] * num_pad + + return np.array(tokens), np.array(pad_mask) diff --git a/megatron/data/image_folder.py b/megatron/data/image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..de15b29bf0665562a00bfcab8b106ff2d4ca26f2 --- /dev/null +++ b/megatron/data/image_folder.py @@ -0,0 +1,302 @@ +# BSD 3-Clause License +# +# Copyright (c) Soumith Chintala 2016, +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# code taken from +# https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py +# added support for classes_fraction and data_per_class_fraction + +from torchvision.datasets import VisionDataset +from PIL import Image + +import os +import os.path +from typing import Any, Callable, cast, Dict, List, Optional, Tuple +import numpy as np + +def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: + """Checks if a file is an allowed extension. + Args: + filename (string): path to a file + extensions (tuple of strings): extensions to consider (lowercase) + Returns: + bool: True if the filename ends with one of given extensions + """ + return filename.lower().endswith(extensions) + + +def is_image_file(filename: str) -> bool: + """Checks if a file is an allowed image extension. + Args: + filename (string): path to a file + Returns: + bool: True if the filename ends with a known image extension + """ + return has_file_allowed_extension(filename, IMG_EXTENSIONS) + + +def make_dataset( + directory: str, + class_to_idx: Dict[str, int], + data_per_class_fraction: float, + extensions: Optional[Tuple[str, ...]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, +) -> List[Tuple[str, int]]: + """Generates a list of samples of a form (path_to_sample, class). + Args: + directory (str): root dataset directory + class_to_idx (Dict[str, int]): dictionary mapping class name to class index + extensions (optional): A list of allowed extensions. + Either extensions or is_valid_file should be passed. Defaults to None. + is_valid_file (optional): A function that takes path of a file + and checks if the file is a valid file + (used to check of corrupt files) both extensions and + is_valid_file should not be passed. Defaults to None. + Raises: + ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. + Returns: + List[Tuple[str, int]]: samples of a form (path_to_sample, class) + """ + instances = [] + directory = os.path.expanduser(directory) + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + if both_none or both_something: + raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + if extensions is not None: + def is_valid_file(x: str) -> bool: + return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) + is_valid_file = cast(Callable[[str], bool], is_valid_file) + for target_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[target_class] + target_dir = os.path.join(directory, target_class) + if not os.path.isdir(target_dir): + continue + local_instances = [] + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + item = path, class_index + local_instances.append(item) + + instances.extend(local_instances[0:int(len(local_instances) * data_per_class_fraction)]) + + return instances + + +class DatasetFolder(VisionDataset): + """A generic data loader where the samples are arranged in this way: :: + root/class_x/xxx.ext + root/class_x/xxy.ext + root/class_x/[...]/xxz.ext + root/class_y/123.ext + root/class_y/nsdf3.ext + root/class_y/[...]/asd932_.ext + Args: + root (string): Root directory path. + loader (callable): A function to load a sample given its path. + extensions (tuple[string]): A list of allowed extensions. + both extensions and is_valid_file should not be passed. + transform (callable, optional): A function/transform that takes in + a sample and returns a transformed version. + E.g, ``transforms.RandomCrop`` for images. + target_transform (callable, optional): A function/transform that takes + in the target and transforms it. + is_valid_file (callable, optional): A function that takes path of a file + and check if the file is a valid file (used to check of corrupt files) + both extensions and is_valid_file should not be passed. + Attributes: + classes (list): List of the class names sorted alphabetically. + class_to_idx (dict): Dict with items (class_name, class_index). + samples (list): List of (sample path, class_index) tuples + targets (list): The class_index value for each image in the dataset + """ + + def __init__( + self, + root: str, + loader: Callable[[str], Any], + extensions: Optional[Tuple[str, ...]] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + classes_fraction=1.0, + data_per_class_fraction=1.0, + is_valid_file: Optional[Callable[[str], bool]] = None, + ) -> None: + super(DatasetFolder, self).__init__(root, transform=transform, + target_transform=target_transform) + self.classes_fraction = classes_fraction + self.data_per_class_fraction = data_per_class_fraction + classes, class_to_idx = self._find_classes(self.root) + samples = self.make_dataset(self.root, + class_to_idx, + self.data_per_class_fraction, + extensions, + is_valid_file) + if len(samples) == 0: + msg = "Found 0 files in subfolders of: {}\n".format(self.root) + if extensions is not None: + msg += "Supported extensions are: {}".format(",".join(extensions)) + raise RuntimeError(msg) + + self.loader = loader + self.extensions = extensions + self.total = len(samples) + self.classes = classes + self.class_to_idx = class_to_idx + self.samples = samples + self.targets = [s[1] for s in samples] + + @staticmethod + def make_dataset( + directory: str, + class_to_idx: Dict[str, int], + data_per_class_fraction: float, + extensions: Optional[Tuple[str, ...]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + ) -> List[Tuple[str, int]]: + return make_dataset(directory, + class_to_idx, + data_per_class_fraction, + extensions=extensions, + is_valid_file=is_valid_file) + + def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: + """ + Finds the class folders in a dataset. + Args: + dir (string): Root directory path. + Returns: + tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + Ensures: + No class is a subdirectory of another. + """ + all_classes = [d.name for d in os.scandir(dir) if d.is_dir()] + classes = all_classes[0:int(len(all_classes) * self.classes_fraction)] + classes.sort() + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + return classes, class_to_idx + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + curr_index = index + for x in range(self.total): + try: + path, target = self.samples[curr_index] + sample = self.loader(path) + break + except Exception as e: + curr_index = np.random.randint(0, self.total) + + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target + + def __len__(self) -> int: + return len(self.samples) + + +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') + + +def pil_loader(path: str) -> Image.Image: + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + + +# TODO: specify the return type +def accimage_loader(path: str) -> Any: + import accimage + try: + return accimage.Image(path) + except IOError: + # Potentially a decoding problem, fall back to PIL.Image + return pil_loader(path) + + +def default_loader(path: str) -> Any: + from torchvision import get_image_backend + if get_image_backend() == 'accimage': + return accimage_loader(path) + else: + return pil_loader(path) + + +class ImageFolder(DatasetFolder): + """A generic data loader where the images are arranged in this way: :: + root/dog/xxx.png + root/dog/xxy.png + root/dog/[...]/xxz.png + root/cat/123.png + root/cat/nsdf3.png + root/cat/[...]/asd932_.png + Args: + root (string): Root directory path. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + is_valid_file (callable, optional): A function that takes path of an Image file + and check if the file is a valid file (used to check of corrupt files) + Attributes: + classes (list): List of the class names sorted alphabetically. + class_to_idx (dict): Dict with items (class_name, class_index). + imgs (list): List of (image path, class_index) tuples + """ + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + classes_fraction=1.0, + data_per_class_fraction=1.0, + loader: Callable[[str], Any] = default_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, + ): + super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, + transform=transform, + target_transform=target_transform, + classes_fraction=classes_fraction, + data_per_class_fraction=data_per_class_fraction, + is_valid_file=is_valid_file) + self.imgs = self.samples + diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6e1b845c8edfab063716562355baf25e8f646f --- /dev/null +++ b/megatron/data/indexed_dataset.py @@ -0,0 +1,576 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# copied from fairseq/fairseq/data/indexed_dataset.py +# Removed IndexedRawTextDataset since it relied on Fairseq dictionary +# other slight modifications to remove fairseq dependencies +# Added document index to index file and made it accessible. +# An empty sentence no longer separates documents. + +from functools import lru_cache +import os +import shutil +import struct +from itertools import accumulate + +import numpy as np +import torch +from megatron import print_rank_0 + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def get_available_dataset_impl(): + return ['lazy', 'cached', 'mmap'] + + +def infer_dataset_impl(path): + if IndexedDataset.exists(path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return 'cached' + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return 'mmap' + else: + return None + else: + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == 'mmap': + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, skip_warmup=False): + if not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + if impl == 'infer': + impl = infer_dataset_impl(path) + if impl == 'lazy' and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == 'cached' and IndexedDataset.exists(path): + return IndexedCachedDataset(path) + elif impl == 'mmap' and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup) + print(f"Unknown dataset implementation: {impl}") + return None + + +def dataset_exists(path, impl): + if impl == 'mmap': + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float, + 7: np.double, + 8: np.uint16 +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + '.idx' + + +def data_file_path(prefix_path): + return prefix_path + '.bin' + + +def create_doc_idx(sizes): + doc_idx = [0] + for i, s in enumerate(sizes): + if s == 0: + doc_idx.append(i + 1) + return doc_idx + + +class IndexedDataset(torch.utils.data.Dataset): + """Loader for IndexedDataset""" + _HDR_MAGIC = b'TNTIDX\x00\x00' + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + 'Index file doesn\'t match expected format. ' + 'Make sure that --dataset-impl is configured properly.' + ) + version = f.read(8) + assert struct.unpack('= self._len: + raise IndexError('index out of range') + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return ( + os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + ) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx: ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx: ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float: 4, + np.double: 8 + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, 'wb') + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, tensor): + bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype)) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in tensor.size(): + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + doc_offset = len(self.sizes) + + begin = self.data_offsets[-1] + for data_offset in index.data_offsets[1:]: + self.data_offsets.append(begin + data_offset) + self.sizes.extend(index.sizes) + + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + self.doc_idx.extend((doc_offset + index.doc_idx)[1:]) + + with open(data_file_path(another_file), 'rb') as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, 'wb') + index.write(b'TNTIDX\x00\x00') + index.write(struct.pack(' max_seq_length - 1: + enc_ids = enc_ids[0: max_seq_length - 1] + tokentypes_enc = tokentypes_enc[0: max_seq_length - 1] + + # [SEP]. + enc_ids.append(sep_id) + tokentypes_enc.append(0) + + num_tokens_enc = len(enc_ids) + # Padding. + padding_length = max_seq_length - len(enc_ids) + if padding_length > 0: + enc_ids.extend([pad_id] * padding_length) + tokentypes_enc.extend([pad_id] * padding_length) + + pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length) + pad_mask = np.array(pad_mask, dtype=np.int64) + + return enc_ids, tokentypes_enc, pad_mask + + +def build_sample(row_id, context_ids, context_types, context_pad_mask): + """Convert to numpy and return a sample consumed by the batch producer.""" + + context_ids = np.array(context_ids, dtype=np.int64) + context_types = np.array(context_types, dtype=np.int64) + context_mask = make_attention_mask(context_ids, context_ids) + + sample = ({ + 'row_id': row_id, + 'context': context_ids, + 'context_mask': context_mask, + 'context_types': context_types, + 'context_pad_mask': context_pad_mask + }) + return sample + + +class OpenRetrievalEvidenceDataset(ABC, Dataset): + """Open Retrieval Evidence dataset class.""" + + def __init__(self, task_name, dataset_name, datapath, tokenizer, + max_seq_length): + # Store inputs. + self.task_name = task_name + self.dataset_name = dataset_name + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + print_rank_0(' > building {} dataset for {}:'.format(self.task_name, + self.dataset_name)) + # Process the files. + print_rank_0(datapath) + self.samples, self.id2text = self.process_samples_from_single_path( + datapath) + + args = get_args() + if args.sample_rate < 1: # subsample + k = int(len(self.samples) * args.sample_rate) + self.samples = random.sample(self.samples, k) + + print_rank_0(' >> total number of samples: {}'.format( + len(self.samples))) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + row = self.samples[idx] + + context_ids, context_types, context_pad_mask = \ + build_tokens_types_paddings_from_text(row, self.tokenizer, + self.max_seq_length) + + sample = build_sample(row['doc_id'], + context_ids, + context_types, + context_pad_mask) + return sample + + @staticmethod + def process_samples_from_single_path(filename): + print_rank_0(' > Processing {} ...'.format(filename)) + total = 0 + + rows = [] + id2text = {} + + with open(filename) as tsvfile: + reader = csv.reader(tsvfile, delimiter='\t') + next(reader, None) # skip the headers + for row in reader: + # file format: doc_id, doc_text, title + doc_id = int(row[0]) + text = row[1] + title = row[2] + + rows.append({'doc_id': doc_id, + 'text': text, + 'title': title}) + + assert doc_id not in id2text + id2text[doc_id] = (text, title) + + total += 1 + if total % 100000 == 0: + print_rank_0(' > processed {} rows so far ...'.format( + total)) + + print_rank_0(' >> processed {} samples.'.format(len(rows))) + return rows, id2text diff --git a/megatron/data/realm_dataset_utils.py b/megatron/data/realm_dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aecf5549a73db3875432b8cc4cfebfc2b828b0b5 --- /dev/null +++ b/megatron/data/realm_dataset_utils.py @@ -0,0 +1,198 @@ +import os +import time + +import numpy as np +import torch + +from megatron import mpu, print_rank_0 +from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy +from megatron import get_args, get_tokenizer, print_rank_0, mpu + + +def get_one_epoch_dataloader(dataset, micro_batch_size=None): + """Specifically one epoch to be used in an indexing job.""" + args = get_args() + + world_size = mpu.get_data_parallel_world_size() + rank = mpu.get_data_parallel_rank() + if micro_batch_size is None: + micro_batch_size = args.micro_batch_size + global_batch_size = micro_batch_size * world_size + num_workers = args.num_workers + + sampler = torch.utils.data.SequentialSampler(dataset) + # importantly, drop_last must be False to get all the data. + assert False, 'DistributedBatchSampler deprecated, change the implementation' + from megatron.data.samplers import DistributedBatchSampler + batch_sampler = DistributedBatchSampler(sampler, + batch_size=global_batch_size, + drop_last=False, + rank=rank, + world_size=world_size) + + return torch.utils.data.DataLoader(dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + pin_memory=True) + + +def get_ict_batch(data_iterator): + # Items and their type. + keys = ['query_tokens', 'query_pad_mask', + 'block_tokens', 'block_pad_mask', 'block_data'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is None: + data = None + else: + data = next(data_iterator) + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + query_tokens = data_b['query_tokens'].long() + query_pad_mask = data_b['query_pad_mask'].long() + block_tokens = data_b['block_tokens'].long() + block_pad_mask = data_b['block_pad_mask'].long() + block_indices = data_b['block_data'].long() + + return query_tokens, query_pad_mask,\ + block_tokens, block_pad_mask, block_indices + + +def join_str_list(str_list): + """Join a list of strings, handling spaces appropriately""" + result = "" + for s in str_list: + if s.startswith("##"): + result += s[2:] + else: + result += " " + s + return result + + +class BlockSampleData(object): + """A struct for fully describing a fixed-size block of data as used in REALM + + :param start_idx: for first sentence of the block + :param end_idx: for last sentence of the block (may be partially truncated in sample construction) + :param doc_idx: the index of the document from which the block comes in the original indexed dataset + :param block_idx: a unique integer identifier given to every block. + """ + def __init__(self, start_idx, end_idx, doc_idx, block_idx): + self.start_idx = start_idx + self.end_idx = end_idx + self.doc_idx = doc_idx + self.block_idx = block_idx + + def as_array(self): + return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64) + + def as_tuple(self): + return self.start_idx, self.end_idx, self.doc_idx, self.block_idx + + +class BlockSamplesMapping(object): + def __init__(self, mapping_array): + # make sure that the array is compatible with BlockSampleData + assert mapping_array.shape[1] == 4 + self.mapping_array = mapping_array + + def __len__(self): + return self.mapping_array.shape[0] + + def __getitem__(self, idx): + """Get the data associated with an indexed sample.""" + sample_data = BlockSampleData(*self.mapping_array[idx]) + return sample_data + + +def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs, + max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False): + """Get samples mapping for a dataset over fixed size blocks. This function also requires + a dataset of the titles for the source documents since their lengths must be taken into account. + + :return: samples_mapping (BlockSamplesMapping) + """ + + if not num_epochs: + if not max_num_samples: + raise ValueError("Need to specify either max_num_samples " + "or num_epochs") + num_epochs = np.iinfo(np.int32).max - 1 + if not max_num_samples: + max_num_samples = np.iinfo(np.int64).max - 1 + + # Filename of the index mapping + indexmap_filename = data_prefix + indexmap_filename += '_{}_indexmap'.format(name) + if num_epochs != (np.iinfo(np.int32).max - 1): + indexmap_filename += '_{}ep'.format(num_epochs) + if max_num_samples != (np.iinfo(np.int64).max - 1): + indexmap_filename += '_{}mns'.format(max_num_samples) + indexmap_filename += '_{}msl'.format(max_seq_length) + indexmap_filename += '_{}s'.format(seed) + if use_one_sent_docs: + indexmap_filename += '_1sentok' + indexmap_filename += '.npy' + + # Build the indexed mapping if not exist. + if mpu.get_data_parallel_rank() == 0 and \ + not os.path.isfile(indexmap_filename): + print(' > WARNING: could not find index map file {}, building ' + 'the indices on rank 0 ...'.format(indexmap_filename)) + + # Make sure the types match the helpers input types. + assert block_dataset.doc_idx.dtype == np.int64 + assert block_dataset.sizes.dtype == np.int32 + + # Build samples mapping + verbose = torch.distributed.get_rank() == 0 + start_time = time.time() + print_rank_0(' > building samples index mapping for {} ...'.format( + name)) + + from megatron.data import helpers + mapping_array = helpers.build_blocks_mapping( + block_dataset.doc_idx, + block_dataset.sizes, + title_dataset.sizes, + num_epochs, + max_num_samples, + max_seq_length - 3, # account for added tokens + seed, + verbose, + use_one_sent_docs) + + + print_rank_0(' > done building samples index mapping') + np.save(indexmap_filename, mapping_array, allow_pickle=True) + print_rank_0(' > saved the index mapping in {}'.format( + indexmap_filename)) + # Make sure all the ranks have built the mapping + print_rank_0(' > elapsed time to build and save samples mapping ' + '(seconds): {:4f}'.format( + time.time() - start_time)) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + assert counts[0].item() == torch.distributed.get_world_size( + group=mpu.get_data_parallel_group()) + + # Load indexed dataset. + print_rank_0(' > loading indexed mapping from {}'.format( + indexmap_filename)) + start_time = time.time() + + mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') + samples_mapping = BlockSamplesMapping(mapping_array) + + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + print_rank_0(' total number of samples: {}'.format( + mapping_array.shape[0])) + + return samples_mapping diff --git a/megatron/data/realm_index.py b/megatron/data/realm_index.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b543c7e040088c7a4bc0e0beb85810d94b2781 --- /dev/null +++ b/megatron/data/realm_index.py @@ -0,0 +1,224 @@ +import itertools +import os +import pickle +import shutil + +import numpy as np +import torch + +from megatron import get_args +from megatron import mpu + + +def detach(tensor): + return tensor.detach().cpu().numpy() + + +class OpenRetreivalDataStore(object): + """ + Serializable data structure for holding data for blocks -- + embeddings and necessary metadata for Retriever + """ + def __init__(self, embedding_path=None, load_from_path=True, rank=None): + self.embed_data = dict() + if embedding_path is None: + args = get_args() + embedding_path = args.embedding_path + rank = args.rank + self.embedding_path = embedding_path + self.rank = rank + + if load_from_path: + self.load_from_file() + + block_data_name = os.path.splitext(self.embedding_path)[0] + self.temp_dir_name = block_data_name + '_tmp' + + def state(self): + return { + 'embed_data': self.embed_data, + } + + def clear(self): + """ + Clear the embedding data structures to save memory. + The metadata ends up getting used, and is also much smaller in + dimensionality so it isn't really worth clearing. + """ + self.embed_data = dict() + + def load_from_file(self): + """Populate members from instance saved to file""" + + if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: + print("\n> Unpickling BlockData", flush=True) + state_dict = pickle.load(open(self.embedding_path, 'rb')) + if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: + print(">> Finished unpickling BlockData\n", flush=True) + + self.embed_data = state_dict['embed_data'] + + def add_block_data(self, row_id, block_embeds, allow_overwrite=False): + """ + Add data for set of blocks + :param row_id: 1D array of unique int ids for the blocks + :param block_embeds: 2D array of embeddings of the blocks + In the case of retriever this will be [start_idx, end_idx, doc_idx] + """ + for idx, embed in zip(row_id, block_embeds): + if not allow_overwrite and idx in self.embed_data: + raise ValueError("Unexpectedly tried to overwrite block data") + + self.embed_data[idx] = np.float16(embed) + + def save_shard(self): + """ + Save the block data that was created this in this process + """ + if not os.path.isdir(self.temp_dir_name): + os.makedirs(self.temp_dir_name, exist_ok=True) + + # save the data for each shard + with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \ + as writer: + pickle.dump(self.state(), writer) + + def merge_shards_and_save(self): + #Combine all the shards made using save_shard + shard_names = os.listdir(self.temp_dir_name) + seen_own_shard = False + + for fname in os.listdir(self.temp_dir_name): + shard_rank = int(os.path.splitext(fname)[0]) + if shard_rank == self.rank: + seen_own_shard = True + continue + + with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f: + data = pickle.load(f) + old_size = len(self.embed_data) + shard_size = len(data['embed_data']) + + # add the shard's data and check to make sure there + # is no overlap + self.embed_data.update(data['embed_data']) + assert len(self.embed_data) == old_size + shard_size + + assert seen_own_shard + + # save the consolidated shards and remove temporary directory + with open(self.embedding_path, 'wb') as final_file: + pickle.dump(self.state(), final_file) + shutil.rmtree(self.temp_dir_name, ignore_errors=True) + + print("Finished merging {} shards for a total of {} embeds".format( + len(shard_names), len(self.embed_data)), flush=True) + + +class FaissMIPSIndex(object): + """ + Wrapper object for a BlockData which similarity search via FAISS under the hood + """ + def __init__(self, embed_size, embed_data=None, use_gpu=False): + self.embed_size = embed_size + self.embed_data = embed_data + self.use_gpu = use_gpu + + self.mips_index = None + self._set_mips_index() + + def _set_mips_index(self): + """ + Create a Faiss Flat index with inner product as the metric + to search against + """ + try: + import faiss + except ImportError: + raise Exception("Error: Please install faiss to use FaissMIPSIndex") + + if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: + print("\n> Building index", flush=True) + + cpu_index = faiss.IndexFlatIP(self.embed_size) + + if self.use_gpu: + # create resources and config for GpuIndex + config = faiss.GpuMultipleClonerOptions() + config.shard = True + config.useFloat16 = True + gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config) + self.mips_index = faiss.IndexIDMap(gpu_index) + if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: + print(">> Initialized index on GPU", flush=True) + else: + # CPU index supports IDs so wrap with IDMap + self.mips_index = faiss.IndexIDMap(cpu_index) + if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: + print(">> Initialized index on CPU", flush=True) + + # if we were constructed with a BlockData, then automatically load it + # when the FAISS structure is built + if self.embed_data is not None: + self.add_embed_data(self.embed_data) + + def reset_index(self): + """Delete existing index and create a new""" + del self.mips_index + + # reset the block data so that _set_block_index will reload it as well + if self.embed_data is not None: + embed_data_path = self.embed_data.embedding_path + del self.embed_data + self.embed_data = OpenRetreivalDataStore(embed_data_path) + + self._set_mips_index() + + def update_index(self): + """Delete existing index and create a new""" + del self.mips_index + + # reset the block data so that _set_mips_index will reload it as well + if self.embed_data is not None: + self.embed_data.load_from_file() + self._set_mips_index() + + def add_embed_data(self, all_embed_data): + """Add the embedding of each block to the underlying FAISS index""" + + # this assumes the embed_data is a dict : {int: np.array} + block_indices, block_embeds = zip(*all_embed_data.embed_data.items()) + + # the embeddings have to be entered in as float32 even though the math + # internally is done with float16. + embeds_arr = np.float32(np.array(block_embeds)) + indices_arr = np.array(block_indices) + + # we no longer need the embedding data since it's in the index now + all_embed_data.clear() + + self.mips_index.add_with_ids(embeds_arr, indices_arr) + + if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: + print(">>> Finished adding block data to index", flush=True) + + def search_mips_index(self, query_embeds, top_k, reconstruct=True): + """ + Get the top-k blocks by the index distance metric. + + :param reconstruct: if True: return a [num_queries x k x embed_dim] + array of blocks + if False: return [num_queries x k] array of + distances, and another for indices + """ + query_embeds = np.float32(detach(query_embeds)) + + if reconstruct: + # get the vectors themselves + top_k_block_embeds = self.mips_index.search_and_reconstruct(\ + query_embeds, top_k) + return top_k_block_embeds + else: + # get distances and indices of closest vectors + distances, block_indices = self.mips_index.search(query_embeds, top_k) + return distances, block_indices diff --git a/megatron/data/t5_dataset.py b/megatron/data/t5_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..42110b9239831f16ac6cc75e2092a37f760ebc77 --- /dev/null +++ b/megatron/data/t5_dataset.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5 Style dataset.""" + +import collections + +import numpy as np +import torch + +from megatron import get_tokenizer +from megatron.data.dataset_utils import ( + create_masked_lm_predictions, + get_samples_mapping +) + +class T5Dataset(torch.utils.data.Dataset): + + def __init__(self, name, indexed_dataset, data_prefix, + num_epochs, max_num_samples, masked_lm_prob, + max_seq_length, max_seq_length_dec, + short_seq_prob, seed): + + # Params to store. + self.name = name + self.seed = seed + self.masked_lm_prob = masked_lm_prob + self.max_seq_length = max_seq_length + self.max_seq_length_dec = max_seq_length_dec + + # Dataset. + self.indexed_dataset = indexed_dataset + + # Build the samples mapping. + self.samples_mapping = get_samples_mapping(self.indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + self.max_seq_length - 2, # account for added tokens + short_seq_prob, + self.seed, + self.name, + False) + + # Vocab stuff. + tokenizer = get_tokenizer() + self.vocab_id_list = list(tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_dict = tokenizer.inv_vocab + self.cls_id = tokenizer.cls + self.sep_id = tokenizer.sep + self.mask_id = tokenizer.mask + self.pad_id = tokenizer.pad + self.bos_id = tokenizer.bos_token_id + self.eos_id = tokenizer.eos_token_id + self.sentinel_tokens = tokenizer.additional_special_tokens_ids + assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script" + + def __len__(self): + return self.samples_mapping.shape[0] + + def __getitem__(self, idx): + + start_index, end_index, seq_length = self.samples_mapping[idx] + sample = [] + for index in range(start_index, end_index): + sample.append(self.indexed_dataset[index]) + # Note that this rng state should be numpy and not python since + # python randint is inclusive whereas the numpy one is exclusive. + np_rng = np.random.RandomState(seed=(self.seed + idx)) + return build_training_sample(sample, seq_length, + self.max_seq_length, # needed for padding + self.max_seq_length_dec, + self.vocab_id_list, + self.vocab_id_to_token_dict, + self.cls_id, self.sep_id, + self.mask_id, self.pad_id, + self.masked_lm_prob, np_rng, + self.bos_id, self.eos_id, + self.sentinel_tokens) + + +def build_training_sample(sample, target_seq_length, + max_seq_length, max_seq_length_dec, + vocab_id_list, vocab_id_to_token_dict, + cls_id, sep_id, mask_id, pad_id, + masked_lm_prob, np_rng, bos_id=None, + eos_id=None, sentinel_tokens=None): + """Build training sample. + + Arguments: + sample: A list of sentences in which each sentence is a list token ids. + target_seq_length: Desired sequence length. + max_seq_length: Maximum length of the sequence. All values are padded to + this length. + vocab_id_list: List of vocabulary ids. Used to pick a random id. + vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. + cls_id: Start of example id. + sep_id: Separator id. + mask_id: Mask token id. + pad_id: Padding token id. + masked_lm_prob: Probability to mask tokens. + np_rng: Random number genenrator. Note that this rng state should be + numpy and not python since python randint is inclusive for + the opper bound whereas the numpy one is exclusive. + bos_id: start of decoder example id + eos_id: end of generation id + sentinel_tokens: unique value to be substituted for every replaced span + """ + + assert target_seq_length <= max_seq_length + + # flatten sentences into one list + tokens = [token for sentence in sample for token in sentence] + + # Truncate to `target_sequence_length`. + max_num_tokens = target_seq_length + truncated = len(tokens) > max_num_tokens + tokens = tokens[:max_num_tokens] + + # Masking. + max_predictions_per_seq = masked_lm_prob * max_num_tokens + (tokens, masked_positions, masked_labels, _, masked_spans) = create_masked_lm_predictions( + tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, + cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng, + max_ngrams=10, geometric_dist=True, masking_style="t5") + + # Padding. + tokens_enc, tokens_dec_in, labels, enc_mask, \ + dec_mask, enc_dec_mask, loss_mask \ + = pad_and_convert_to_numpy(tokens, masked_positions, + masked_labels, pad_id, max_seq_length, + max_seq_length_dec, masked_spans, + bos_id, eos_id, sentinel_tokens) + + train_sample = { + 'text_enc': tokens_enc, + 'text_dec': tokens_dec_in, + 'labels': labels, + 'loss_mask': loss_mask, + 'truncated': int(truncated), + 'enc_mask': enc_mask, + 'dec_mask': dec_mask, + 'enc_dec_mask': enc_dec_mask, + } + return train_sample + + +def pad_and_convert_to_numpy(tokens, masked_positions, + masked_labels, pad_id, + max_seq_length, max_seq_length_dec, + masked_spans=None, bos_id=None, + eos_id=None, sentinel_tokens=None): + """Pad sequences and convert them to numpy.""" + + sentinel_tokens = collections.deque(sentinel_tokens) + t5_input = [] + (t5_decoder_in, t5_decoder_out) = ([bos_id], []) + (start_index, end_index) = (0, None) + for span in masked_spans: + flag = sentinel_tokens.popleft() + + # Append the same tokens in decoder input and output + t5_decoder_in.append(flag) + t5_decoder_in.extend(span.label) + t5_decoder_out.append(flag) + t5_decoder_out.extend(span.label) + + end_index = span.index[0] + t5_input.extend(tokens[start_index: end_index]) + t5_input.append(flag) + + # the next start index is the token after the last span token + start_index = span.index[-1] + 1 + + # Add token to the t5_decoder_out + t5_decoder_out.append(eos_id) + + # Add the remaining tokens to the t5 input + t5_input.extend(tokens[start_index:]) + + # assert (len(t5_input) - len(masked_spans)) + \ + # (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens) + + # Some checks. + + # Encoder-side padding mask. + num_tokens = len(t5_input) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + assert len(masked_positions) == len(masked_labels) + + # Tokens.. + filler = [pad_id] * padding_length + tokens_enc = np.array(t5_input + filler, dtype=np.int64) + + # Decoder-side padding mask. + num_tokens_dec = len(t5_decoder_in) + padding_length_dec = max_seq_length_dec - num_tokens_dec + assert padding_length_dec >= 0 + filler_dec = [pad_id] * padding_length_dec + tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64) + + # Create attention masks + enc_mask = make_attention_mask(tokens_enc, tokens_enc) + enc_dec_mask = make_attention_mask(tokens_dec_in, tokens_enc) + dec_mask = make_attention_mask(tokens_dec_in, tokens_dec_in) + dec_mask = dec_mask * make_history_mask(tokens_dec_in) + + # Labels mask. + labels = t5_decoder_out + ([-1] * padding_length_dec) + labels = np.array(labels, dtype=np.int64) + + # Loss mask + loss_mask = ([1] * num_tokens_dec) + ([0] * padding_length_dec) + loss_mask = np.array(loss_mask, dtype=np.int64) + + return tokens_enc, tokens_dec_in, labels, enc_mask, \ + dec_mask, enc_dec_mask, loss_mask + + +def make_attention_mask(source_block, target_block): + """ + Returns a 2-dimensional (2-D) attention mask + :param source_block: 1-D array + :param target_block: 1-D array + """ + mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) + mask = mask.astype(np.int64) + # (source_length, target_length) + return mask + + +def make_attention_mask_3d(source_block, target_block): + """ + Returns a 3-dimensional (3-D) attention mask + :param source_block: 1-D array + :param target_block: 1-D array + """ + mask = (target_block[:, None, :] >= 1) * (source_block[:, :, None] >= 1) + # (batch, source_length, target_length) + # mask = mask.astype(np.int64) + return mask + + +def make_history_mask(block): + length = block.shape[0] + arange = np.arange(length) + history_mask = (arange[None, ] <= arange[:, None]) + history_mask = history_mask.astype(np.int64) + return history_mask + + +def make_history_mask_3d(block): + batch, length = block.shape + arange = torch.arange(length, device=block.device) + history_mask = (arange[None, ] <= arange[:, None])[None, ] + history_mask = history_mask.expand(batch, length, length) + return history_mask diff --git a/megatron/data/test/test_indexed_dataset.py b/megatron/data/test/test_indexed_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..12fec8d8192d974d7016fa1c49dc799884deec76 --- /dev/null +++ b/megatron/data/test/test_indexed_dataset.py @@ -0,0 +1,125 @@ +# This file isn't really a formal automated test, it's just a place to +# put some code used during development and manual testing of +# indexed_dataset. + +from megatron.data import indexed_dataset +from megatron.tokenizer import build_tokenizer +import argparse +import os +import sys + +import torch + +script_dir = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.join(script_dir, "../../../")) + + +def test_indexed_dataset(args): + ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) + tokenizer = build_tokenizer(args) + print(len(ds.doc_idx)) + print(len(ds)) + print(ds.doc_idx[-1]) + if ds.supports_prefetch: + # just prefetch the whole thing in test (so assume it is small) + ds.prefetch(range(len(ds))) + if args.count > len(ds.doc_idx) - 1: + args.count = len(ds.doc_idx) - 1 + + for i in range(args.count): + start = ds.doc_idx[i] + end = ds.doc_idx[i + 1] + ids = ds[start:end] + print(f"Document {i}:") + print("--------------") + for s in ids: + assert len(s) > 0 + l = s.data.tolist() + text = tokenizer.detokenize(l) + print(text) + print("---") + + +def test_indexed_dataset_get(args): + ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) + tokenizer = build_tokenizer(args) + size = ds.sizes[0] + print(f"size: {size}") + full = ds.get(0) + print(full) + # print(tokenizer.detokenize(full.data.tolist())) + print("---") + end = ds.get(0, offset=size - 10) + print(end) + # print(tokenizer.detokenize(end.data.tolist())) + + start = ds.get(0, length=10) + print(start) + # print(tokenizer.detokenize(start.data.tolist())) + + part = ds.get(0, offset=2, length=8) + print(part) + # print(tokenizer.detokenize(part.data.tolist())) + +# def test_albert_dataset(args): +# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) +# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl) +# # ds = AlbertDataset(idataset, tokenizer) +# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl, +# args.epochs, args.max_num_samples, +# args.masked_lm_prob, args.seq_length, +# args.short_seq_prob, args.seed) +# truncated = 0 +# total = 0 +# for i, s in enumerate(ds): +# ids = s['text'] +# tokens = ds.tokenizer.convert_ids_to_tokens(ids) +# print(tokens) +# if i >= args.count-1: +# exit() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data', type=str, help='prefix to data files') + parser.add_argument('--dataset-impl', type=str, default='infer', + choices=['lazy', 'cached', 'mmap', 'infer']) + parser.add_argument('--count', type=int, default=10, + help='Number of samples/documents to print') + + group = parser.add_argument_group(title='tokenizer') + group.add_argument('--tokenizer-type', type=str, required=True, + choices=['BertWordPieceLowerCase', + 'GPT2BPETokenizer'], + help='What type of tokenizer to use.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to the vocab file') + group.add_argument('--merge-file', type=str, default=None, + help='Path to the BPE merge file (if necessary).') + + parser.add_argument('--epochs', type=int, default=5, + help='Number of epochs to plan for') + parser.add_argument('--max-num-samples', type=int, default=None, + help='Maximum number of samples to plan for') + parser.add_argument('--masked-lm-prob', type=float, default=0.15, + help='probability of masking tokens') + parser.add_argument('--seq-length', type=int, default=512, + help='maximum sequence length') + parser.add_argument('--short-seq-prob', type=float, default=0.1, + help='probability of creating a short sequence') + parser.add_argument('--seed', type=int, default=1234, + help='random seed') + args = parser.parse_args() + args.rank = 0 + args.make_vocab_size_divisible_by = 128 + args.tensor_model_parallel_size = 1 + + if args.dataset_impl == "infer": + args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) + +# test_albert_dataset(args) + test_indexed_dataset_get(args) + + +if __name__ == "__main__": + main() diff --git a/megatron/data/test/test_preprocess_data.sh b/megatron/data/test/test_preprocess_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..d121c85958ff35e37431befdceabb831c8cd2705 --- /dev/null +++ b/megatron/data/test/test_preprocess_data.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +IMPL=cached +python ../preprocess_data.py \ + --input test_samples.json \ + --vocab vocab.txt \ + --dataset-impl ${IMPL} \ + --output-prefix test_samples_${IMPL} \ + --workers 1 \ + --log-interval 2 diff --git a/megatron/data/vit_dataset.py b/megatron/data/vit_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5bbd1ef56260b0f106e19c7a774c5ca254d1f59f --- /dev/null +++ b/megatron/data/vit_dataset.py @@ -0,0 +1,262 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random +import numpy as np +import torch +import torchvision.transforms as T +from torchvision import datasets +from megatron import get_args +from megatron.data.image_folder import ImageFolder +from megatron.data.autoaugment import ImageNetPolicy +from megatron.data.data_samplers import RandomSeedDataset +from PIL import Image, ImageFilter, ImageOps + + +class GaussianBlur(object): + """ + Apply Gaussian Blur to the PIL image. + """ + def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): + self.prob = p + self.radius_min = radius_min + self.radius_max = radius_max + + def __call__(self, img): + do_it = random.random() <= self.prob + if not do_it: + return img + + return img.filter( + ImageFilter.GaussianBlur( + radius=random.uniform(self.radius_min, self.radius_max) + ) + ) + + +class Solarization(object): + """ + Apply Solarization to the PIL image. + """ + def __init__(self, p): + self.p = p + + def __call__(self, img): + if random.random() < self.p: + return ImageOps.solarize(img) + else: + return img + + +class ClassificationTransform(): + def __init__(self, image_size, train=True): + args = get_args() + assert args.fp16 or args.bf16 + self.data_type = torch.half if args.fp16 else torch.bfloat16 + if train: + self.transform = T.Compose([ + T.RandomResizedCrop(image_size), + T.RandomHorizontalFlip(), + T.ColorJitter(0.4, 0.4, 0.4, 0.1), + ImageNetPolicy(), + T.ToTensor(), + T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + T.ConvertImageDtype(self.data_type) + ]) + else: + self.transform = T.Compose([ + T.Resize(image_size), + T.CenterCrop(image_size), + T.ToTensor(), + T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + T.ConvertImageDtype(self.data_type) + ]) + + def __call__(self, input): + output = self.transform(input) + return output + + +class InpaintingTransform(): + def __init__(self, image_size, train=True): + + args = get_args() + self.mask_factor = args.mask_factor + self.mask_type = args.mask_type + self.image_size = image_size + self.patch_size = args.patch_dim + self.mask_size = int(self.mask_factor*(image_size[0]/self.patch_size)*(image_size[1]/self.patch_size)) + self.train = train + assert args.fp16 or args.bf16 + self.data_type = torch.half if args.fp16 else torch.bfloat16 + + if self.train: + self.transform = T.Compose([ + T.RandomResizedCrop(self.image_size), + T.RandomHorizontalFlip(), + T.ColorJitter(0.4, 0.4, 0.4, 0.1), + ImageNetPolicy(), + T.ToTensor(), + T.ConvertImageDtype(self.data_type) + ]) + else: + self.transform = T.Compose([ + T.Resize(self.image_size, interpolation=2), + T.CenterCrop(self.image_size), + T.ToTensor(), + T.ConvertImageDtype(self.data_type) + ]) + + def gen_mask(self, image_size, mask_size, mask_type, patch_size): + # output: mask as a list with indices for missing patches + action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]] + assert image_size[0] == image_size[1] + img_size_patch = image_size[0] // patch_size + + # drop masked patches + mask = torch.zeros((image_size[0], image_size[1]), dtype=torch.float) + + if mask_type == 'random': + x = torch.randint(0, img_size_patch, ()) + y = torch.randint(0, img_size_patch, ()) + for i in range(mask_size): + r = torch.randint(0, len(action_list), ()) + x = torch.clamp(x + action_list[r][0], min=0, max=img_size_patch - 1) + y = torch.clamp(y + action_list[r][1], min=0, max=img_size_patch - 1) + x_offset = x * patch_size + y_offset = y * patch_size + mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1 + else: + assert mask_type == 'row' + count = 0 + for x in reversed(range(img_size_patch)): + for y in reversed(range(img_size_patch)): + if (count < mask_size): + count += 1 + x_offset = x * patch_size + y_offset = y * patch_size + mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1 + return mask + + def __call__(self, input): + trans_input = self.transform(input) + mask = self.gen_mask(self.image_size, self.mask_size, + self.mask_type, self.patch_size) + mask = mask.unsqueeze(dim=0) + return trans_input, mask + + +class DinoTransform(object): + def __init__(self, image_size, train=True): + args = get_args() + self.data_type = torch.half if args.fp16 else torch.bfloat16 + + flip_and_color_jitter = T.Compose([ + T.RandomHorizontalFlip(p=0.5), + T.RandomApply( + [T.ColorJitter(brightness=0.4, contrast=0.4, + saturation=0.2, hue=0.1)], + p=0.8 + ), + T.RandomGrayscale(p=0.2), + ]) + + if args.fp16 or args.bf16: + normalize = T.Compose([ + T.ToTensor(), + T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + T.ConvertImageDtype(self.data_type) + ]) + else: + normalize = T.Compose([ + T.ToTensor(), + T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + + # first global crop + scale_const = 0.4 + self.global_transform1 = T.Compose([ + T.RandomResizedCrop(image_size, + scale=(scale_const, 1), + interpolation=Image.BICUBIC), + flip_and_color_jitter, + GaussianBlur(1.0), + normalize + ]) + # second global crop + self.global_transform2 = T.Compose([ + T.RandomResizedCrop(image_size, + scale=(scale_const, 1), + interpolation=Image.BICUBIC), + flip_and_color_jitter, + GaussianBlur(0.1), + Solarization(0.2), + normalize + ]) + # transformation for the local small crops + self.local_crops_number = args.dino_local_crops_number + self.local_transform = T.Compose([ + T.RandomResizedCrop(args.dino_local_img_size, + scale=(0.05, scale_const), + interpolation=Image.BICUBIC), + flip_and_color_jitter, + GaussianBlur(p=0.5), + normalize + ]) + + def __call__(self, image): + crops = [] + crops.append(self.global_transform1(image)) + crops.append(self.global_transform2(image)) + for _ in range(self.local_crops_number): + crops.append(self.local_transform(image)) + return crops + + +def build_train_valid_datasets(data_path, image_size=224): + args = get_args() + + if args.vision_pretraining_type == 'classify': + train_transform = ClassificationTransform(image_size) + val_transform = ClassificationTransform(image_size, train=False) + elif args.vision_pretraining_type == 'inpaint': + train_transform = InpaintingTransform(image_size, train=False) + val_transform = InpaintingTransform(image_size, train=False) + elif args.vision_pretraining_type == 'dino': + train_transform = DinoTransform(image_size, train=True) + val_transform = ClassificationTransform(image_size, train=False) + else: + raise Exception('{} vit pretraining type is not supported.'.format( + args.vit_pretraining_type)) + + # training dataset + train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2] + train_data = ImageFolder( + root=train_data_path, + transform=train_transform, + classes_fraction=args.classes_fraction, + data_per_class_fraction=args.data_per_class_fraction + ) + train_data = RandomSeedDataset(train_data) + + # validation dataset + val_data_path = data_path[1] + val_data = ImageFolder( + root=val_data_path, + transform=val_transform + ) + val_data = RandomSeedDataset(val_data) + + return train_data, val_data diff --git a/megatron/dist_signal_handler.py b/megatron/dist_signal_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..a60204f004a3f149da0bc059cd875b6ec390c0c4 --- /dev/null +++ b/megatron/dist_signal_handler.py @@ -0,0 +1,81 @@ +import signal + +import torch + + +def get_world_size(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + else: + world_size = 1 + return world_size + + +def get_device(local_rank=None): + backend = torch.distributed.get_backend() + if backend == 'nccl': + if local_rank is None: + device = torch.device('cuda') + else: + device = torch.device(f'cuda:{local_rank}') + elif backend == 'gloo': + device = torch.device('cpu') + else: + raise RuntimeError + return device + + +def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None): + if not torch.distributed.is_available() or \ + not torch.distributed.is_initialized(): + return [item] + + device = get_device(local_rank) + + if group is not None: + group_size = group.size() + else: + group_size = get_world_size() + + tensor = torch.tensor([item], device=device, dtype=dtype) + output_tensors = [ + torch.zeros(1, dtype=tensor.dtype, device=tensor.device) + for _ in range(group_size) + ] + torch.distributed.all_gather(output_tensors, tensor, group, async_op) + output = [elem.item() for elem in output_tensors] + return output + + +class DistributedSignalHandler: + def __init__(self, sig=signal.SIGTERM): + self.sig = sig + + def signals_received(self): + all_received = all_gather_item( + self._signal_received, dtype=torch.int32 + ) + return all_received + + def __enter__(self): + self._signal_received = False + self.released = False + self.original_handler = signal.getsignal(self.sig) + + def handler(signum, frame): + self._signal_received = True + + signal.signal(self.sig, handler) + + return self + + def __exit__(self, type, value, tb): + self.release() + + def release(self): + if self.released: + return False + + signal.signal(self.sig, self.original_handler) + self.released = True + return True diff --git a/megatron/fp16_deprecated/loss_scaler.py b/megatron/fp16_deprecated/loss_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..5f6f5e3af1e565db583c36ba8151a6a212caee46 --- /dev/null +++ b/megatron/fp16_deprecated/loss_scaler.py @@ -0,0 +1,39 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""For backward compatibility, we need the class definitions to deserialize.""" + +class LossScaler: + def __init__(self, scale=1): + self.cur_scale = scale + +class DynamicLossScaler: + def __init__(self, + init_scale=2**32, + scale_factor=2., + scale_window=1000, + min_scale=1, + delayed_shift=1, + consecutive_hysteresis=False): + self.cur_scale = init_scale + self.cur_iter = 0 + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + self.min_scale = min_scale + self.delayed_shift = delayed_shift + self.cur_hysteresis = delayed_shift + self.consecutive_hysteresis = consecutive_hysteresis + diff --git a/megatron/fused_kernels/__init__.py b/megatron/fused_kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d063e6686aafb90ab83adcd73f5fa25a56820a9 --- /dev/null +++ b/megatron/fused_kernels/__init__.py @@ -0,0 +1,125 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib +import subprocess + +from torch.utils import cpp_extension + +# Setting this param to a list has a problem of generating different +# compilation commands (with diferent order of architectures) and +# leading to recompilation of fused kernels. Set it to empty string +# to avoid recompilation and assign arch flags explicity in +# extra_cuda_cflags below +os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + +def load(args): + + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_major, _ = _get_cuda_bare_metal_version( + cpp_extension.CUDA_HOME) + if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_80,code=sm_80') + + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / 'build' + _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=['-O3',], + extra_cuda_cflags=['-O3', + '-gencode', 'arch=compute_70,code=sm_70', + '--use_fast_math'] + extra_cuda_flags + cc_flag, + verbose=(args.rank == 0) + ) + + # ============== + # Fused softmax. + # ============== + + if args.masked_softmax_fusion: + extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '--expt-relaxed-constexpr', + '--expt-extended-lambda'] + + # Upper triangular softmax. + sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', + srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'] + scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper( + "scaled_upper_triang_masked_softmax_cuda", + sources, extra_cuda_flags) + + # Masked softmax. + sources=[srcpath / 'scaled_masked_softmax.cpp', + srcpath / 'scaled_masked_softmax_cuda.cu'] + scaled_masked_softmax_cuda = _cpp_extention_load_helper( + "scaled_masked_softmax_cuda", sources, extra_cuda_flags) + + # Softmax + sources=[srcpath / 'scaled_softmax.cpp', + srcpath / 'scaled_softmax_cuda.cu'] + scaled_softmax_cuda = _cpp_extention_load_helper( + "scaled_softmax_cuda", sources, extra_cuda_flags) + + # ================================= + # Mixed precision fused layer norm. + # ================================= + + extra_cuda_flags = ['-maxrregcount=50'] + sources=[srcpath / 'layer_norm_cuda.cpp', + srcpath / 'layer_norm_cuda_kernel.cu'] + fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper( + "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags) + + # ================================= + # Fused gradient accumulation to weight gradient computation of linear layer + # ================================= + + if args.gradient_accumulation_fusion: + sources=[srcpath / 'fused_weight_gradient_dense.cpp', + srcpath / 'fused_weight_gradient_dense.cu'] + fused_dense_cuda = _cpp_extention_load_helper( + "fused_dense_cuda", sources, []) + + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], + universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") diff --git a/megatron/fused_kernels/build/.ninja_deps b/megatron/fused_kernels/build/.ninja_deps new file mode 100644 index 0000000000000000000000000000000000000000..779cf1abf49115f00ed9a9770fbc30f6c2d5d97c Binary files /dev/null and b/megatron/fused_kernels/build/.ninja_deps differ diff --git a/megatron/fused_kernels/build/.ninja_log b/megatron/fused_kernels/build/.ninja_log new file mode 100644 index 0000000000000000000000000000000000000000..3c84b93455e7b7490147e1297b3ef1d84fdf6e45 --- /dev/null +++ b/megatron/fused_kernels/build/.ninja_log @@ -0,0 +1,99 @@ +# ninja log v5 +0 10410 1666322917541195629 scaled_upper_triang_masked_softmax.o 3f7d6908014e2c4b +0 44996 1666322952117194853 scaled_upper_triang_masked_softmax_cuda.cuda.o 63e21f5b42e40bd0 +44999 45365 1666322952493194844 scaled_upper_triang_masked_softmax_cuda.so f84b117fa81963e5 +1 10116 1666322962741194614 scaled_masked_softmax.o 9dfad674b44501f2 +1 45749 1666322998365193814 scaled_masked_softmax_cuda.cuda.o aec0e3e7e0fe0af5 +45749 46118 1666322998741193805 scaled_masked_softmax_cuda.so a24b9ad6a01f9db5 +0 10316 1666323009189193571 scaled_softmax.o ba78446f9188fae0 +0 44538 1666323043401192802 scaled_softmax_cuda.cuda.o 77ea362b8652c7e9 +44538 44904 1666323043777192794 scaled_softmax_cuda.so d8fa0ebfd78e8bd8 +0 10918 1666323054829192545 layer_norm_cuda.o 7dc5869ac5593422 +0 11891 1666323055797192524 layer_norm_cuda_kernel.cuda.o 13d0d213fbbb62de +11891 12255 1666323056165192515 fused_mix_prec_layer_norm_cuda.so a21986e1b00b3401 +0 10072 1666682710301113263 scaled_upper_triang_masked_softmax.o c11d897e7800befb +0 46206 1666682746425112452 scaled_upper_triang_masked_softmax_cuda.cuda.o bc610e36d8dfd435 +46206 46587 1666682746813112443 scaled_upper_triang_masked_softmax_cuda.so f84b117fa81963e5 +0 9858 1666682756829112218 scaled_masked_softmax.o fedebad209ed2d21 +0 46362 1666682793321111399 scaled_masked_softmax_cuda.cuda.o 51814239e7caea9a +46362 46747 1666682793717111390 scaled_masked_softmax_cuda.so a24b9ad6a01f9db5 +0 9870 1666682803741111164 scaled_softmax.o 1d9e3231fe352c0b +0 46512 1666682840373110342 scaled_softmax_cuda.cuda.o f9b5a976cff0a5ef +46513 46900 1666682840769110333 scaled_softmax_cuda.so d8fa0ebfd78e8bd8 +0 10615 1666682851533110091 layer_norm_cuda.o 3cacb26d8faa2b99 +0 11849 1666682852761110063 layer_norm_cuda_kernel.cuda.o 319de99ce0920143 +11849 12230 1666682853145110055 fused_mix_prec_layer_norm_cuda.so a21986e1b00b3401 +0 12428 1666750089507718000 scaled_upper_triang_masked_softmax.o 8e61e453c7b77ff5 +0 46226 1666750123300534000 scaled_upper_triang_masked_softmax_cuda.cuda.o 193ac2a539f3f292 +46228 47144 1666750124224556000 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4 +0 11486 1666750135832836000 scaled_masked_softmax.o b6378099b518f069 +0 47221 1666750171561699000 scaled_masked_softmax_cuda.cuda.o 4faef3fa30fe1e1d +47223 48124 1666750172469721000 scaled_masked_softmax_cuda.so d6611febaa933d3d +0 11564 1666750184410010000 scaled_softmax.o a90db6c821074406 +0 46461 1666750219302852000 scaled_softmax_cuda.cuda.o bf0ec8bfec64157c +46464 47488 1666750220334877000 scaled_softmax_cuda.so e7199387ed26e64e +0 12007 1666750232439170000 layer_norm_cuda.o 6a55ca87d1a8c0b2 +0 12020 1666750232447170000 layer_norm_cuda_kernel.cuda.o 655c31ba3cbc10c2 +12022 13866 1666750234299214000 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da +0 10929 1666856346511840836 scaled_upper_triang_masked_softmax.o 5a6ff0631bbc2735 +0 43626 1666856379199840197 scaled_upper_triang_masked_softmax_cuda.cuda.o c36b22f9d9fc2117 +43626 44026 1666856379607840189 scaled_upper_triang_masked_softmax_cuda.so f84b117fa81963e5 +0 10190 1666856389955839987 scaled_masked_softmax.o e032e1ed28e01e30 +0 44582 1666856424331839315 scaled_masked_softmax_cuda.cuda.o 2c8db7df38489475 +44582 44961 1666856424723839308 scaled_masked_softmax_cuda.so a24b9ad6a01f9db5 +0 10142 1666856435015839107 scaled_softmax.o 446947b66b18fa33 +0 44480 1666856469343838436 scaled_softmax_cuda.cuda.o 4fb07733497ecc29 +44480 44879 1666856469751838428 scaled_softmax_cuda.so d8fa0ebfd78e8bd8 +0 10396 1666856480295838222 layer_norm_cuda.o 73da6101a07a24a7 +1 11899 1666856481791838193 layer_norm_cuda_kernel.cuda.o 9ec8eab79e592ff4 +11899 12298 1666856482195838185 fused_mix_prec_layer_norm_cuda.so a21986e1b00b3401 +1 12100 1666925285098117232 scaled_upper_triang_masked_softmax.o f73d4cea858af8b1 +2 45529 1666925318557711559 scaled_upper_triang_masked_softmax_cuda.cuda.o 2fa8c20456ca471c +45531 47926 1666925320961971387 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4 +1 12395 1666925333323307273 scaled_masked_softmax.o ae4878f941317da4 +1 46831 1666925368027057696 scaled_masked_softmax_cuda.cuda.o 94fda8fac85d606d +46833 48525 1666925369731241867 scaled_masked_softmax_cuda.so d6611febaa933d3d +1 12276 1666925382212590722 scaled_softmax.o 7a00c61166684714 +1 47263 1666925417188370543 scaled_softmax_cuda.cuda.o dfe840df0dc3178d +47265 50987 1666925420916773471 scaled_softmax_cuda.so e7199387ed26e64e +1 13036 1666925434150203603 layer_norm_cuda_kernel.cuda.o 128560bba544b6cb +1 14561 1666925435354333732 layer_norm_cuda.o e02c8859a84e70db +14569 15455 1666925436574465592 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da +1 62343 1666962134003529000 scaled_upper_triang_masked_softmax_cuda.cuda.o abba0fca57f22344 +62363 63833 1666962135511529000 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4 +0 12215 1667460191194232000 scaled_upper_triang_masked_softmax.o 96d879ae2bf7b993 +1 49211 1667460228190231000 scaled_upper_triang_masked_softmax_cuda.cuda.o bc4b370b3c3d5c9e +49213 50896 1667460229886231000 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4 +0 13297 1667460243334231000 scaled_masked_softmax.o a6809b1177c7ca02 +1 50055 1667460280082230000 scaled_masked_softmax_cuda.cuda.o dfbe364852f092fc +50057 65422 1667460295430230000 scaled_masked_softmax_cuda.so d6611febaa933d3d +1 12055 1667460307682230000 scaled_softmax.o cd4f40829964c3cb +1 48489 1667460344126229000 scaled_softmax_cuda.cuda.o a6917d3b3ea80f97 +48526 49856 1667460345502229000 scaled_softmax_cuda.so e7199387ed26e64e +0 13966 1667460359626229000 layer_norm_cuda.o e644ccb47b3615c +1 15506 1667460361158229000 layer_norm_cuda_kernel.cuda.o 2d32cb24bea852c7 +15509 40152 1667460385810228000 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da +0 11006 1676876089228927885 scaled_upper_triang_masked_softmax.o 9c8f7b7399ab2d1f +0 42047 1676876120260456898 scaled_upper_triang_masked_softmax_cuda.cuda.o 12331cadf47cb899 +42047 42264 1676876120488482829 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4 +1 10823 1676876131453729835 scaled_masked_softmax.o 6e7c1e4df8bc11a8 +1 47902 1676876168521945360 scaled_masked_softmax_cuda.cuda.o c86371ad5d3e19ff +47902 48087 1676876168713967197 scaled_masked_softmax_cuda.so d6611febaa933d3d +0 13926 1676876182787567696 scaled_softmax.o eeaf8300d7ba52f3 +0 42044 1676876210894764143 scaled_softmax_cuda.cuda.o 379540e7c9ee343a +42044 42242 1676876211098787345 scaled_softmax_cuda.so e7199387ed26e64e +0 11125 1676876222348066652 layer_norm_cuda_kernel.cuda.o 5fa2a5b112be408c +0 11375 1676876222600095314 layer_norm_cuda.o e369f8c0d20bc213 +11375 11572 1676876222800118061 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da +2 12764 1682404121987049000 scaled_upper_triang_masked_softmax.o 20d2ceb970d5c522 +3 48533 1682404157745832000 scaled_upper_triang_masked_softmax_cuda.cuda.o e3bb3890927e826b +48592 49557 1682404158781913000 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4 +2 13484 1682404172586987000 scaled_masked_softmax.o 8b2e2a7ca1fd841b +2 48141 1682404207237685000 scaled_masked_softmax_cuda.cuda.o 6215ee08e9bfd383 +48143 49151 1682404208257765000 scaled_masked_softmax_cuda.so d6611febaa933d3d +2 12243 1682404220754738000 scaled_softmax.o 2842472d594d0d1d +2 51180 1682404259661769000 scaled_softmax_cuda.cuda.o 78419d022eea9ad2 +51184 52124 1682404260637845000 scaled_softmax_cuda.so e7199387ed26e64e +2 12762 1682404273698863000 layer_norm_cuda_kernel.cuda.o dec411c038eb6254 +1 13361 1682404274238905000 layer_norm_cuda.o 14db0d087e6f7321 +13415 14196 1682404275138975000 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da diff --git a/megatron/fused_kernels/build/build.ninja b/megatron/fused_kernels/build/build.ninja new file mode 100644 index 0000000000000000000000000000000000000000..151e14747301367fa7b31ec93c884aa5a79d763f --- /dev/null +++ b/megatron/fused_kernels/build/build.ninja @@ -0,0 +1,28 @@ +ninja_required_version = 1.3 +cxx = c++ +nvcc = /usr/local/cuda/bin/nvcc + +cflags = -DTORCH_EXTENSION_NAME=fused_mix_prec_layer_norm_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -isystem /opt/conda/lib/python3.8/site-packages/torch/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++14 -O3 +post_cflags = +cuda_cflags = -DTORCH_EXTENSION_NAME=fused_mix_prec_layer_norm_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -isystem /opt/conda/lib/python3.8/site-packages/torch/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -O3 -gencode arch=compute_70,code=sm_70 --use_fast_math -maxrregcount=50 -gencode arch=compute_80,code=sm_80 -std=c++14 +cuda_post_cflags = +ldflags = -shared -L/opt/conda/lib/python3.8/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart + +rule compile + command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags + depfile = $out.d + deps = gcc + +rule cuda_compile + command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags + +rule link + command = $cxx $in $ldflags -o $out + +build layer_norm_cuda.o: compile /root/ouyangxuan/project/big_model_finetune/Megatrion-LM-clear/megatron/fused_kernels/layer_norm_cuda.cpp +build layer_norm_cuda_kernel.cuda.o: cuda_compile /root/ouyangxuan/project/big_model_finetune/Megatrion-LM-clear/megatron/fused_kernels/layer_norm_cuda_kernel.cu + +build fused_mix_prec_layer_norm_cuda.so: link layer_norm_cuda.o layer_norm_cuda_kernel.cuda.o + +default fused_mix_prec_layer_norm_cuda.so + diff --git a/megatron/fused_kernels/build/fused_mix_prec_layer_norm_cuda.so b/megatron/fused_kernels/build/fused_mix_prec_layer_norm_cuda.so new file mode 100644 index 0000000000000000000000000000000000000000..1f60d28e71d491526983323f2b9f43a8fb95e016 Binary files /dev/null and b/megatron/fused_kernels/build/fused_mix_prec_layer_norm_cuda.so differ diff --git a/megatron/fused_kernels/build/layer_norm_cuda.o b/megatron/fused_kernels/build/layer_norm_cuda.o new file mode 100644 index 0000000000000000000000000000000000000000..1449ace91b68002e17700f8ec3b5c6a54ece69af Binary files /dev/null and b/megatron/fused_kernels/build/layer_norm_cuda.o differ diff --git a/megatron/fused_kernels/build/layer_norm_cuda_kernel.cuda.o b/megatron/fused_kernels/build/layer_norm_cuda_kernel.cuda.o new file mode 100644 index 0000000000000000000000000000000000000000..8eb43c673be17ab3044f8abbf643950558de268c Binary files /dev/null and b/megatron/fused_kernels/build/layer_norm_cuda_kernel.cuda.o differ diff --git a/megatron/fused_kernels/build/scaled_masked_softmax.o b/megatron/fused_kernels/build/scaled_masked_softmax.o new file mode 100644 index 0000000000000000000000000000000000000000..0b4a21081239561958ea7a63c4d367651fdeabca Binary files /dev/null and b/megatron/fused_kernels/build/scaled_masked_softmax.o differ diff --git a/megatron/fused_kernels/build/scaled_masked_softmax_cuda.cuda.o b/megatron/fused_kernels/build/scaled_masked_softmax_cuda.cuda.o new file mode 100644 index 0000000000000000000000000000000000000000..1f5d46a051035fc48cf1d86f677bea816b38d1f6 --- /dev/null +++ b/megatron/fused_kernels/build/scaled_masked_softmax_cuda.cuda.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:23117ca78a4427a192781d5bff9b2ddc34711ec3c57f6bd5c7c4b7d3b634e429 +size 1196624 diff --git a/megatron/fused_kernels/build/scaled_masked_softmax_cuda.so b/megatron/fused_kernels/build/scaled_masked_softmax_cuda.so new file mode 100644 index 0000000000000000000000000000000000000000..72236430997eaa802afc622fa7d35b6703afff7d --- /dev/null +++ b/megatron/fused_kernels/build/scaled_masked_softmax_cuda.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:409dbcba44e37f1e791e8ce1cb82f4a6884c6eb68aa281311a675462e91762ea +size 1283032 diff --git a/megatron/fused_kernels/build/scaled_softmax.o b/megatron/fused_kernels/build/scaled_softmax.o new file mode 100644 index 0000000000000000000000000000000000000000..8d9d45d9c44ff49be542fc34f383c9e006806a55 Binary files /dev/null and b/megatron/fused_kernels/build/scaled_softmax.o differ diff --git a/megatron/fused_kernels/build/scaled_softmax_cuda.cuda.o b/megatron/fused_kernels/build/scaled_softmax_cuda.cuda.o new file mode 100644 index 0000000000000000000000000000000000000000..293f54062c446e9bf1fbd78ee5018a693fce1bf4 --- /dev/null +++ b/megatron/fused_kernels/build/scaled_softmax_cuda.cuda.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c2e292b61e4060b6a2c263bd95dcc08c4ab9c29ae4cf74df98bbd2ad4b566ee +size 1084024 diff --git a/megatron/fused_kernels/build/scaled_softmax_cuda.so b/megatron/fused_kernels/build/scaled_softmax_cuda.so new file mode 100644 index 0000000000000000000000000000000000000000..6a0b35a8fe5c38b8af004c67ba37e855cf20027e --- /dev/null +++ b/megatron/fused_kernels/build/scaled_softmax_cuda.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c0b73baf4b0a10ccf124c4e20aabefc122e2f8b0b0887dccfb7eafe3cd5e39c +size 1170600 diff --git a/megatron/fused_kernels/build/scaled_upper_triang_masked_softmax.o b/megatron/fused_kernels/build/scaled_upper_triang_masked_softmax.o new file mode 100644 index 0000000000000000000000000000000000000000..be73129b2266e1d57cf64fd0164e287bf88ae1f3 Binary files /dev/null and b/megatron/fused_kernels/build/scaled_upper_triang_masked_softmax.o differ diff --git a/megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.cuda.o b/megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.cuda.o new file mode 100644 index 0000000000000000000000000000000000000000..2c6a5caf2bdab246f29dace54cf3b228f563b5a8 Binary files /dev/null and b/megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.cuda.o differ diff --git a/megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.so b/megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.so new file mode 100644 index 0000000000000000000000000000000000000000..a2b1d9dfe626cb98538b4095c3d100a44e81fc55 --- /dev/null +++ b/megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e1b57adfdbb4254303f89ad4c8f786f9f7b4516c8fa2e95339ae5177a69e4a5 +size 1032720 diff --git a/megatron/fused_kernels/compat.h b/megatron/fused_kernels/compat.h new file mode 100644 index 0000000000000000000000000000000000000000..92e7eb7723bdc9a7cf5d65bda3ba45775ef8c589 --- /dev/null +++ b/megatron/fused_kernels/compat.h @@ -0,0 +1,31 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + + + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/megatron/fused_kernels/fused_weight_gradient_dense.cpp b/megatron/fused_kernels/fused_weight_gradient_dense.cpp new file mode 100644 index 0000000000000000000000000000000000000000..194ee59353d2a8c9da24e50c592f4e086806d078 --- /dev/null +++ b/megatron/fused_kernels/fused_weight_gradient_dense.cpp @@ -0,0 +1,47 @@ +#include +#include + +#include +#include + +#include "type_shim.h" + + +template +int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); + +void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at::Tensor d_weight) { + at::Tensor input_2d, d_output_2d; + // input tensor: collapse to the first dim + auto in_sizes = input.sizes(); + if (input.dim() > 2) { + input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]}); + } else { + input_2d = input; + } + // d_output tensor: collapse to the first dim + auto d_out_sizes = d_output.sizes(); + if (d_output.dim() > 2) { + d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]}); + } else { + d_output_2d = d_output; + } + + int hidden_dim = input_2d.size(0); + int in_dim = input_2d.size(1); + int out_dim = d_weight.size(0); + + DISPATCH_HALF_BFLOAT_AND_FLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32", + int result = wgrad_gemm_accum_fp32_cuda( + input_2d.data_ptr(), + d_output_2d.data_ptr(), + d_weight.data_ptr(), + in_dim, + hidden_dim, + out_dim); + ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32, "wgrad gemm accum in fp32"); +} diff --git a/megatron/fused_kernels/fused_weight_gradient_dense.cu b/megatron/fused_kernels/fused_weight_gradient_dense.cu new file mode 100644 index 0000000000000000000000000000000000000000..7dc10e65d37e531b54d2c875011f9c47eeb8ff7f --- /dev/null +++ b/megatron/fused_kernels/fused_weight_gradient_dense.cu @@ -0,0 +1,157 @@ +#include +#include +#include +#include +#include +#include +#include + +/* Includes, cuda */ +#include +#include + + +// BF16 Tensor core wrapper around cublas GEMMEx +cublasStatus_t gemmex_wrapper( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + at::BFloat16* A, + int lda, + at::BFloat16* B, + int ldb, + const float* beta, + float* C, + int ldc) { + return cublasGemmEx( + handle, + transa, + transb, + m, + n, + k, + alpha, + A, + CUDA_R_16BF, + lda, + B, + CUDA_R_16BF, + ldb, + beta, + C, + CUDA_R_32F, + ldc, + CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// FP16 Tensor core wrapper around cublas GEMMEx +cublasStatus_t gemmex_wrapper( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + at::Half* A, + int lda, + at::Half* B, + int ldb, + const float* beta, + float* C, + int ldc) { + return cublasGemmEx( + handle, + transa, + transb, + m, + n, + k, + alpha, + A, + CUDA_R_16F, + lda, + B, + CUDA_R_16F, + ldb, + beta, + C, + CUDA_R_32F, + ldc, + CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// FP32 Tensor core wrapper around cublas GEMMEx +cublasStatus_t gemmex_wrapper( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + float* A, + int lda, + float* B, + int ldb, + const float* beta, + float* C, + int ldc) { + return cublasGemmEx( + handle, + transa, + transb, + m, + n, + k, + alpha, + A, + CUDA_R_32F, + lda, + B, + CUDA_R_32F, + ldb, + beta, + C, + CUDA_R_32F, + ldc, + CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +template +int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cudaStream_t stream; + cublasGetStream(handle, &stream); + const float alpha = 1.0; + const float beta = 1.0; + int status = 1; + + status = gemmex_wrapper( + handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + in_dim, + out_dim, + hidden_dim, + &alpha, + input, + in_dim, + d_output, + out_dim, + &beta, + d_weight, + in_dim); + return status; +} + +template int wgrad_gemm_accum_fp32_cuda(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); +template int wgrad_gemm_accum_fp32_cuda(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); +template int wgrad_gemm_accum_fp32_cuda(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); diff --git a/megatron/fused_kernels/layer_norm_cuda.cpp b/megatron/fused_kernels/layer_norm_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f28e7b4ad361b6f51579b35a77c7c0a42ab2841 --- /dev/null +++ b/megatron/fused_kernels/layer_norm_cuda.cpp @@ -0,0 +1,201 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include +#include +#include +#include "compat.h" + +namespace { + +void compute_n1_n2( + at::Tensor input, + at::IntArrayRef normalized_shape, + int& n1, + int& n2) { + int idiff = input.ndimension() - normalized_shape.size(); + n2 = 1; + for (int i = 0; i < (int)normalized_shape.size(); ++i) { + assert( input.sizes()[i+idiff] == normalized_shape[i] ); + n2 *= normalized_shape[i]; + } + n1 = 1; + for (int i = 0; i < idiff; ++i) { + n1 *= input.sizes()[i]; + } +} + +void check_args( + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta + ) +{ + TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); + TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); +} + +void check_args( + at::Tensor input, + at::IntArrayRef normalized_shape, + int& n1, + int& n2 + ) +{ + int64_t normalized_ndim = normalized_shape.size(); + + if (normalized_ndim < 1) { + std::stringstream ss; + ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " + << "containing at least one element, but got normalized_shape=" + << normalized_shape; + throw std::runtime_error(ss.str()); + } + + auto input_shape = input.sizes(); + auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + throw std::runtime_error(ss.str()); + } + + compute_n1_n2(input,normalized_shape,n1,n2); +} + + +void check_args( + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + int& n1, + int& n2 + ) +{ + check_args(input,normalized_shape,n1,n2); + check_args(normalized_shape,gamma,beta); +} +} + +void cuda_layer_norm( + at::Tensor* output, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + at::IntArrayRef normalized_shape, + at::Tensor* gamma, + at::Tensor* beta, + double epsilon); + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector layer_norm_affine( + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + double epsilon) { + + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor output = at::empty_like( + input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor mean = at::empty( + {n1}, input.options().dtype(at::ScalarType::Float)); + at::Tensor invvar = at::empty_like(mean); + + cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon); + + return {output, mean, invvar}; + +} + + +void cuda_layer_norm_gradient( + at::Tensor* dout, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + at::IntArrayRef normalized_shape, + at::Tensor* gamma, + at::Tensor* beta, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + at::Tensor* grad_beta + ); + +std::vector layer_norm_gradient_affine( + at::Tensor dout, + at::Tensor mean, + at::Tensor invvar, + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + double epsilon) { + + CHECK_INPUT(dout); + CHECK_INPUT(mean); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor grad_input = at::empty_like(input); + at::Tensor grad_gamma = at::empty_like(gamma); + at::Tensor grad_beta = at::empty_like(beta); + + cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon, + &grad_input, &grad_gamma, &grad_beta); + + return {grad_input, grad_gamma, grad_beta}; + +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_affine", &layer_norm_affine, + "LayerNorm forward (CUDA)"); + m.def("backward_affine", &layer_norm_gradient_affine, + "LayerNorm backward (CUDA)"); +} diff --git a/megatron/fused_kernels/layer_norm_cuda_kernel.cu b/megatron/fused_kernels/layer_norm_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..91d533191506b939aae0fd304f0bb7e8c89895c2 --- /dev/null +++ b/megatron/fused_kernels/layer_norm_cuda_kernel.cu @@ -0,0 +1,832 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include "ATen/ATen.h" +#include "ATen/AccumulateType.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/DeviceUtils.cuh" + +#include +#include + +#include "type_shim.h" + +template __device__ +void cuWelfordOnlineSum( + const U curr, + U& mu, + U& sigma2, + U& count) +{ + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template __device__ +void cuChanOnlineSum( + const U muB, + const U sigma2B, + const U countB, + U& mu, + U& sigma2, + U& count) +{ + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA*mu + nB*muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template __device__ +void cuWelfordMuSigma2( + const T* __restrict__ vals, + const int n1, + const int n2, + const int i1, + U& mu, + U& sigma2, + U* buf) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu= U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1*n2; + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l+k]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1<(muB,sigma2B,countB,mu,sigma2,count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = (U*)buf; + U* ibuf = (U*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2*wrt_y] = mu; + ubuf[2*wrt_y+1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U muB = ubuf[2*threadIdx.y]; + U sigma2B = ubuf[2*threadIdx.y+1]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1]/U(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2/U(n2), 0); + } + } +} + +template<> __device__ +void cuWelfordMuSigma2( + const at::Half* __restrict__ vals, + const int n1, + const int n2, + const int i1, + float& mu, + float& sigma2, + float* buf) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu= float(0); + sigma2 = float(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const at::Half* lvals = vals + i1*n2; + int l = 8*thrx; + if ((((size_t)lvals)&3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l+7 < n2; l+=8*numx) { + for (int k = 0; k < 8; k+=2) { + float2 curr = __half22float2(*((__half2*)(lvals+l+k))); + cuWelfordOnlineSum(curr.x,mu,sigma2,count); + cuWelfordOnlineSum(curr.y,mu,sigma2,count); + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1< 1) { + float* ubuf = (float*)buf; + float* ibuf = (float*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2*wrt_y] = mu; + ubuf[2*wrt_y+1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float muB = ubuf[2*threadIdx.y]; + float sigma2B = ubuf[2*threadIdx.y+1]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1]/float(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2/float(n2), 0); + } + } +} + +template U rsqrt(U v) { + return U(1) / sqrt(v); +} +template<> float rsqrt(float v) { + return rsqrtf(v); +} +template<> double rsqrt(double v) { + return rsqrt(v); +} + +namespace { +// This is the un-specialized struct. Note that we prevent instantiation of this +// struct by putting an undefined symbol in the function body so it won't compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory +{ + __device__ float *getPointer() + { + extern __shared__ float s_float[]; + return s_float; + } +}; + +} + +template __global__ +void cuApplyLayerNorm( + V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta + ) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U* buf = shared.getPointer(); + U mu,sigma2; + cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf); + const T* lvals = vals + i1*n2; + V* ovals = output_vals + i1*n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && beta != NULL) { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } + } else { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + ovals[i] = static_cast(c_invvar * (curr - mu)); + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + mean[i1] = mu; + invvar[i1] = c_invvar; + } + __syncthreads(); + } +} + +template __device__ +void cuLoadWriteStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } +} + +template __device__ +void cuLoadAddStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + } + } + } +} + +template __global__ +void cuComputePartGradGammaBeta( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + U* part_grad_gamma, + U* part_grad_beta) +{ + const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y; + const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x+1; + const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1); + const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = (U*)buf; + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); + for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k*blockDim.y; + int idx1 = row1*row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; + warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y/2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template __global__ +void cuComputeGradGammaBeta( + const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, + const int n1, + const int n2, + V* grad_gamma, + V* grad_beta) +{ + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset*n2]; + sum_beta += part_grad_beta_ptr[warp_offset*n2]; + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y/2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx+nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx+nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + +template __global__ +void cuComputeGradInput( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + const V* gamma, + T* grad_input) +{ + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + const U c_mean = mean[i1]; + const U c_invvar = invvar[i1]; + const T* k_input = input + i1*n2; + const V* k_dout = dout + i1*n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + sum_loss1 += c_loss * gamma[l+k]; + sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } + } else { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + // intra-warp reductions + for (int mask = blockDim.x/2; mask > 0; mask /= 2) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[2*wrt_i] = sum_loss1; + buf[2*wrt_i+1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + sum_loss1 += buf[2*read_i]; + sum_loss2 += buf[2*read_i+1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + buf[2*threadIdx.x] = sum_loss1; + buf[2*threadIdx.x+1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y !=0) { + sum_loss1 = buf[2*threadIdx.x]; + sum_loss2 = buf[2*threadIdx.x+1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1*n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + // prevent race where buf is written again before reads are done + __syncthreads(); + } +} + + + + +template +void HostApplyLayerNorm( + V* output, + U* mean, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma, + const V* beta + ) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const dim3 threads(32,4,1); + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? + threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : + 0; + cuApplyLayerNorm<<>>( + output, + mean, + invvar, + input, + n1,n2, + U(epsilon), + gamma,beta); +} + + +void cuda_layer_norm( + at::Tensor* output, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + double epsilon) +{ + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", + HostApplyLayerNorm( + output->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input->DATA_PTR(), + n1,n2, + epsilon, + gamma != NULL ? gamma->DATA_PTR() : NULL, + beta != NULL ? beta->DATA_PTR() : NULL); + ) +} + + +template +void HostLayerNormGradient( + const V* dout, + const U* mean, + const U* invvar, + at::Tensor* input, + int n1, + int n2, + const V* gamma, + const V* beta, + double epsilon, + T* grad_input, + V* grad_gamma, + V* grad_beta + ) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL && beta != NULL) { + // compute grad_gamma(j) and grad_beta(j) + const int part_size = 16; + const dim3 threads2(32,4,1); + const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); + const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * + (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + at::Tensor part_grad_gamma = at::empty( + {part_size,n2}, input->options().dtype(at::ScalarType::Float)); + at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); + cuComputePartGradGammaBeta<<>>( + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR()); + + const dim3 threads3(32,8,1); + const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR(), + part_size, + n1,n2, + grad_gamma, + grad_beta); + } + + // compute grad_input + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32,4,1); + int nshared = + threads1.y > 1 ? + threads1.y*threads1.x*sizeof(U) : + 0; + cuComputeGradInput<<>>( + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + grad_input); +} + + +void cuda_layer_norm_gradient( + at::Tensor* dout, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + at::Tensor* grad_beta) +{ + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma->scalar_type(), + "cuda_layer_norm_gradient_kernel", + HostLayerNormGradient( + dout->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input, + n1,n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->DATA_PTR() : NULL, + gamma != NULL ? beta->DATA_PTR() : NULL, + epsilon, + grad_input->DATA_PTR(), + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + gamma != NULL ? grad_beta->DATA_PTR() : NULL); + ) +} diff --git a/megatron/fused_kernels/scaled_masked_softmax.cpp b/megatron/fused_kernels/scaled_masked_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1852aee6fdaacbf2bb9258968ce0405817b1c650 --- /dev/null +++ b/megatron/fused_kernels/scaled_masked_softmax.cpp @@ -0,0 +1,97 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads); + +torch::Tensor fwd( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) { + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + + return fwd_cuda(input, mask, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +int get_batch_per_block( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +} + +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + + m.def("backward", + &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); + + m.def("get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, + "Return Batch per block size." + ); +} diff --git a/megatron/fused_kernels/scaled_masked_softmax.h b/megatron/fused_kernels/scaled_masked_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..e57fd04c62ab55e01e00bc2fb83365358effe6c4 --- /dev/null +++ b/megatron/fused_kernels/scaled_masked_softmax.h @@ -0,0 +1,717 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + */ +template +__global__ void scaled_softmax_warp_forward( + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const acc_t scale, + int micro_batch_size, + int element_count, + int pad_batches) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + +template +__global__ void scaled_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); + } + } + } +} +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; +} + +template +void dispatch_scaled_softmax_forward( + output_t *dst, + const input_t *src, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 12: // 4096 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + int pad_batches) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 12: // 4096 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count/batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 12: // 4096 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + + default: + break; + } + } +} diff --git a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..2efee39a6d293827f6deb3dc3185ef5fd535c9ce --- /dev/null +++ b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu @@ -0,0 +1,117 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +} + + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 4096); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + pad_batches); + ); + return softmax_results; +} + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/megatron/fused_kernels/scaled_softmax.cpp b/megatron/fused_kernels/scaled_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e89b39f6a842a16489f56bc0ac72378c4d915a66 --- /dev/null +++ b/megatron/fused_kernels/scaled_softmax.cpp @@ -0,0 +1,75 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd( + torch::Tensor const& input, + float scale_factor) { + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_softmax::fwd, + "Self Multihead Attention scaled, softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_softmax::bwd, + "Self Multihead Attention scaled, softmax -- Backward."); +} + diff --git a/megatron/fused_kernels/scaled_softmax_cuda.cu b/megatron/fused_kernels/scaled_softmax_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..efc0822481250ea4360f57d479bf35a847941a05 --- /dev/null +++ b/megatron/fused_kernels/scaled_softmax_cuda.cu @@ -0,0 +1,104 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 4096); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_softmax_forward", + dispatch_scaled_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + return softmax_results; +} + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} + diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ea283588db2d4aef256c2ab25533630170fe416b --- /dev/null +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp @@ -0,0 +1,72 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_upper_triang_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); +} diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..6df83fc1037dded422b60d746afd76a311e4eacf --- /dev/null +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h @@ -0,0 +1,513 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Implicit time (diagonal masking) + */ +template +__global__ void scaled_upper_triang_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it+element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } + } + copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } + } + } +} + +template +__global__ void scaled_upper_triang_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_scaled_upper_triang_masked_softmax_forward( + output_t *dst, + const input_t *src, + const input_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..5efc3d41280cf7b4f2fdf6afbe1271ffbe21037c --- /dev/null +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu @@ -0,0 +1,98 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_upper_triang_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor) +{ + // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_INTERNAL_ASSERT(seq_len <= 2048); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({attn_batches, seq_len, seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_forward", + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + return softmax_results; +} + + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = output_grads.size(0); + const int seq_len = output_grads.size(1); + TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_backward", + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/megatron/fused_kernels/tests/__init__.py b/megatron/fused_kernels/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/megatron/fused_kernels/tests/test_fused_kernels.py b/megatron/fused_kernels/tests/test_fused_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d5027a1f0a33f522b8cd55bd467fd881efe166 --- /dev/null +++ b/megatron/fused_kernels/tests/test_fused_kernels.py @@ -0,0 +1,300 @@ +import math + +import torch +from torch.nn import LayerNorm + +from megatron.model.enums import AttnMaskType +from megatron.model.fused_layer_norm import MixedFusedLayerNorm +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from megatron.model.utils import attention_mask_func + + +def test_load_fused_kernels(): + try: + import fused_mix_prec_layer_norm_cuda + import scaled_masked_softmax_cuda + import scaled_upper_triang_masked_softmax_cuda + import torch + + print("[Success] load_fused_kernels") + except ImportError as e: + print("[Fail] load_fused_kernels") + raise e + + +def test_fused_softmax(): + bert = BertModel.from_pretrained("bert-base-cased").cuda().half() + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + embedding_output = bert.embeddings( + input_ids=tokens["input_ids"].cuda(), + position_ids=None, + token_type_ids=tokens["token_type_ids"].cuda(), + inputs_embeds=None, + past_key_values_length=0, + ) + + # (bsz, 1, 1, seq_len) + mask = bert.get_extended_attention_mask( + attention_mask=tokens["attention_mask"].cuda(), + input_shape=tokens["input_ids"].shape, + device=bert.device, + ) + # (bsz, 1, seq_len, seq_len) + mask = mask.repeat(1, 1, mask.size()[-1], 1) + + attention = bert.encoder.layer[0].attention.self + key_layer = attention.transpose_for_scores(attention.key(embedding_output)) + query_layer = attention.transpose_for_scores(attention.query(embedding_output)) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores /= math.sqrt(key_layer.size()[-1]) + + fused_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.padding, + scaled_masked_softmax_fusion=True, + ) + .cuda() + .half() + ) + + fused_softmax_output = fused_softmax( + attention_scores, + (mask != 0), + ) + + torch_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.padding, + scaled_masked_softmax_fusion=False, + ) + .cuda() + .half() + ) + + torch_softmax_output = torch_softmax( + attention_scores, + (mask != 0), + ) + + test_result = (fused_softmax_output - torch_softmax_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_softmax" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_softmax" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + + +def test_fused_upper_triangle_mask_softmax(): + gpt = GPT2Model.from_pretrained("gpt2").cuda().half() + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi" # 24 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + attention_mask = tokens["attention_mask"].cuda() + attention_mask = attention_mask.view(attention_mask.size(0), -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1) + attn = gpt.h[0] + + hidden_states = gpt.wte(tokens["input_ids"].cuda()) + q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1) + q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim) + k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim) + attn_weights = torch.matmul(q, k.transpose(-1, -2)) + + sq, sk = q.size(-2), k.size(-2) + causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool() + total_mask = ~(causal_mask & (attention_mask == 0)) + """ + tensor([[[[False, True, True, ..., True, True, True], + [False, False, True, ..., True, True, True], + [False, False, False, ..., True, True, True], + ..., + [False, False, False, ..., False, True, True], + [False, False, False, ..., False, False, True], + [False, False, False, ..., False, False, False]]] + """ + + fused_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=True, + ) + .cuda() + .half() + ) + + fused_softmax_output = fused_softmax( + attn_weights, + total_mask, + ) + + torch_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=False, + ) + .cuda() + .half() + ) + + torch_softmax_output = torch_softmax( + attn_weights, + total_mask, + ) + + test_result = (fused_softmax_output - torch_softmax_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_upper_triangle_mask_softmax" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_upper_triangle_mask_softmax" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + + +def test_layer_norm(): + bert = BertModel.from_pretrained("bert-base-cased").cuda().half() + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + # [bsz, seq_len, d_model] + embedding_output = ( + bert.embeddings( + input_ids=tokens["input_ids"].cuda(), + position_ids=None, + token_type_ids=tokens["token_type_ids"].cuda(), + inputs_embeds=None, + past_key_values_length=0, + ) + .cuda() + .half() + ) + + fused_layernorm_layer = ( + MixedFusedLayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() + ) + + torch_layernorm_layer = ( + LayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() + ) + + fused_output = fused_layernorm_layer(embedding_output) + torch_output = torch_layernorm_layer(embedding_output) + test_result = (fused_output - torch_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_layer_norm" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_output[-1][-1][:5].tolist()}" + f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_layer_norm" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" + ) + + +if __name__ == "__main__": + try: + from transformers import BertTokenizer, GPT2Tokenizer + from transformers.models.bert.modeling_bert import BertModel + from transformers.models.gpt2.modeling_gpt2 import GPT2Model + import transformers + + transformers.logging.set_verbosity( + transformers.logging.FATAL, + ) + + except: + print("\n[Fail] Please install `transformers` package to test fused kernels\n") + exit(-1) + + test_load_fused_kernels() + test_fused_softmax() + test_fused_upper_triangle_mask_softmax() + test_layer_norm() diff --git a/megatron/fused_kernels/type_shim.h b/megatron/fused_kernels/type_shim.h new file mode 100644 index 0000000000000000000000000000000000000000..30e605bd38076e2ead81b24704d64cbf16cea437 --- /dev/null +++ b/megatron/fused_kernels/type_shim.h @@ -0,0 +1,117 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include +#include "compat.h" + + +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + +#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch(TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + diff --git a/megatron/global_vars.py b/megatron/global_vars.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b2741444d658648755c3fa1fea9d55616e201f --- /dev/null +++ b/megatron/global_vars.py @@ -0,0 +1,308 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron global variables.""" + +import os +import sys +import time +from functools import reduce +import operator +import torch + +from megatron import dist_signal_handler +from megatron.tokenizer import build_tokenizer +from .microbatches import build_num_microbatches_calculator + +_GLOBAL_ARGS = None +_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None +_GLOBAL_TOKENIZER = None +_GLOBAL_TENSORBOARD_WRITER = None +_GLOBAL_ADLR_AUTORESUME = None +_GLOBAL_TIMERS = None +_GLOBAL_SIGNAL_HANDLER = None +_GLOBAL_MEMORY_BUFFER = None + +def get_args(): + """Return arguments.""" + _ensure_var_is_initialized(_GLOBAL_ARGS, 'args') + return _GLOBAL_ARGS + + +def get_num_microbatches(): + return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() + + +def get_current_global_batch_size(): + return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size() + + +def update_num_microbatches(consumed_samples, consistency_check=True): + _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, + consistency_check) + + +def get_tokenizer(): + """Return tokenizer.""" + _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer') + return _GLOBAL_TOKENIZER + + +def get_tensorboard_writer(): + """Return tensorboard writer. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_TENSORBOARD_WRITER + + +def get_adlr_autoresume(): + """ADLR autoresume object. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_ADLR_AUTORESUME + + +def get_timers(): + """Return timers.""" + _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') + return _GLOBAL_TIMERS + + +def get_signal_handler(): + _ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler') + return _GLOBAL_SIGNAL_HANDLER + + +def get_global_memory_buffer(): + _ensure_var_is_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer') + return _GLOBAL_MEMORY_BUFFER + + +def _set_signal_handler(): + global _GLOBAL_SIGNAL_HANDLER + _ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler') + _GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__() + + + +def set_global_variables(args): + """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" + + assert args is not None + + _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args') + set_args(args) + + _build_num_microbatches_calculator(args) + if args.vocab_file: + _ = _build_tokenizer(args) + _set_tensorboard_writer(args) + _set_adlr_autoresume(args) + _set_timers() + _set_global_memory_buffer() + + if args.exit_signal_handler: + _set_signal_handler() + + +def set_args(args): + global _GLOBAL_ARGS + _GLOBAL_ARGS = args + + +def _build_num_microbatches_calculator(args): + + global _GLOBAL_NUM_MICROBATCHES_CALCULATOR + _ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, + 'num microbatches calculator') + + _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator( + args) + + +def _build_tokenizer(args): + """Initialize tokenizer.""" + global _GLOBAL_TOKENIZER + _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer') + _GLOBAL_TOKENIZER = build_tokenizer(args) + return _GLOBAL_TOKENIZER + + +def rebuild_tokenizer(args): + global _GLOBAL_TOKENIZER + _GLOBAL_TOKENIZER = None + return _build_tokenizer(args) + + +def _set_tensorboard_writer(args): + """Set tensorboard writer.""" + global _GLOBAL_TENSORBOARD_WRITER + _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, + 'tensorboard writer') + + if hasattr(args, 'tensorboard_dir') and \ + args.tensorboard_dir and args.rank == (args.world_size - 1): + try: + from torch.utils.tensorboard import SummaryWriter + print('> setting tensorboard ...') + _GLOBAL_TENSORBOARD_WRITER = SummaryWriter( + log_dir=args.tensorboard_dir, + max_queue=args.tensorboard_queue_size) + except ModuleNotFoundError: + print('WARNING: TensorBoard writing requested but is not ' + 'available (are you using PyTorch 1.1.0 or later?), ' + 'no TensorBoard logs will be written.', flush=True) + + +def _set_adlr_autoresume(args): + """Initialize ADLR autoresume.""" + global _GLOBAL_ADLR_AUTORESUME + _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume') + + if args.adlr_autoresume: + if args.rank == 0: + print('enabling autoresume ...', flush=True) + sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.')) + try: + from userlib.auto_resume import AutoResume + except BaseException: + print('ADLR autoresume is not available, exiting ...') + sys.exit() + + _GLOBAL_ADLR_AUTORESUME = AutoResume + + +def _set_timers(): + """Initialize timers.""" + global _GLOBAL_TIMERS + _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') + _GLOBAL_TIMERS = Timers() + +def _set_global_memory_buffer(): + """Initialize global buffer""" + global _GLOBAL_MEMORY_BUFFER + _ensure_var_is_not_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer') + _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() + + +def _ensure_var_is_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is not None, '{} is not initialized.'.format(name) + + +def _ensure_var_is_not_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is None, '{} is already initialized.'.format(name) + + +class _Timer: + """Timer.""" + + def __init__(self, name): + self.name_ = name + self.elapsed_ = 0.0 + self.started_ = False + self.start_time = time.time() + + def start(self): + """Start the timer.""" + assert not self.started_, 'timer has already been started' + torch.cuda.synchronize() + self.start_time = time.time() + self.started_ = True + + def stop(self): + """Stop the timer.""" + assert self.started_, 'timer is not started' + torch.cuda.synchronize() + self.elapsed_ += (time.time() - self.start_time) + self.started_ = False + + def reset(self): + """Reset timer.""" + self.elapsed_ = 0.0 + self.started_ = False + + def elapsed(self, reset=True): + """Calculate the elapsed time.""" + started_ = self.started_ + # If the timing in progress, end it first. + if self.started_: + self.stop() + # Get the elapsed time. + elapsed_ = self.elapsed_ + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if started_: + self.start() + return elapsed_ + + +class Timers: + """Group of timers.""" + + def __init__(self): + self.timers = {} + + def __call__(self, name): + if name not in self.timers: + self.timers[name] = _Timer(name) + return self.timers[name] + + def write(self, names, writer, iteration, normalizer=1.0, reset=False): + """Write timers to a tensorboard writer""" + # currently when using add_scalars, + # torch.utils.add_scalars makes each timer its own run, which + # polutes the runs list, so we just add each as a scalar + assert normalizer > 0.0 + for name in names: + value = self.timers[name].elapsed(reset=reset) / normalizer + writer.add_scalar(name + '-time', value, iteration) + + def log(self, names, normalizer=1.0, reset=True): + """Log a group of timers.""" + assert normalizer > 0.0 + string = 'time (ms)' + for name in names: + elapsed_time = self.timers[name].elapsed( + reset=reset) * 1000.0 / normalizer + string += ' | {}: {:.2f}'.format(name, elapsed_time) + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == ( + torch.distributed.get_world_size() - 1): + print(string, flush=True) + else: + print(string, flush=True) + + +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations. + Caller should ensure that buffers of the same name + are not used concurrently.""" + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name): + required_len = reduce(operator.mul, tensor_shape, 1) + if self.buffer.get((name, dtype), None) is None or \ + self.buffer[(name, dtype)].numel() < required_len: + self.buffer[(name, dtype)] = \ + torch.empty(required_len, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) diff --git a/megatron/indexer.py b/megatron/indexer.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ff9e36f85b5149715022526d550291c62440ca --- /dev/null +++ b/megatron/indexer.py @@ -0,0 +1,129 @@ +import sys +import time +import torch +import torch.distributed as dist + +from megatron import get_args, print_rank_0 +from megatron import mpu +from megatron.checkpointing import load_biencoder_checkpoint +from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset +from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch +from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader +from megatron.data.realm_index import detach, OpenRetreivalDataStore +from megatron.model.biencoder_model import get_model_provider +from megatron.training import get_model + + +class IndexBuilder(object): + """ + Object for taking one pass over a dataset and creating a BlockData of its + embeddings + """ + def __init__(self): + args = get_args() + self.model = None + self.dataloader = None + self.evidence_embedder_obj = None + self.biencoder_shared_query_context_model = \ + args.biencoder_shared_query_context_model + + # need to know whether we're using a REALM checkpoint (args.load) + # or ICT checkpoint + assert not (args.load and args.ict_load) + + self.log_interval = args.indexer_log_interval + self.batch_size = args.indexer_batch_size + + self.load_attributes() + self.is_main_builder = mpu.get_data_parallel_rank() == 0 + self.num_total_builders = mpu.get_data_parallel_world_size() + self.iteration = self.total_processed = 0 + + def load_attributes(self): + """ + Load the necessary attributes: model, dataloader and empty BlockData + """ + only_context_model = True + if self.biencoder_shared_query_context_model: + only_context_model = False + + model = get_model(get_model_provider(only_context_model=\ + only_context_model, biencoder_shared_query_context_model=\ + self.biencoder_shared_query_context_model)) + + self.model = load_biencoder_checkpoint(model, + only_context_model=only_context_model) + + assert len(self.model) == 1 + self.model[0].eval() + + self.dataset = get_open_retrieval_wiki_dataset() + self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \ + self.batch_size)) + + self.evidence_embedder_obj = OpenRetreivalDataStore( \ + load_from_path=False) + + def track_and_report_progress(self, batch_size): + """ + Utility function for tracking progress + """ + self.iteration += 1 + self.total_processed += batch_size * self.num_total_builders + if self.is_main_builder and self.iteration % self.log_interval == 0: + print('Batch {:10d} | Total {:10d}'.format(self.iteration, + self.total_processed), flush=True) + + def build_and_save_index(self): + """ + Goes through one epoch of the dataloader and adds all data to this + instance's BlockData. + + The copy of BlockData is saved as a shard, which when run in a + distributed setting will be consolidated by the rank 0 process + and saved as a final pickled BlockData. + """ + assert len(self.model) == 1 + unwrapped_model = self.model[0] + + while not hasattr(unwrapped_model, 'embed_text'): + unwrapped_model = unwrapped_model.module + + while True: + try: + # batch also has query_tokens and query_pad_data + row_id, context_tokens, context_mask, context_types, \ + context_pad_mask = get_open_retrieval_batch( \ + self.dataloader) + except (StopIteration, IndexError): + break + + # TODO: can we add with torch.no_grad() to reduce memory usage + # detach, separate fields and add to BlockData + assert context_mask.dtype == torch.bool + context_logits = unwrapped_model.embed_text( + unwrapped_model.context_model, context_tokens, context_mask, + context_types) + + context_logits = detach(context_logits) + row_id = detach(row_id) + + self.evidence_embedder_obj.add_block_data(row_id, context_logits) + self.track_and_report_progress(batch_size=len(row_id)) + + # This process signals to finalize its shard and then synchronize with + # the other processes + self.evidence_embedder_obj.save_shard() + torch.distributed.barrier() + del self.model + + # rank 0 process builds the final copy + if self.is_main_builder: + self.evidence_embedder_obj.merge_shards_and_save() + # make sure that every single piece of data was embedded + assert len(self.evidence_embedder_obj.embed_data) == \ + len(self.dataset) + self.evidence_embedder_obj.clear() + + # complete building the final copy + torch.distributed.barrier() diff --git a/megatron/initialize.py b/megatron/initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..5994c0d8e5548f1825771fd0fc3a4ac020269caa --- /dev/null +++ b/megatron/initialize.py @@ -0,0 +1,309 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron initialization.""" + +import random +import os +import time + +import numpy as np +import torch +from datetime import timedelta + +from megatron import fused_kernels +from megatron import get_adlr_autoresume +from megatron import get_args +from megatron import get_tensorboard_writer +from megatron import mpu +from megatron.arguments import (parse_args, validate_args) +from megatron.checkpointing import load_args_from_checkpoint +from megatron.global_vars import set_global_variables +from megatron.mpu import (set_tensor_model_parallel_rank, + set_tensor_model_parallel_world_size) +from megatron.model.transformer import bias_dropout_add_fused_train +from megatron.model.fused_bias_gelu import bias_gelu + + +def initialize_megatron(extra_args_provider=None, args_defaults={}, + ignore_unknown_args=False, allow_no_cuda=False): + """Set global variables, initialize distributed, and + set autoresume and random seeds. + `allow_no_cuda` should not be set unless using megatron for cpu only + data processing. In general this arg should not be set unless you know + what you are doing. + Returns a function to finalize distributed env initialization + (optionally, only when args.lazy_mpu_init == True) + """ + if not allow_no_cuda: + # Make sure cuda is available. + assert torch.cuda.is_available(), 'Megatron requires CUDA.' + + # Parse arguments + args = parse_args(extra_args_provider, ignore_unknown_args) + + if args.use_checkpoint_args or args_defaults.get('use_checkpoint_args', False): + assert args.load is not None, '--use-checkpoints-args requires --load argument' + load_args_from_checkpoint(args) + + validate_args(args, args_defaults) + + # set global args, build tokenizer, and set adlr-autoresume, + # tensorboard-writer, and timers. + set_global_variables(args) + + # torch.distributed initialization + def finish_mpu_init(): + args = get_args() + # Pytorch distributed. + _initialize_distributed() + + # Random seeds for reproducibility. + if args.rank == 0: + print('> setting random seeds to {} ...'.format(args.seed)) + _set_random_seed(args.seed, args.data_parallel_random_init) + + args = get_args() + if args.lazy_mpu_init: + args.use_cpu_initialization=True + # delayed initialization of DDP-related stuff + # We only set basic DDP globals + set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) + # and return function for external DDP manager + # to call when it has DDP initialized + set_tensor_model_parallel_rank(args.rank) + return finish_mpu_init + else: + # Megatron's MPU is the master. Complete initialization right away. + finish_mpu_init() + + # Autoresume. + _init_autoresume() + + # Compile dependencies. + _compile_dependencies() + + # No continuation function + return None + + +def _compile_dependencies(): + + args = get_args() + + # ========================= + # Compile dataset C++ code. + # ========================= + # TODO: move this to ninja + if torch.distributed.get_rank() == 0: + start_time = time.time() + print('> compiling dataset index builder ...') + from megatron.data.dataset_utils import compile_helper + compile_helper() + print('>>> done with dataset index builder. Compilation time: {:.3f} ' + 'seconds'.format(time.time() - start_time), flush=True) + + # ================== + # Load fused kernels + # ================== + + # Custom kernel constraints check. + seq_len = args.seq_length + attn_batch_size = \ + (args.num_attention_heads / args.tensor_model_parallel_size) * \ + args.micro_batch_size + # Constraints on sequence length and attn_batch_size to enable warp based + # optimization and upper triangular optimization (for causal mask) + custom_kernel_constraint = seq_len > 16 and seq_len <=4096 and \ + seq_len % 4 == 0 and attn_batch_size % 4 == 0 + # Print a warning. + if not ((args.fp16 or args.bf16) and + custom_kernel_constraint and + args.masked_softmax_fusion): + if args.rank == 0: + print('WARNING: constraints for invoking optimized' + ' fused softmax kernel are not met. We default' + ' back to unfused kernel invocations.', flush=True) + + # Always build on rank zero first. + if torch.distributed.get_rank() == 0: + start_time = time.time() + print('> compiling and loading fused kernels ...', flush=True) + fused_kernels.load(args) + torch.distributed.barrier() + else: + torch.distributed.barrier() + fused_kernels.load(args) + # Simple barrier to make sure all ranks have passed the + # compilation phase successfully before moving on to the + # rest of the program. We think this might ensure that + # the lock is released. + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>>> done with compiling and loading fused kernels. ' + 'Compilation time: {:.3f} seconds'.format( + time.time() - start_time), flush=True) + + + +def _initialize_distributed(): + """Initialize torch.distributed and mpu.""" + args = get_args() + + device_count = torch.cuda.device_count() + if torch.distributed.is_initialized(): + + if args.rank == 0: + print('torch distributed is already initialized, ' + 'skipping initialization ...', flush=True) + args.rank = torch.distributed.get_rank() + args.world_size = torch.distributed.get_world_size() + + else: + + if args.rank == 0: + print('> initializing torch distributed ...', flush=True) + # Manually set the device ids. + if device_count > 0: + device = args.rank % device_count + if args.local_rank is not None: + assert args.local_rank == device, \ + 'expected local-rank to be the same as rank % device-count.' + else: + args.local_rank = device + torch.cuda.set_device(device) + # Call the init process + torch.distributed.init_process_group( + backend=args.distributed_backend, + world_size=args.world_size, rank=args.rank, + timeout=timedelta(minutes=10)) + + # Set the tensor model-parallel, pipeline model-parallel, and + # data-parallel communicators. + if device_count > 0: + if mpu.model_parallel_is_initialized(): + print('model parallel is already initialized') + else: + mpu.initialize_model_parallel(args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, + args.virtual_pipeline_model_parallel_size, + args.pipeline_model_parallel_split_rank) + + +def _init_autoresume(): + """Set autoresume start time.""" + autoresume = get_adlr_autoresume() + if autoresume: + torch.distributed.barrier() + autoresume.init() + torch.distributed.barrier() + + +def _set_random_seed(seed_, data_parallel_random_init=False): + """Set random seed for reproducability.""" + if seed_ is not None and seed_ > 0: + # Ensure that different pipeline MP stages get different seeds. + seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank()) + # Ensure different data parallel ranks get different seeds + if data_parallel_random_init: + seed = seed + (10 * mpu.get_data_parallel_rank()) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.device_count() > 0: + mpu.model_parallel_cuda_manual_seed(seed) + else: + raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) + + +def write_args_to_tensorboard(): + """Write arguments to tensorboard.""" + args = get_args() + writer = get_tensorboard_writer() + if writer: + for arg in vars(args): + writer.add_text(arg, str(getattr(args, arg)), + global_step=args.iteration) + + +def set_jit_fusion_options(): + """Set PyTorch JIT layer fusion options.""" + # flags required to enable jit fusion kernels + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): + # nvfuser + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(True) + torch._C._debug_set_autodiff_subgraph_inlining(False) + else: + # legacy pytorch fuser + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + + _warmup_jit_function() + + +def _warmup_jit_function(): + """ Compilie JIT functions before the main training steps """ + args = get_args() + if args.bf16: + dtype = torch.bfloat16 + elif args.fp16: + dtype = torch.float16 + else: + dtype = torch.float32 + + # Warmup fused bias+gelu + bias = torch.rand(args.ffn_hidden_size // args.tensor_model_parallel_size, + dtype=dtype, device='cuda') + input = torch.rand((args.seq_length, args.micro_batch_size, + args.ffn_hidden_size // args.tensor_model_parallel_size), + dtype=dtype, device='cuda') + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for bias_grad, input_grad in zip([True, True], [False, True]): + bias.requires_grad, input.requires_grad = bias_grad, input_grad + for _ in range(5): + output = bias_gelu(bias, input) + del bias, input, output + + # Warmup fused bias+dropout+add + if args.sequence_parallel: + seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() + else: + seq_length = args.seq_length + input = torch.rand((seq_length, args.micro_batch_size, args.hidden_size), + dtype=dtype, device='cuda') + residual = torch.rand((seq_length, args.micro_batch_size, args.hidden_size), + dtype=dtype, device='cuda') + bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual) + dropout_rate = 0.1 + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): + input.requires_grad = input_grad + bias.requires_grad = bias_grad + residual.requires_grad = residual_grad + for _ in range(5): + output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate) + del bias, input, residual, output + torch.cuda.empty_cache() diff --git a/megatron/memory.py b/megatron/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..be5a117bcd3f68b93950a078ef5b1a224ff744cf --- /dev/null +++ b/megatron/memory.py @@ -0,0 +1,145 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + + +# A dictionary of all the memory buffers allocated. +_MEM_BUFFS = dict() + + +def allocate_mem_buff(name, numel, dtype, track_usage): + """Allocate a memory buffer.""" + assert name not in _MEM_BUFFS, \ + 'memory buffer {} already allocated.'.format(name) + _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage) + return _MEM_BUFFS[name] + + +def get_mem_buff(name): + """Get the memory buffer.""" + return _MEM_BUFFS[name] + + +class MemoryBuffer: + """Contiguous memory buffer. + Allocate a contiguous memory of type `dtype` and size `numel`. It is + used to reduce memory fragmentation. + + Usage: After the allocation, the `_start` index is set tot the first + index of the memory. A memory chunk starting from `_start` index + can be `allocated` for an input tensor, with the elements of the + tensor being coppied. The buffer can be reused by resetting the + `_start` index. + + """ + def __init__(self, name, numel, dtype, track_usage): + if torch.distributed.get_rank() == 0: + element_size = torch.tensor([], dtype=dtype).element_size() + print('> building the {} memory buffer with {} num elements ' + 'and {} dtype ({:.1f} MB)...'.format( + name, numel, dtype, numel*element_size/1024/1024), + flush=True) + self.name = name + self.numel = numel + self.dtype = dtype + self.data = torch.empty(self.numel, + dtype=self.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + + # Index tracking the start of the free memory. + self._start = 0 + + # Values used for tracking usage. + self.track_usage = track_usage + if self.track_usage: + self.in_use_value = 0.0 + self.total_value = 0.0 + + + def reset(self): + """Reset the buffer start index to the beginning of the buffer.""" + self._start = 0 + + + def is_in_use(self): + """Whether the current buffer hold on to any memory.""" + return self._start > 0 + + + def numel_in_use(self): + """Return number of elements in use.""" + return self._start + + + def add(self, tensor): + """Allocate a chunk of memory from the buffer to tensor and copy + the values.""" + assert tensor.dtype == self.dtype, \ + 'Input tensor type {} different from buffer type {}'.format( + tensor.dtype, self.dtype) + # Number of elements of the input tensor. + tensor_numel = torch.numel(tensor) + new_start = self._start + tensor_numel + assert new_start <= self.numel, \ + 'Not enough memory left in the buffer ({} > {})'.format( + tensor_numel, self.numel - self._start) + # New tensor is a view into the memory. + new_tensor = self.data[self._start:new_start] + self._start = new_start + new_tensor = new_tensor.view(tensor.shape) + new_tensor.copy_(tensor) + # Return a pointer to the new tensor. + return new_tensor + + + def get_data(self): + """Return the data currently in use.""" + if self.track_usage: + self.in_use_value += float(self._start) + self.total_value += float(self.numel) + return self.data[:self._start] + + + def print_average_usage(self): + """Print memory usage average over time. We would like this value + to be as high as possible.""" + assert self.track_usage, 'You need to enable track usage.' + if torch.distributed.get_rank() == 0: + print(' > usage of {} memory buffer: {:.2f} %'.format( + self.name, self.in_use_value * 100.0 / self.total_value), + flush=True) + + + +class RingMemBuffer: + """A ring of memory buffers.""" + + def __init__(self, name, num_buffers, numel, dtype, track_usage): + self.num_buffers = num_buffers + self.buffers = [ + allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage) + for i in range(num_buffers)] + self._index = -1 + + + def get_next_buffer(self): + self._index += 1 + self._index = self._index % self.num_buffers + buff = self.buffers[self._index] + assert not buff.is_in_use(), 'buffer is already in use.' + return buff diff --git a/megatron/microbatches.py b/megatron/microbatches.py new file mode 100644 index 0000000000000000000000000000000000000000..c2bf2823dc89e35cb491926c69e4bcbc690e84ed --- /dev/null +++ b/megatron/microbatches.py @@ -0,0 +1,157 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron number of micro-batches calculators.""" + +from abc import ABC +from abc import abstractmethod + + +def build_num_microbatches_calculator(args): + + # Constant num micro-batches. + if args.rampup_batch_size is None: + num_microbatches_calculator = ConstantNumMicroBatches( + args.global_batch_size, args.micro_batch_size, + args.data_parallel_size) + if args.rank == 0: + print('setting number of micro-batches to constant {}'.format( + num_microbatches_calculator.get()), flush=True) + + else: + assert len(args.rampup_batch_size) == 3, 'expected the following ' \ + 'format: --rampup-batch-size ' \ + ' ' + start_batch_size = int(args.rampup_batch_size[0]) + batch_size_increment = int(args.rampup_batch_size[1]) + ramup_samples = int(args.rampup_batch_size[2]) + if args.rank == 0: + print('will use batch size rampup starting from global batch ' + 'size {} to global batch size {} with batch size increments ' + '{} over {} samples.'.format(start_batch_size, + args.global_batch_size, + batch_size_increment, + ramup_samples), flush=True) + num_microbatches_calculator = RampupBatchsizeNumMicroBatches( + start_batch_size, batch_size_increment, ramup_samples, + args.global_batch_size, args.micro_batch_size, + args.data_parallel_size) + + return num_microbatches_calculator + + +class NumMicroBatchesCalculator(ABC): + + def __init__(self): + self.num_micro_batches = None + self.current_global_batch_size = None + + def get(self): + return self.num_micro_batches + + def get_current_global_batch_size(self): + return self.current_global_batch_size + + @abstractmethod + def update(self, consumed_samples, consistency_check): + pass + + +class ConstantNumMicroBatches(NumMicroBatchesCalculator): + + def __init__(self, global_batch_size, micro_batch_size, data_parallel_size): + micro_batch_times_data_parallel = micro_batch_size * \ + data_parallel_size + assert global_batch_size % micro_batch_times_data_parallel == 0, \ + 'global batch size ({}) is not divisible by micro batch size ({})' \ + ' times data parallel size ({})'.format(global_batch_size, + micro_batch_size, + data_parallel_size) + self.num_micro_batches = global_batch_size // \ + micro_batch_times_data_parallel + assert self.num_micro_batches >= 1 + self.current_global_batch_size = global_batch_size + + def update(self, consumed_samples, consistency_check): + pass + + +class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): + + def __init__(self, start_batch_size, batch_size_increment, ramup_samples, + global_batch_size, micro_batch_size, data_parallel_size): + """Batch size ramp up. + Over + steps = (global-batch-size - start-batch-size) / batch_size_increment + increment batch size from start-batch-size to global-batch-size using + rampup-samples / steps + samples. + Arguments: + start_batch_size: global batch size to start with + batch_size_increment: global batch size increments + ramup_samples: number of samples to use ramp up global + batch size from `start_batch_size` to `global_batch_size` + global_batch_size: global batch size post rampup + micro_batch_size: micro batch size + data_parallel_size: data parallel size. + """ + + self.micro_batch_size = micro_batch_size + self.data_parallel_size = data_parallel_size + self.micro_batch_times_data_parallel_size = self.micro_batch_size * \ + self.data_parallel_size + assert self.micro_batch_times_data_parallel_size > 0 + + assert start_batch_size > 0 + self.start_batch_size = start_batch_size + + assert global_batch_size > 0 + self.global_batch_size = global_batch_size + diff_batch_size = self.global_batch_size - self.start_batch_size + assert diff_batch_size >= 0 + assert batch_size_increment > 0 + self.batch_size_increment = batch_size_increment + assert diff_batch_size % batch_size_increment == 0, 'expected ' \ + 'global batch size interval ({}) to be divisible by global batch ' \ + 'size increment ({})'.format(diff_batch_size, batch_size_increment) + + num_increments = diff_batch_size // self.batch_size_increment + self.ramup_samples = ramup_samples + assert self.ramup_samples >= 0 + self.rampup_samples_per_increment = self.ramup_samples / num_increments + + # Initialize number of microbatches. + self.update(0, False) + + + def update(self, consumed_samples, consistency_check): + + if consumed_samples > self.ramup_samples: + self.current_global_batch_size = self.global_batch_size + else: + steps = int(consumed_samples / self.rampup_samples_per_increment) + self.current_global_batch_size = self.start_batch_size + \ + steps * self.batch_size_increment + assert self.current_global_batch_size <= self.global_batch_size + + if consistency_check: + assert self.current_global_batch_size % \ + self.micro_batch_times_data_parallel_size == 0, 'current global ' \ + 'batch size ({}) is not divisible by micro-batch-size ({}) times' \ + 'data parallel size ({})'.format(self.current_global_batch_size, + self.micro_batch_size, + self.data_parallel_size) + self.num_micro_batches = self.current_global_batch_size // \ + self.micro_batch_times_data_parallel_size diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a672da7b282ec75252a2a6717268f3f26a05004 --- /dev/null +++ b/megatron/model/__init__.py @@ -0,0 +1,25 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm + +from .distributed import DistributedDataParallel +from .bert_model import BertModel +from .glm_model import GlmModel +from .gpt_model import GPTModel +from .t5_model import T5Model +from .language_model import get_language_model +from .module import Float16Module +from .enums import ModelType diff --git a/megatron/model/bert_model.py b/megatron/model/bert_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1b8c5ba693fd31f7de6296dd8369d74b40ffcc51 --- /dev/null +++ b/megatron/model/bert_model.py @@ -0,0 +1,245 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BERT model.""" + +import torch + +from megatron import get_args +from megatron import mpu +from megatron.model.enums import AttnMaskType +from megatron.model.language_model import parallel_lm_logits +from megatron.model.language_model import get_language_model +from megatron.model import LayerNorm +from megatron.model.utils import openai_gelu, erf_gelu +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from .module import MegatronModule + +def bert_extended_attention_mask(attention_mask): + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = attention_mask.unsqueeze(2) + # [b, s, s] + attention_mask_bss = attention_mask_b1s * attention_mask_bs1 + # [b, 1, s, s] + extended_attention_mask = attention_mask_bss.unsqueeze(1) + + # Convert attention mask to binary: + extended_attention_mask = (extended_attention_mask < 0.5) + + return extended_attention_mask + +def bert_position_ids(token_ids): + # Create position ids + seq_length = token_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, + device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + +class BertLMHead(MegatronModule): + """Masked LM head for Bert + + Arguments: + mpu_vocab_size: model parallel size of vocabulary. + hidden_size: hidden size + init_method: init method for weight initialization + layernorm_epsilon: tolerance for layer norm divisions + parallel_output: whether output logits being distributed or not. + """ + + def __init__(self, mpu_vocab_size, hidden_size, init_method, + layernorm_epsilon, parallel_output): + + super(BertLMHead, self).__init__() + + args = get_args() + + self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) + mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + self.parallel_output = parallel_output + + self.dense = get_linear_layer(hidden_size, hidden_size, init_method) + setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel) + setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel) + + self.layernorm = LayerNorm(hidden_size, + eps=layernorm_epsilon, + sequence_parallel=args.sequence_parallel) + self.gelu = torch.nn.functional.gelu + if args.openai_gelu: + self.gelu = openai_gelu + elif args.onnx_safe: + self.gelu = erf_gelu + + def forward(self, hidden_states, word_embeddings_weight): + hidden_states = self.dense(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = self.layernorm(hidden_states) + output = parallel_lm_logits(hidden_states, + word_embeddings_weight, + self.parallel_output, + bias=self.bias) + return output + + +def post_language_model_processing(lm_output, pooled_output, + lm_head, binary_head, + lm_labels, + logit_weights, + fp16_lm_cross_entropy): + # Output. + lm_logits = lm_head( + lm_output, logit_weights) + binary_logits = None + if binary_head is not None: + binary_logits = binary_head(pooled_output) + + if lm_labels is None: + # [s b h] => [b s h] + return lm_logits.transpose(0,1).contiguous(), binary_logits + else: + # [b s] => [s b] + lm_labels = lm_labels.transpose(0,1).contiguous() + # lm_logits : [s, b, h] and lm_labels: [s, b] + if fp16_lm_cross_entropy: + assert lm_logits.dtype == torch.half + lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) + else: + lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), + lm_labels) + # [s, b] => [b s] + lm_loss = lm_loss.transpose(0,1).contiguous() + return lm_loss, binary_logits + + +class BertModel(MegatronModule): + """Bert Language model.""" + + def __init__(self, + num_tokentypes=2, + add_binary_head=True, + parallel_output=True, + pre_process=True, + post_process=True): + super(BertModel, self).__init__() + args = get_args() + + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + self.add_binary_head = add_binary_head + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + + init_method = init_method_normal(args.init_method_std) + scaled_init_method = scaled_init_method_normal(args.init_method_std, + args.num_layers) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=self.add_binary_head, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method, + pre_process=self.pre_process, + post_process=self.post_process) + + self.initialize_word_embeddings(init_method_normal) + if self.post_process: + self.lm_head = BertLMHead( + self.word_embeddings_weight().size(0), + args.hidden_size, init_method, args.layernorm_epsilon, parallel_output) + self._lm_head_key = 'lm_head' + self.binary_head = None + if self.add_binary_head: + self.binary_head = get_linear_layer(args.hidden_size, 2, + init_method) + self._binary_head_key = 'binary_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, bert_model_input, attention_mask, + tokentype_ids=None, lm_labels=None): + + extended_attention_mask = bert_extended_attention_mask(attention_mask) + input_ids = bert_model_input + position_ids = bert_position_ids(input_ids) + + lm_output = self.language_model( + input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids + ) + if self.post_process and self.add_binary_head: + lm_output, pooled_output = lm_output + else: + pooled_output = None + + if self.post_process: + return post_language_model_processing(lm_output, pooled_output, + self.lm_head, self.binary_head, + lm_labels, + self.word_embeddings_weight(), + self.fp16_lm_cross_entropy) + else: + return lm_output + + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.post_process: + state_dict_[self._lm_head_key] \ + = self.lm_head.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.post_process and self.add_binary_head: + state_dict_[self._binary_head_key] \ + = self.binary_head.state_dict(destination, prefix, keep_vars) + # Save word_embeddings. + if self.post_process and not self.pre_process: + state_dict_[self._word_embeddings_for_head_key] \ + = self.word_embeddings.state_dict(destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + if self.post_process: + self.lm_head.load_state_dict( + state_dict[self._lm_head_key], strict=strict) + if self.post_process and self.add_binary_head: + self.binary_head.load_state_dict( + state_dict[self._binary_head_key], strict=strict) + # Load word_embeddings. + if self.post_process and not self.pre_process: + self.word_embeddings.load_state_dict( + state_dict[self._word_embeddings_for_head_key], strict=strict) diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py new file mode 100644 index 0000000000000000000000000000000000000000..752c5752e92006dc7c2e2828a77f58b9542d81b4 --- /dev/null +++ b/megatron/model/biencoder_model.py @@ -0,0 +1,330 @@ +import os +import torch +import sys + +from megatron import get_args, print_rank_0 +from megatron.checkpointing import fix_query_key_value_ordering +from megatron.checkpointing import get_checkpoint_tracker_filename +from megatron.checkpointing import get_checkpoint_name +from megatron import mpu, get_tokenizer +from megatron.model.bert_model import bert_position_ids +from megatron.model.enums import AttnMaskType +from megatron.model.language_model import get_language_model +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from .module import MegatronModule + +def get_model_provider(only_query_model=False, only_context_model=False, + biencoder_shared_query_context_model=False): + + def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building Bienoder model ...') + model = biencoder_model_provider(only_query_model=only_query_model, + only_context_model = only_context_model, + biencoder_shared_query_context_model = \ + biencoder_shared_query_context_model, + pre_process=pre_process, post_process=post_process) + + return model + + return model_provider + + +def biencoder_model_provider(only_query_model=False, + only_context_model=False, + biencoder_shared_query_context_model=False, + pre_process=True, + post_process=True): + """Build the model.""" + + assert mpu.get_tensor_model_parallel_world_size() == 1 and \ + mpu.get_pipeline_model_parallel_world_size() == 1, \ + "Model parallel size > 1 not supported for ICT" + + print_rank_0('building BiEncoderModel...') + + # simpler to just keep using 2 tokentypes since + # the LM we initialize with has 2 tokentypes + model = BiEncoderModel( + num_tokentypes=2, + parallel_output=False, + only_query_model=only_query_model, + only_context_model=only_context_model, + biencoder_shared_query_context_model=\ + biencoder_shared_query_context_model, + pre_process=pre_process, + post_process=post_process) + + return model + + +class BiEncoderModel(MegatronModule): + """Bert-based module for Biencoder model.""" + + def __init__(self, + num_tokentypes=1, + parallel_output=True, + only_query_model=False, + only_context_model=False, + biencoder_shared_query_context_model=False, + pre_process=True, + post_process=True): + super(BiEncoderModel, self).__init__() + args = get_args() + + bert_kwargs = dict( + num_tokentypes=num_tokentypes, + parallel_output=parallel_output, + pre_process=pre_process, + post_process=post_process) + + self.biencoder_shared_query_context_model = \ + biencoder_shared_query_context_model + assert not (only_context_model and only_query_model) + self.use_context_model = not only_query_model + self.use_query_model = not only_context_model + self.biencoder_projection_dim = args.biencoder_projection_dim + + if self.biencoder_shared_query_context_model: + self.model = PretrainedBertModel(**bert_kwargs) + self._model_key = 'shared_model' + self.query_model, self.context_model = self.model, self.model + else: + if self.use_query_model: + # this model embeds (pseudo-)queries - Embed_input in the paper + self.query_model = PretrainedBertModel(**bert_kwargs) + self._query_key = 'query_model' + + if self.use_context_model: + # this model embeds evidence blocks - Embed_doc in the paper + self.context_model = PretrainedBertModel(**bert_kwargs) + self._context_key = 'context_model' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + # this is just a placeholder and will be needed when model + # parallelism will be used + # self.language_model.set_input_tensor(input_tensor) + return + + def forward(self, query_tokens, query_attention_mask, query_types, + context_tokens, context_attention_mask, context_types): + """Run a forward pass for each of the models and + return the respective embeddings.""" + + if self.use_query_model: + query_logits = self.embed_text(self.query_model, + query_tokens, + query_attention_mask, + query_types) + else: + raise ValueError("Cannot embed query without the query model.") + if self.use_context_model: + context_logits = self.embed_text(self.context_model, + context_tokens, + context_attention_mask, + context_types) + else: + raise ValueError("Cannot embed block without the block model.") + return query_logits, context_logits + + @staticmethod + def embed_text(model, tokens, attention_mask, token_types): + """Embed a batch of tokens using the model""" + logits = model(tokens, + attention_mask, + token_types) + return logits + + def state_dict_for_save_checkpoint(self, destination=None, \ + prefix='', keep_vars=False): + """Save dict with state dicts of each of the models.""" + state_dict_ = {} + if self.biencoder_shared_query_context_model: + state_dict_[self._model_key] = \ + self.model.state_dict_for_save_checkpoint(destination, + prefix, + keep_vars) + else: + if self.use_query_model: + state_dict_[self._query_key] = \ + self.query_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + + if self.use_context_model: + state_dict_[self._context_key] = \ + self.context_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Load the state dicts of each of the models""" + if self.biencoder_shared_query_context_model: + print_rank_0("Loading shared query-context model") + self.model.load_state_dict(state_dict[self._model_key], \ + strict=strict) + else: + if self.use_query_model: + print_rank_0("Loading query model") + self.query_model.load_state_dict( \ + state_dict[self._query_key], strict=strict) + + if self.use_context_model: + print_rank_0("Loading context model") + self.context_model.load_state_dict( \ + state_dict[self._context_key], strict=strict) + + def init_state_dict_from_bert(self): + """Initialize the state from a pretrained BERT model + on iteration zero of ICT pretraining""" + args = get_args() + + if args.bert_load is None: + print_rank_0("bert-load argument is None") + return + + tracker_filename = get_checkpoint_tracker_filename(args.bert_load) + if not os.path.isfile(tracker_filename): + raise FileNotFoundError("Could not find BERT checkpoint") + with open(tracker_filename, 'r') as f: + iteration = int(f.read().strip()) + assert iteration > 0 + + checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False) + if mpu.get_data_parallel_rank() == 0: + print('global rank {} is loading BERT checkpoint {}'.format( + torch.distributed.get_rank(), checkpoint_name)) + + # Load the checkpoint. + try: + state_dict = torch.load(checkpoint_name, map_location='cpu') + except ModuleNotFoundError: + from megatron.fp16_deprecated import loss_scaler + # For backward compatibility. + print_rank_0(' > deserializing using the old code structure ...') + sys.modules['fp16.loss_scaler'] = sys.modules[ + 'megatron.fp16_deprecated.loss_scaler'] + sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ + 'megatron.fp16_deprecated.loss_scaler'] + state_dict = torch.load(checkpoint_name, map_location='cpu') + sys.modules.pop('fp16.loss_scaler', None) + sys.modules.pop('megatron.fp16.loss_scaler', None) + except BaseException: + print_rank_0('could not load the BERT checkpoint') + sys.exit() + + checkpoint_version = state_dict.get('checkpoint_version', 0) + + # load the LM state dict into each model + model_dict = state_dict['model']['language_model'] + + if self.biencoder_shared_query_context_model: + self.model.language_model.load_state_dict(model_dict) + fix_query_key_value_ordering(self.model, checkpoint_version) + else: + if self.use_query_model: + self.query_model.language_model.load_state_dict(model_dict) + # give each model the same ict_head to begin with as well + if self.biencoder_projection_dim > 0: + query_proj_state_dict = \ + self.state_dict_for_save_checkpoint()\ + [self._query_key]['projection_enc'] + fix_query_key_value_ordering(self.query_model, checkpoint_version) + + if self.use_context_model: + self.context_model.language_model.load_state_dict(model_dict) + if self.query_model is not None and \ + self.biencoder_projection_dim > 0: + self.context_model.projection_enc.load_state_dict\ + (query_proj_state_dict) + fix_query_key_value_ordering(self.context_model, checkpoint_version) + + +class PretrainedBertModel(MegatronModule): + """BERT-based encoder for queries or contexts used for + learned information retrieval.""" + + def __init__(self, num_tokentypes=2, + parallel_output=True, pre_process=True, post_process=True): + super(PretrainedBertModel, self).__init__() + + args = get_args() + tokenizer = get_tokenizer() + self.pad_id = tokenizer.pad + self.biencoder_projection_dim = args.biencoder_projection_dim + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + init_method = init_method_normal(args.init_method_std) + scaled_init_method = scaled_init_method_normal( + args.init_method_std, args.num_layers) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=False, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method, + pre_process=self.pre_process, + post_process=self.post_process) + + if args.biencoder_projection_dim > 0: + self.projection_enc = get_linear_layer(args.hidden_size, + args.biencoder_projection_dim, + init_method) + self._projection_enc_key = 'projection_enc' + + def forward(self, input_ids, attention_mask, tokentype_ids=None): + extended_attention_mask = attention_mask.unsqueeze(1) + #extended_attention_mask = bert_extended_attention_mask(attention_mask) + position_ids = bert_position_ids(input_ids) + + lm_output = self.language_model(input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids) + # This mask will be used in average-pooling and max-pooling + pool_mask = (input_ids == self.pad_id).unsqueeze(2) + + # Taking the representation of the [CLS] token of BERT + pooled_output = lm_output[0, :, :] + + # Converting to float16 dtype + pooled_output = pooled_output.to(lm_output.dtype) + + # Output. + if self.biencoder_projection_dim: + pooled_output = self.projection_enc(pooled_output) + + return pooled_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + + if self.biencoder_projection_dim > 0: + state_dict_[self._projection_enc_key] = \ + self.projection_enc.state_dict(destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + print_rank_0("loading pretrained weights") + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + + if self.biencoder_projection_dim > 0: + print_rank_0("loading projection head weights") + self.projection_enc.load_state_dict( + state_dict[self._projection_enc_key], strict=strict) diff --git a/megatron/model/classification.py b/megatron/model/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..d975072f7733c76db13f52b1f1820aad63830914 --- /dev/null +++ b/megatron/model/classification.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classification model.""" + +import torch + +from megatron import get_args, print_rank_last +from megatron import mpu +from megatron.model.enums import AttnMaskType +from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids +from megatron.model.language_model import get_language_model +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from .module import MegatronModule + + +class Classification(MegatronModule): + + def __init__(self, + num_classes, + num_tokentypes=2, + pre_process=True, + post_process=True): + super(Classification, self).__init__(share_word_embeddings=False) + args = get_args() + + self.num_classes = num_classes + self.pre_process = pre_process + self.post_process = post_process + init_method = init_method_normal(args.init_method_std) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=True, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method_normal(args.init_method_std, + args.num_layers), + pre_process=self.pre_process, + post_process=self.post_process) + + # Multi-choice head. + if self.post_process: + self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) + self.classification_head = get_linear_layer(args.hidden_size, + self.num_classes, + init_method) + self._classification_head_key = 'classification_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, model_input, attention_mask, tokentype_ids=None): + + extended_attention_mask = bert_extended_attention_mask(attention_mask) + input_ids = model_input + position_ids = bert_position_ids(input_ids) + + lm_output = self.language_model( + input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids + ) + + if self.post_process: + _, pooled_output = lm_output + classification_output = self.classification_dropout(pooled_output) + classification_logits = self.classification_head(classification_output) + + # Reshape back to separate choices. + classification_logits = classification_logits.view(-1, self.num_classes) + + return classification_logits + return lm_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.post_process: + state_dict_[self._classification_head_key] \ + = self.classification_head.state_dict( + destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + if self.post_process: + if self._classification_head_key in state_dict: + self.classification_head.load_state_dict( + state_dict[self._classification_head_key], strict=strict) + else: + print_rank_last('***WARNING*** could not find {} in the checkpoint, ' + 'initializing to random'.format( + self._classification_head_key)) diff --git a/megatron/model/distributed.py b/megatron/model/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..726ea714627f22f319724b33264da50a98d78637 --- /dev/null +++ b/megatron/model/distributed.py @@ -0,0 +1,246 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from abc import abstractmethod +import math + +import torch +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from megatron import get_args +from megatron import mpu +from .module import MegatronModule + + +class MemoryBuffer: + + def __init__(self, numel, numel_padded, dtype): + self.numel = numel + self.numel_padded = numel_padded + self.dtype = dtype + self.data = torch.zeros(self.numel_padded, + dtype=self.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + + def zero(self): + """Reset the buffer to zero.""" + self.data.zero_() + + + def get(self, shape, start_index): + """Return a tensor with the input `shape` as a view into the + 1-D data starting at `start_index`.""" + end_index = start_index + shape.numel() + assert end_index <= self.numel, \ + 'requested tensor is out of the buffer range.' + buffer_tensor = self.data[start_index:end_index] + buffer_tensor = buffer_tensor.view(shape) + return buffer_tensor + + + +class DistributedDataParallelBase(MegatronModule, ABC): + """Abstract class for DDP.""" + + def __init__(self, module): + super(DistributedDataParallelBase, self).__init__() + # Keep a pointer to the model. + self.module = module + + + @abstractmethod + def allreduce_gradients(self): + pass + + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination, prefix, keep_vars) + + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + return self.module.state_dict_for_save_checkpoint(destination, prefix, + keep_vars) + + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) + + + +class DistributedDataParallel(DistributedDataParallelBase): + """DDP with contiguous buffers options to storre and accumulate gradients. + This class: + - has the potential to reduce memory fragmentation. + - provides the option to do the gradient accumulation + in a type other than the params type (for example fp32) + + Arguments: + module: input model. + accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation + and the gradient all-reduce all in in float32. If this option is + true, we require `use_contiguous_buffers` to be true too. + use_contiguous_buffers: if true, use a contiguous buffer to store the + gradients. + """ + + def __init__(self, module, + accumulate_allreduce_grads_in_fp32, + use_contiguous_buffers): + + super(DistributedDataParallel, self).__init__(module) + + self.accumulate_allreduce_grads_in_fp32 \ + = accumulate_allreduce_grads_in_fp32 + self.use_contiguous_buffers = use_contiguous_buffers + # If we are using fp32-accumulate-allreduce explicitly + # this means we need main grads in a continous buffer. + if self.accumulate_allreduce_grads_in_fp32: + assert self.use_contiguous_buffers + + # =================================== + # Rest of this part applies only to + # the case we use continuous buffers. + # =================================== + self._grad_buffers = None + self._grad_buffer_param_index_map = None + if self.use_contiguous_buffers: + self._grad_buffers = {} + self._grad_buffer_param_index_map = {} + data_parallel_world_size = mpu.get_data_parallel_world_size() + + # Simple function to define buffer type. + def _get_buffer_type(param): + return torch.float if \ + self.accumulate_allreduce_grads_in_fp32 else param.dtype + + # First calculate total number of elements per type. + type_num_elements = {} + for param in self.module.parameters(): + if param.requires_grad: + dtype = _get_buffer_type(param) + type_num_elements[dtype] = type_num_elements.get(dtype, 0) \ + + param.data.nelement() + + # Allocate the buffer. + for dtype, num_elements in type_num_elements.items(): + + # If using distributed optimizer, pad memory buffer to be + # multiple of data_parallel_world_size. (This padding is done + # due to a constraint with the reduce_scatter op, which requires + # all tensors have equal size. See: optimizer.py.) + num_elements_padded = data_parallel_world_size * \ + int(math.ceil(num_elements / data_parallel_world_size)) + + # Allocate grad buffer. + self._grad_buffers[dtype] = MemoryBuffer(num_elements, + num_elements_padded, + dtype) + + # Assume the back prop order is reverse the params order, + # store the start index for the gradients. + for param in self.module.parameters(): + if param.requires_grad: + dtype = _get_buffer_type(param) + type_num_elements[dtype] -= param.data.nelement() + param.main_grad = self._grad_buffers[dtype].get( + param.data.shape, type_num_elements[dtype]) + if dtype not in self._grad_buffer_param_index_map: + self._grad_buffer_param_index_map[dtype] = {} + self._grad_buffer_param_index_map[dtype][param] = ( + type_num_elements[dtype], + type_num_elements[dtype] + param.data.nelement(), + ) + + # Backward hook. + # Accumalation function for the gradients. We need + # to store them so they don't go out of scope. + self.grad_accs = [] + # Loop over all the parameters in the model. + for param in self.module.parameters(): + if param.requires_grad: + # Expand so we get access to grad_fn. + param_tmp = param.expand_as(param) + # Get the gradient accumulator functtion. + grad_acc = param_tmp.grad_fn.next_functions[0][0] + grad_acc.register_hook(self._make_param_hook(param)) + self.grad_accs.append(grad_acc) + + + def _make_param_hook(self, param): + """Create the all-reduce hook for backprop.""" + # Hook used for back-prop. + def param_hook(*unused): + # Add the gradient to the buffer. + if param.grad is not None: + # The gradient function of linear layers is fused with GEMMs + param.main_grad.add_(param.grad.data) + # Now we can deallocate grad memory. + param.grad = None + return param_hook + + + def zero_grad_buffer(self): + """Set the grad buffer data to zero. Needs to be called at the + begining of each iteration.""" + assert self._grad_buffers is not None, 'buffers are not initialized.' + for _, buffer_ in self._grad_buffers.items(): + buffer_.zero() + + + def broadcast_params(self): + for param in self.module.parameters(): + torch.distributed.broadcast(param.data, + src=mpu.get_data_parallel_src_rank(), + group=mpu.get_data_parallel_group()) + + + def allreduce_gradients(self): + """Reduce gradients across data parallel ranks.""" + # If we have buffers, simply reduce the data in the buffer. + if self._grad_buffers is not None: + for _, buffer_ in self._grad_buffers.items(): + buffer_.data /= mpu.get_data_parallel_world_size() + torch.distributed.all_reduce( + buffer_.data, group=mpu.get_data_parallel_group()) + else: + # Otherwise, bucketize and all-reduce + buckets = {} + # Pack the buckets. + for param in self.module.parameters(): + if param.requires_grad and param.grad is not None: + tp = param.data.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + param.main_grad = param.grad + + # For each bucket, all-reduce and copy all-reduced grads. + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + coalesced /= mpu.get_data_parallel_world_size() + torch.distributed.all_reduce( + coalesced, group=mpu.get_data_parallel_group()) + for buf, synced in zip(grads, _unflatten_dense_tensors( + coalesced, grads)): + buf.copy_(synced) diff --git a/megatron/model/enums.py b/megatron/model/enums.py new file mode 100644 index 0000000000000000000000000000000000000000..492d2c0c6ca8cc4c6b9e2f8dbdfe5424439d4912 --- /dev/null +++ b/megatron/model/enums.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum + +class ModelType(enum.Enum): + encoder_or_decoder = 1 + encoder_and_decoder = 2 + +class LayerType(enum.Enum): + encoder = 1 + decoder = 2 + +class AttnType(enum.Enum): + self_attn = 1 + cross_attn = 2 + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 diff --git a/megatron/model/fused_bias_gelu.py b/megatron/model/fused_bias_gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..207071d6ebc21b4a183fc0f5834b6c1783516f0e --- /dev/null +++ b/megatron/model/fused_bias_gelu.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +###### BIAS GELU FUSION/ NO AUTOGRAD ################ +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) + +@torch.jit.script +def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_back(g, bias, y): + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff*g + +class GeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(bias, input) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_back(grad_output, bias, input) + return tmp, tmp + +bias_gelu_impl = GeLUFunction.apply diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..53f3fd516a6d149cd4e54407ce6d404b86c4bfcd --- /dev/null +++ b/megatron/model/fused_layer_norm.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This code is copied fron NVIDIA apex: + https://github.com/NVIDIA/apex + with some changes. """ + +import numbers +import torch +from torch.nn.parameter import Parameter +from torch.nn import init +import importlib + +from megatron.mpu import make_viewless_tensor + +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNormFN + HAVE_PERSIST_LAYER_NORM = True +except: + HAVE_PERSIST_LAYER_NORM = False + +global fused_mix_prec_layer_norm_cuda +fused_mix_prec_layer_norm_cuda = None + + +class FusedLayerNormAffineFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + + return output + + + @staticmethod + def backward(ctx, grad_output): + + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias \ + = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) + + return grad_input, grad_weight, grad_bias, None, None + + + +class MixedFusedLayerNorm(torch.nn.Module): + + def __init__(self, normalized_shape, eps=1e-5, + no_persist_layer_norm=True, + sequence_parallel=False): + super(MixedFusedLayerNorm, self).__init__() + + global fused_mix_prec_layer_norm_cuda + fused_mix_prec_layer_norm_cuda = importlib.import_module( + "fused_mix_prec_layer_norm_cuda") + + # List of hiddens sizes supported in the persistent layer norm kernel + # If the hidden size is not supported, fall back to the non-persistent + # kernel. + persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096, + 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536] + if normalized_shape not in persist_ln_hidden_sizes or \ + not HAVE_PERSIST_LAYER_NORM: + no_persist_layer_norm = True + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.weight = Parameter(torch.Tensor(*normalized_shape)) + self.bias = Parameter(torch.Tensor(*normalized_shape)) + self.reset_parameters() + self.no_persist_layer_norm = no_persist_layer_norm + self.sequence_parallel = sequence_parallel + + # set sequence parallelism flag on weight and bias parameters + setattr(self.weight, 'sequence_parallel', self.sequence_parallel) + setattr(self.bias, 'sequence_parallel', self.sequence_parallel) + + + def reset_parameters(self): + + init.ones_(self.weight) + init.zeros_(self.bias) + + + def forward(self, input): + + if self.no_persist_layer_norm: + return FusedLayerNormAffineFunction.apply( + input, self.weight, self.bias, self.normalized_shape, self.eps) + else: + output = FastLayerNormFN.apply( + input, self.weight, self.bias, self.eps) + + # Apex's fast layer norm function outputs a 'view' tensor (i.e., has + # a populated '_base' field). This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + output = make_viewless_tensor(inp = output, + requires_grad = input.requires_grad, + keep_graph = True) + + return output diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..2409edd59f9501b9ba1b7b093086bb775a35cf06 --- /dev/null +++ b/megatron/model/fused_softmax.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from megatron.model.enums import AttnMaskType + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + import scaled_upper_triang_masked_softmax_cuda + + scale_t = torch.tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( + inputs, scale_t[0] + ) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_upper_triang_masked_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_upper_triang_masked_softmax_cuda.backward( + output_grads, softmax_results, scale_t[0] + ) + + return input_grads, None + + +class ScaledMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, mask, scale): + import scaled_masked_softmax_cuda + + scale_t = torch.tensor([scale]) + + softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_masked_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_masked_softmax_cuda.backward( + output_grads, softmax_results, scale_t[0] + ) + return input_grads, None, None + + +class ScaledSoftmax(torch.autograd.Function): + """ + Fused operation which performs following two operations in sequence + 1. Scale the tensor. + 2. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + import scaled_softmax_cuda + + scale_t = torch.tensor([scale]) + + softmax_results = scaled_softmax_cuda.forward( + inputs, scale_t[0] + ) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_softmax_cuda.backward( + output_grads, softmax_results, scale_t[0] + ) + return input_grads, None, None + + +class FusedScaleMaskSoftmax(nn.Module): + """ + fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: flag to indicate if input in fp16 data format. + input_in_bf16: flag to indicate if input in bf16 data format. + attn_mask_type: attention mask type (pad or causal) + scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion + mask_func: mask function to be applied. + softmax_in_fp32: if true, softmax in performed at fp32 precision. + scale: scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super(FusedScaleMaskSoftmax, self).__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + + assert ( + self.scale is None or softmax_in_fp32 + ), "softmax should be in fp32 when scaled" + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and 16 < sk <= 4096 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 4096: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + + if self.attn_mask_type == AttnMaskType.causal: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + if mask is not None: + return ScaledMaskedSoftmax.apply(input, mask, scale) + else: + return ScaledSoftmax.apply(input, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + @staticmethod + def get_batch_per_block(sq, sk, b, np): + import scaled_masked_softmax_cuda + + return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) diff --git a/megatron/model/glm_model.py b/megatron/model/glm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..db59551e64d33b816c6a4c54d664c0809756221b --- /dev/null +++ b/megatron/model/glm_model.py @@ -0,0 +1,298 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BERT model.""" + +import torch +import torch.nn.functional as F + +from megatron import get_args +from megatron import mpu +from megatron.model.enums import AttnMaskType +from megatron.model.language_model import parallel_lm_logits +from megatron.model.language_model import get_language_model +from megatron.model import LayerNorm +from megatron.model.utils import openai_gelu, erf_gelu +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from .module import MegatronModule + +def bert_extended_attention_mask(attention_mask): + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = attention_mask.unsqueeze(2) + # [b, s, s] + attention_mask_bss = attention_mask_b1s * attention_mask_bs1 + # [b, 1, s, s] + extended_attention_mask = attention_mask_bss.unsqueeze(1) + + # Convert attention mask to binary: + extended_attention_mask = (extended_attention_mask < 0.5) + + return extended_attention_mask + +def bert_position_ids(token_ids): + # Create position ids + seq_length = token_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, + device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + +class BertLMHead(MegatronModule): + """Masked LM head for Bert + + Arguments: + mpu_vocab_size: model parallel size of vocabulary. + hidden_size: hidden size + init_method: init method for weight initialization + layernorm_epsilon: tolerance for layer norm divisions + parallel_output: whether output logits being distributed or not. + """ + + def __init__(self, mpu_vocab_size, hidden_size, init_method, + layernorm_epsilon, parallel_output): + + super(BertLMHead, self).__init__() + + args = get_args() + + self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) + mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + self.parallel_output = parallel_output + + self.dense = get_linear_layer(hidden_size, hidden_size, init_method) + setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel) + setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel) + + self.layernorm = LayerNorm(hidden_size, + eps=layernorm_epsilon, + sequence_parallel=args.sequence_parallel) + self.gelu = torch.nn.functional.gelu + if args.openai_gelu: + self.gelu = openai_gelu + elif args.onnx_safe: + self.gelu = erf_gelu + + def forward(self, hidden_states, word_embeddings_weight): + hidden_states = self.dense(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = self.layernorm(hidden_states) + output = parallel_lm_logits(hidden_states, + word_embeddings_weight, + self.parallel_output, + bias=self.bias) + return output + +class GlmModel(MegatronModule): + """Bert Language model.""" + + def __init__(self, + num_tokentypes=2, + add_binary_head=True, + parallel_output=True, + pre_process=True, + post_process=True, + add_encoder=True, + add_decoder=True): + super(GlmModel, self).__init__() + args = get_args() + + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + self.add_binary_head = add_binary_head + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder + + self.cal_encoder_loss = True + self.cal_decoder_loss = True + self.cal_sent_loss = True + + init_method = init_method_normal(args.init_method_std) + scaled_init_method = scaled_init_method_normal(args.init_method_std, + args.num_layers) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=self.add_binary_head, + add_encoder=self.add_encoder, + add_decoder=self.add_decoder, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method, + pre_process=self.pre_process, + post_process=self.post_process) + + self.initialize_word_embeddings(init_method_normal) + if self.post_process: + self.lm_head = BertLMHead( + self.word_embeddings_weight().size(0), + args.hidden_size, init_method, args.layernorm_epsilon, parallel_output) + self._lm_head_key = 'lm_head' + self.binary_head = None + if self.add_binary_head: + self.binary_head = get_linear_layer(args.hidden_size, 2, + init_method) + self._binary_head_key = 'binary_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, + encoder_input_ids, + encoder_labels, + decoder_input_ids, + decoder_labels, + encoder_attn_mask, + decoder_attn_mask, + encoder_decoder_attn_mask, + tokentype_ids=None, + lm_labels=None, + sentence_labels=None, + enc_hidden_states=None, + ): + + encoder_position_ids = bert_position_ids(encoder_input_ids) + decoder_position_ids = bert_position_ids(decoder_input_ids) + extended_encoder_attn_mask = bert_extended_attention_mask(encoder_attn_mask) + extended_decoder_attn_mask = bert_extended_attention_mask(decoder_attn_mask) + extended_encoder_decoder_attn_mask = bert_extended_attention_mask(encoder_decoder_attn_mask) + + lm_output = self.language_model(encoder_input_ids, + encoder_position_ids, + extended_encoder_attn_mask, + decoder_input_ids, + decoder_position_ids, + extended_decoder_attn_mask, + extended_encoder_decoder_attn_mask, + tokentype_ids=tokentype_ids, + enc_hidden_states=enc_hidden_states) + + encoder_outputs, decoder_outputs, pooled_outputs = lm_output + + return self.post_language_model_processing(encoder_input_ids, + encoder_outputs, + encoder_labels, + pooled_outputs, + self.lm_head, + self.binary_head, + self.word_embeddings_weight(), + self.fp16_lm_cross_entropy, + decoder_input_ids, + decoder_outputs, + decoder_labels, + decoder_attn_mask, + sentence_labels) + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.post_process: + state_dict_[self._lm_head_key] \ + = self.lm_head.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.post_process and self.add_binary_head: + state_dict_[self._binary_head_key] \ + = self.binary_head.state_dict(destination, prefix, keep_vars) + # Save word_embeddings. + if self.post_process and not self.pre_process: + state_dict_[self._word_embeddings_for_head_key] \ + = self.word_embeddings.state_dict(destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + if self.post_process: + self.lm_head.load_state_dict( + state_dict[self._lm_head_key], strict=strict) + if self.post_process and self.add_binary_head: + self.binary_head.load_state_dict( + state_dict[self._binary_head_key], strict=strict) + # Load word_embeddings. + if self.post_process and not self.pre_process: + self.word_embeddings.load_state_dict( + state_dict[self._word_embeddings_for_head_key], strict=strict) + + def post_language_model_processing(self, encoder_inputs, + encoder_outputs, + encoder_labels, + pooled_outputs, + lm_head, + binary_head, + logit_weights, + fp16_lm_cross_entropy, + decoder_inputs, + decoder_outputs, + decoder_labels, + attention_mask, + sentence_labels + ): + # Output. + describe_tensor = { + "attention_mask" : attention_mask[0] + } + + sent_loss = None + if self.cal_sent_loss: + binary_logits = binary_head(pooled_outputs) + sent_loss = F.cross_entropy(binary_logits.view(-1, 2).float(), + sentence_labels.view(-1), + ignore_index=-1) + + encoder_loss = None + if self.cal_encoder_loss: + encoder_outputs_logits_ = lm_head(encoder_outputs, logit_weights) + if fp16_lm_cross_entropy: + assert encoder_outputs_logits_.dtype == torch.half + encoder_outputs_logits = encoder_outputs_logits_ + else: + encoder_outputs_logits = encoder_outputs_logits_.float() + encoder_loss = mpu.vocab_parallel_cross_entropy(encoder_outputs_logits, encoder_labels.transpose(0, 1).contiguous()) + describe_tensor["encoder_inputs"] = encoder_inputs[0] + describe_tensor["encoder_outputs"] = encoder_outputs_logits[0].transpose(0, 1) + describe_tensor["encoder_labels"] = encoder_labels[0] + + decoder_loss = None + if self.cal_decoder_loss: + decoder_outputs_logits_ = decoder_outputs + if fp16_lm_cross_entropy: + assert decoder_outputs_logits_.dtype == torch.half + decoder_outputs_logits = decoder_outputs_logits_ + else: + decoder_outputs_logits = decoder_outputs_logits_.float() + decoder_loss = mpu.vocab_parallel_cross_entropy(decoder_outputs_logits, decoder_labels.transpose(0, 1).contiguous()) + describe_tensor["encoder_inputs"] = encoder_inputs[0] + describe_tensor["decoder_inputs"] = decoder_inputs[0] + describe_tensor["decoder_outputs"] = decoder_outputs_logits[0].transpose(0, 1) + describe_tensor["decoder_labels"] = decoder_labels[0] + + return encoder_loss, decoder_loss, sent_loss, describe_tensor \ No newline at end of file diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py new file mode 100644 index 0000000000000000000000000000000000000000..af6b5bf12edc0fdae00e88801402d31ddae72226 --- /dev/null +++ b/megatron/model/gpt_model.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPT-2 model.""" + +import torch + +from megatron import get_args +from megatron import mpu +from .module import MegatronModule + +from .enums import AttnMaskType +from .language_model import parallel_lm_logits +from .language_model import get_language_model +from .utils import init_method_normal +from .utils import scaled_init_method_normal + + +def post_language_model_processing(lm_output, labels, logit_weights, + parallel_output, + fp16_lm_cross_entropy): + + # Output. Format [s b h] + output = parallel_lm_logits( + lm_output, + logit_weights, + parallel_output) + + if labels is None: + # [s b h] => [b s h] + return output.transpose(0,1).contiguous() + else: + # [b s] => [s b] + labels = labels.transpose(0,1).contiguous() + if fp16_lm_cross_entropy: + assert output.dtype == torch.half + loss = mpu.vocab_parallel_cross_entropy(output, labels) + else: + loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) + + # [s b] => [b, s] + loss = loss.transpose(0,1).contiguous() + return loss + + +class GPTModel(MegatronModule): + """GPT-2 Language model.""" + + def __init__(self, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True): + super(GPTModel, self).__init__() + args = get_args() + + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=False, + encoder_attn_mask_type=AttnMaskType.causal, + init_method=init_method_normal(args.init_method_std), + scaled_init_method=scaled_init_method_normal(args.init_method_std, + args.num_layers), + pre_process=self.pre_process, + post_process=self.post_process) + + self.initialize_word_embeddings(init_method_normal) + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, input_ids, position_ids, attention_mask, labels=None, + tokentype_ids=None, inference_params=None): + + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + inference_params=inference_params) + + if self.post_process: + return post_language_model_processing( + lm_output, labels, + self.word_embeddings_weight(), + self.parallel_output, + self.fp16_lm_cross_entropy) + else: + return lm_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + # Save word_embeddings. + if self.post_process and not self.pre_process: + state_dict_[self._word_embeddings_for_head_key] \ + = self.word_embeddings.state_dict(destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Load word_embeddings. + if self.post_process and not self.pre_process: + self.word_embeddings.load_state_dict( + state_dict[self._word_embeddings_for_head_key], strict=strict) + if self._language_model_key in state_dict: + state_dict = state_dict[self._language_model_key] + self.language_model.load_state_dict(state_dict, strict=strict) diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac3e4a32c51c584323e84ccfba731b9421a706d --- /dev/null +++ b/megatron/model/language_model.py @@ -0,0 +1,563 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer based language model.""" + +import torch +import torch.nn.functional as F + +from megatron import get_args +from megatron import mpu +from .module import MegatronModule +from megatron.model.enums import LayerType, AttnMaskType +from megatron.model.transformer import ParallelTransformer +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal, scaled_init_method_normal + + +def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, + bias=None): + """LM logits using word embedding weights.""" + args = get_args() + # Parallel logits. + if args.async_tensor_model_parallel_allreduce or\ + args.sequence_parallel: + input_parallel = input_ + model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 + async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ + model_parallel and not args.sequence_parallel + else: + input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) + async_grad_allreduce = False + + # Matrix multiply. + logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( + input_parallel, word_embeddings_weight, bias, + args.gradient_accumulation_fusion, + async_grad_allreduce, args.sequence_parallel) + # Gather if needed. + + if parallel_output: + return logits_parallel + + return mpu.gather_from_tensor_model_parallel_region(logits_parallel) + + +def get_language_model(num_tokentypes, add_pooler, + encoder_attn_mask_type, init_method=None, + scaled_init_method=None, add_encoder=True, + add_decoder=False, + decoder_attn_mask_type=AttnMaskType.causal, + pre_process=True, post_process=True): + """Build language model and return along with the key to save.""" + args = get_args() + + if init_method is None: + init_method = init_method_normal(args.init_method_std) + + if scaled_init_method is None: + scaled_init_method = scaled_init_method_normal(args.init_method_std, + args.num_layers) + + # Language model. + language_model = TransformerLanguageModel( + init_method, + scaled_init_method, + encoder_attn_mask_type, + num_tokentypes=num_tokentypes, + add_encoder=add_encoder, + add_decoder=add_decoder, + decoder_attn_mask_type=decoder_attn_mask_type, + add_pooler=add_pooler, + pre_process=pre_process, + post_process=post_process + ) + # key used for checkpoints. + language_model_key = 'language_model' + + return language_model, language_model_key + + +class Pooler(MegatronModule): + """Pooler layer. + + Pool hidden states of a specific token (for example start of the + sequence) and add a linear transformation followed by a tanh. + + Arguments: + hidden_size: hidden size + init_method: weight initialization method for the linear layer. + bias is set to zero. + """ + + def __init__(self, hidden_size, init_method): + super(Pooler, self).__init__() + args = get_args() + self.dense = get_linear_layer(hidden_size, hidden_size, init_method) + self.sequence_parallel = args.sequence_parallel + + + def forward(self, hidden_states, sequence_index=0): + # hidden_states: [s, b, h] + # sequence_index: index of the token to pool. + + # gather data along sequence dimensions + # same pooler is run on all tensor parallel nodes + if self.sequence_parallel: + hidden_states = mpu.gather_from_sequence_parallel_region( + hidden_states, + tensor_parallel_output_grad=False) + + pooled = hidden_states[sequence_index, :, :] + pooled = self.dense(pooled) + pooled = torch.tanh(pooled) + return pooled + + +class Embedding(MegatronModule): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, + hidden_size, + vocab_size, + max_sequence_length, + embedding_dropout_prob, + init_method, + num_tokentypes=0): + super(Embedding, self).__init__() + + self.hidden_size = hidden_size + self.init_method = init_method + self.num_tokentypes = num_tokentypes + + args = get_args() + + # Word embeddings (parallel). + self.word_embeddings = mpu.VocabParallelEmbedding( + vocab_size, self.hidden_size, + init_method=self.init_method) + self._word_embeddings_key = 'word_embeddings' + + # Position embedding (serial). + self.position_embeddings = torch.nn.Embedding( + max_sequence_length, self.hidden_size) + self._position_embeddings_key = 'position_embeddings' + # Initialize the position embeddings. + if args.perform_initialization: + self.init_method(self.position_embeddings.weight) + + # Token type embedding. + # Add this as an optional field that can be added through + # method call so we can load a pretrain model without + # token types and add them as needed. + self._tokentype_embeddings_key = 'tokentype_embeddings' + if self.num_tokentypes > 0: + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, + self.hidden_size) + # Initialize the token-type embeddings. + if args.perform_initialization: + self.init_method(self.tokentype_embeddings.weight) + else: + self.tokentype_embeddings = None + + self.fp32_residual_connection = args.fp32_residual_connection + self.sequence_parallel = args.sequence_parallel + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) + + def zero_parameters(self): + """Zero out all parameters in embedding.""" + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + self.position_embeddings.weight.data.fill_(0) + self.position_embeddings.weight.shared = True + if self.num_tokentypes > 0: + self.tokentype_embeddings.weight.data.fill_(0) + self.tokentype_embeddings.weight.shared = True + + def add_tokentype_embeddings(self, num_tokentypes): + """Add token-type embedding. This function is provided so we can add + token-type embeddings in case the pretrained model does not have it. + This allows us to load the model normally and then add this embedding. + """ + if self.tokentype_embeddings is not None: + raise Exception('tokentype embeddings is already initialized') + if torch.distributed.get_rank() == 0: + print('adding embedding for {} tokentypes'.format(num_tokentypes), + flush=True) + self.num_tokentypes = num_tokentypes + self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, + self.hidden_size) + # Initialize the token-type embeddings. + args = get_args() + self.init_method(self.tokentype_embeddings.weight) + + def forward(self, input_ids, position_ids, tokentype_ids=None): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + if tokentype_ids is not None: + assert self.tokentype_embeddings is not None + embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) + else: + # assert self.tokentype_embeddings is None + # print("self.tokentype_embeddings is None, but tokentype_ids is not None ") + pass + + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + + # Dropout. + if self.sequence_parallel: + embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) + with mpu.get_cuda_rng_tracker().fork(): + embeddings = self.embedding_dropout(embeddings) + else: + embeddings = self.embedding_dropout(embeddings) + + return embeddings + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._word_embeddings_key] \ + = self.word_embeddings.state_dict(destination, prefix, keep_vars) + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict( + destination, prefix, keep_vars) + if self.num_tokentypes > 0: + state_dict_[self._tokentype_embeddings_key] \ + = self.tokentype_embeddings.state_dict( + destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Word embedding. + if self._word_embeddings_key in state_dict: + state_dict_ = state_dict[self._word_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'word_embeddings' in key: + state_dict_[key.split('word_embeddings.')[1]] \ + = state_dict[key] + self.word_embeddings.load_state_dict(state_dict_, strict=strict) + + # Position embedding. + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) + + # Tokentype embedding. + if self.num_tokentypes > 0: + state_dict_ = {} + if self._tokentype_embeddings_key in state_dict: + state_dict_ = state_dict[self._tokentype_embeddings_key] + else: + # for backward compatibility. + for key in state_dict.keys(): + if 'tokentype_embeddings' in key: + state_dict_[key.split('tokentype_embeddings.')[1]] \ + = state_dict[key] + if len(state_dict_.keys()) > 0: + self.tokentype_embeddings.load_state_dict(state_dict_, + strict=strict) + else: + print('***WARNING*** expected tokentype embeddings in the ' + 'checkpoint but could not find it', flush=True) + + +class TransformerLanguageModel(MegatronModule): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, + init_method, + output_layer_init_method, + encoder_attn_mask_type, + num_tokentypes=0, + add_encoder=True, + add_decoder=False, + decoder_attn_mask_type=AttnMaskType.causal, + add_pooler=False, + pre_process=True, + post_process=True): + super(TransformerLanguageModel, self).__init__() + args = get_args() + + self.pre_process = pre_process + self.post_process = post_process + self.hidden_size = args.hidden_size + self.num_tokentypes = num_tokentypes + self.init_method = init_method + self.add_encoder = add_encoder + self.encoder_attn_mask_type = encoder_attn_mask_type + self.add_decoder = add_decoder + self.decoder_attn_mask_type = decoder_attn_mask_type + self.add_pooler = add_pooler + self.encoder_hidden_state = None + + # Embeddings. + if self.pre_process: + self.embedding = Embedding(self.hidden_size, + args.padded_vocab_size, + args.max_position_embeddings, + args.hidden_dropout, + self.init_method, + self.num_tokentypes) + self._embedding_key = 'embedding' + + # Transformer. + # Encoder (usually set to True, False if part of an encoder-decoder + # architecture and in encoder-only stage). + if self.add_encoder: + self.encoder = ParallelTransformer( + self.init_method, + output_layer_init_method, + self_attn_mask_type=self.encoder_attn_mask_type, + trans_layer_type="encoder", + pre_process=self.pre_process, + post_process=self.post_process + ) + self._encoder_key = 'encoder' + else: + self.encoder = None + + # Decoder (usually set to False, True if part of an encoder-decoder + # architecture and in decoder-only stage). + if self.add_decoder: + self.decoder = ParallelTransformer( + self.init_method, + output_layer_init_method, + layer_type=LayerType.decoder, + self_attn_mask_type=self.decoder_attn_mask_type, + trans_layer_type="decoder", + pre_process=self.pre_process, + post_process=self.post_process) + self._decoder_key = 'decoder' + else: + self.decoder = None + + if self.post_process: + # Pooler. + if self.add_pooler: + self.pooler = Pooler(self.hidden_size, self.init_method) + self._pooler_key = 'pooler' + + def set_input_tensor(self, input_tensor): + """ See megatron.model.transformer.set_input_tensor()""" + + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + if self.add_encoder and self.add_decoder: + assert len(input_tensor) == 1, \ + 'input_tensor should only be length 1 for stage with both encoder and decoder' + self.encoder.set_input_tensor(input_tensor[0]) + elif self.add_encoder: + assert len(input_tensor) == 1, \ + 'input_tensor should only be length 1 for stage with only encoder' + self.encoder.set_input_tensor(input_tensor[0]) + elif self.add_decoder: + if len(input_tensor) == 2: + self.decoder.set_input_tensor(input_tensor[0]) + self.encoder_hidden_state = input_tensor[1] + elif len(input_tensor) == 1: + self.decoder.set_input_tensor(None) + self.encoder_hidden_state = input_tensor[0] + else: + raise Exception('input_tensor must have either length 1 or 2') + else: + raise Exception('Stage must have at least either encoder or decoder') + + def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, + dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, + enc_dec_attn_mask=None, tokentype_ids=None, + inference_params=None, + pooling_sequence_index=0, + enc_hidden_states=None, output_enc_hidden=False): + + # Encoder embedding. + if self.pre_process: + encoder_input = self.embedding(enc_input_ids, enc_position_ids, + tokentype_ids=tokentype_ids) + else: + encoder_input = None + + # Run encoder. + if enc_hidden_states is None: + if self.encoder is not None: + encoder_output = self.encoder( + encoder_input, + enc_attn_mask, + inference_params=inference_params) + else: + encoder_output = self.encoder_hidden_state + else: + encoder_output = enc_hidden_states.to(encoder_input.dtype) + + if self.post_process: + if self.add_pooler: + pooled_output = self.pooler(encoder_output, + pooling_sequence_index) + + # output_enc_hidden refers to when we just need the encoder's + # output. For example, it is helpful to compute + # similarity between two sequences by average pooling + if not self.add_decoder or output_enc_hidden: + if self.add_pooler and self.post_process: + return encoder_output, pooled_output + else: + return encoder_output + + # Decoder embedding. + if self.pre_process: + decoder_input = self.embedding(dec_input_ids, + dec_position_ids) + else: + decoder_input = None + + # Run decoder. + decoder_output = self.decoder( + decoder_input, + dec_attn_mask, + encoder_output=encoder_output, + enc_dec_attn_mask=enc_dec_attn_mask, + inference_params=inference_params) + + if self.add_pooler and self.post_process: + return decoder_output, encoder_output, pooled_output + else: + return decoder_output, encoder_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load.""" + + state_dict_ = {} + if self.pre_process: + state_dict_[self._embedding_key] \ + = self.embedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.add_encoder: + state_dict_[self._encoder_key] \ + = self.encoder.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.post_process: + if self.add_pooler: + state_dict_[self._pooler_key] \ + = self.pooler.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.add_decoder: + state_dict_[self._decoder_key] \ + = self.decoder.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Embedding. + if self.pre_process: + if self._embedding_key in state_dict: + state_dict_ = state_dict[self._embedding_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if '_embeddings' in key: + state_dict_[key] = state_dict[key] + self.embedding.load_state_dict(state_dict_, strict=strict) + + # Encoder. + if self.add_encoder: + if self._encoder_key in state_dict: + state_dict_ = state_dict[self._encoder_key] + # For backward compatibility. + elif 'transformer' in state_dict: + state_dict_ = state_dict['transformer'] + else: + # For backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'transformer.' in key: + state_dict_[key.split('transformer.')[1]] = state_dict[key] + + # For backward compatibility. + state_dict_self_attention = {} + for key in state_dict_.keys(): + if '.attention.' in key: + state_dict_self_attention[key.replace(".attention.", + ".self_attention.")] = state_dict_[key] + else: + state_dict_self_attention[key] = state_dict_[key] + state_dict_ = state_dict_self_attention + + self.encoder.load_state_dict(state_dict_, strict=strict) + + # Pooler. + if self.post_process: + if self.add_pooler: + assert 'pooler' in state_dict, \ + 'could not find data for pooler in the checkpoint' + self.pooler.load_state_dict(state_dict[self._pooler_key], + strict=strict) + # Decoder. + if self.add_decoder: + assert 'decoder' in state_dict, \ + 'could not find data for pooler in the checkpoint' + self.decoder.load_state_dict(state_dict[self._decoder_key], + strict=strict) diff --git a/megatron/model/module.py b/megatron/model/module.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a1ef05d2faa350cb1d7bdbd9f2fb820d1f478f --- /dev/null +++ b/megatron/model/module.py @@ -0,0 +1,212 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Module""" + +import torch +from torch.autograd import Variable +from torch.nn.parameter import Parameter + +from megatron import get_args +from megatron import mpu + + +_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) +_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) +_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor) + + + +def param_is_not_shared(param): + return not hasattr(param, 'shared') or not param.shared + + + +class MegatronModule(torch.nn.Module): + """Megatron specific extensions of torch Module with support + for pipelining.""" + + def __init__(self, share_word_embeddings=True): + super(MegatronModule, self).__init__() + self.share_word_embeddings = share_word_embeddings + + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """Use this function to override the state dict for + saving checkpoints.""" + return self.state_dict(destination, prefix, keep_vars) + + + def word_embeddings_weight(self): + if self.pre_process: + return self.language_model.embedding.word_embeddings.weight + else: + if not self.share_word_embeddings: + raise Exception('word_embeddings_weight() called for last ' + 'stage, but share_word_embeddings is false') + return self.word_embeddings.weight + + + def initialize_word_embeddings(self, init_method_normal): + args = get_args() + if not self.share_word_embeddings: + raise Exception('initialize_word_embeddings() was called but ' + 'share_word_embeddings is false') + + # This function just initializes the word embeddings in the final stage + # when we are using pipeline parallelism. Nothing to do if we aren't + # using pipeline parallelism. + if args.pipeline_model_parallel_size == 1: + return + + # Parameters are shared between the word embeddings layers, and the + # heads at the end of the model. In a pipelined setup with more than + # one stage, the initial embedding layer and the head are on different + # workers, so we do the following: + # 1. Create a second copy of word_embeddings on the last stage, with + # initial parameters of 0.0. + # 2. Do an all-reduce between the first and last stage to ensure that + # the two copies of word_embeddings start off with the same + # parameter values. + # 3. In the training loop, before an all-reduce between the grads of + # the two word_embeddings layers to ensure that every applied weight + # update is the same on both stages. + if mpu.is_pipeline_last_stage() and \ + not self.pre_process: + assert not mpu.is_pipeline_first_stage() + self._word_embeddings_for_head_key = 'word_embeddings_for_head' + # set word_embeddings weights to 0 here, then copy first + # stage's weights using all_reduce below. + self.word_embeddings = mpu.VocabParallelEmbedding( + args.padded_vocab_size, args.hidden_size, + init_method=init_method_normal(args.init_method_std)) + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + + # Zero out initial weights for decoder embedding. + # NOTE: We don't currently support T5 with the interleaved schedule. + if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \ + self.pre_process: + self.language_model.embedding.zero_parameters() + + if not torch.distributed.is_initialized(): + if not getattr(MegatronModule, "embedding_warning_printed", False): + print("WARNING! Distributed processes aren't initialized, so " + "word embeddings in the last layer are not initialized. " + "If you are just manipulating a model this is fine, but " + "this needs to be handled manually. If you are training " + "something is definitely wrong.") + MegatronModule.embedding_warning_printed = True + return + + # Ensure that first and last stages have the same initial parameter + # values. + if mpu.is_rank_in_embedding_group(): + torch.distributed.all_reduce(self.word_embeddings_weight().data, + group=mpu.get_embedding_group()) + + # Ensure that encoder(first stage) and decoder(split stage) position + # embeddings have the same initial parameter values + # NOTE: We don't currently support T5 with the interleaved schedule. + if mpu.is_rank_in_position_embedding_group() and \ + args.pipeline_model_parallel_split_rank is not None: + # TODO: Support tokentype embedding. + self.language_model.embedding.cuda() + position_embeddings = self.language_model.embedding.position_embeddings + torch.distributed.all_reduce(position_embeddings.weight.data, + group=mpu.get_position_embedding_group()) + + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` + #is a nested tuple/list structure.""" + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + + +def fp32_to_float16(val, float16_convertor): + """Convert fp32 `val` to fp16/bf16""" + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, _FLOAT_TYPES): + val = float16_convertor(val) + return val + return conversion_helper(val, half_conversion) + + +def float16_to_fp32(val): + """Convert fp16/bf16 `val` to fp32""" + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)): + val = val.float() + return val + return conversion_helper(val, float_conversion) + + + +class Float16Module(MegatronModule): + + def __init__(self, module, args): + super(Float16Module, self).__init__() + + if args.fp16: + self.add_module('module', module.half()) + def float16_convertor(val): + return val.half() + elif args.bf16: + self.add_module('module', module.bfloat16()) + def float16_convertor(val): + return val.bfloat16() + else: + raise Exception('should not be here') + + self.float16_convertor = float16_convertor + + + def set_input_tensor(self, input_tensor): + return self.module.set_input_tensor(input_tensor) + + + def forward(self, *inputs, **kwargs): + if mpu.is_pipeline_first_stage(): + inputs = fp32_to_float16(inputs, self.float16_convertor) + outputs = self.module(*inputs, **kwargs) + if mpu.is_pipeline_last_stage(): + outputs = float16_to_fp32(outputs) + return outputs + + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination, prefix, keep_vars) + + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + return self.module.state_dict_for_save_checkpoint(destination, prefix, + keep_vars) + + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) diff --git a/megatron/model/multiple_choice.py b/megatron/model/multiple_choice.py new file mode 100644 index 0000000000000000000000000000000000000000..c43bd969c0ded670e3eb794a8e630e36e7409eee --- /dev/null +++ b/megatron/model/multiple_choice.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiple choice model.""" + +import torch + +from megatron import get_args, print_rank_last +from megatron import mpu +from megatron.model.enums import AttnMaskType +from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids +from megatron.model.language_model import get_language_model +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from .module import MegatronModule + + +class MultipleChoice(MegatronModule): + + def __init__(self, + num_tokentypes=2, + pre_process=True, + post_process=True): + super(MultipleChoice, self).__init__(share_word_embeddings=False) + args = get_args() + + init_method = init_method_normal(args.init_method_std) + self.pre_process = pre_process + self.post_process = post_process + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=True, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method_normal(args.init_method_std, + args.num_layers), + pre_process=self.pre_process, + post_process=self.post_process) + + # Multi-choice head. + if self.post_process: + self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) + self.multichoice_head = get_linear_layer(args.hidden_size, 1, + init_method) + self._multichoice_head_key = 'multichoice_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, model_input, attention_mask, tokentype_ids=None): + + # [batch, choices, sequence] --> [batch * choices, sequence] --> + # transformer --> [batch, choices] --> softmax + + # Ensure the shape is [batch-size, choices, sequence] + assert len(attention_mask.shape) == 3 + num_choices = attention_mask.shape[1] + + # Reshape and treat choice dimension the same as batch. + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) + extended_attention_mask = bert_extended_attention_mask(attention_mask) + + input_ids = model_input + # Do the same as attention_mask for input_ids, tokentype_ids + assert len(input_ids.shape) == 3 + assert len(tokentype_ids.shape) == 3 + input_ids = input_ids.view(-1, input_ids.size(-1)) + tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1)) + position_ids = bert_position_ids(input_ids) + + lm_output = self.language_model( + input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids + ) + if self.post_process: + _, pooled_output = lm_output + multichoice_output = self.multichoice_dropout(pooled_output) + multichoice_logits = self.multichoice_head(multichoice_output) + + # Reshape back to separate choices. + multichoice_logits = multichoice_logits.view(-1, num_choices) + + return multichoice_logits + return lm_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.post_process: + state_dict_[self._multichoice_head_key] \ + = self.multichoice_head.state_dict( + destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + if self.post_process: + if self._multichoice_head_key in state_dict: + self.multichoice_head.load_state_dict( + state_dict[self._multichoice_head_key], strict=strict) + else: + print_rank_last('***WARNING*** could not find {} in the checkpoint, ' + 'initializing to random'.format( + self._multichoice_head_key)) diff --git a/megatron/model/realm_model.py b/megatron/model/realm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5730a85e36b39e2076fd6b8f8f73e26ae06e9f08 --- /dev/null +++ b/megatron/model/realm_model.py @@ -0,0 +1,204 @@ +import os +import torch + +from megatron import get_args, print_rank_0 +from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name +from megatron.model import BertModel +from .module import MegatronModule +from megatron import mpu +from megatron.model.enums import AttnMaskType +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.language_model import get_language_model +from megatron.model.utils import scaled_init_method_normal +from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids + + +def general_ict_model_provider(only_query_model=False, only_block_model=False): + """Build the model.""" + args = get_args() + assert args.ict_head_size is not None, \ + "Need to specify --ict-head-size to provide an ICTBertModel" + assert mpu.get_tensor_model_parallel_world_size() == 1 and mpu.get_pipeline_model_parallel_world_size() == 1, \ + "Model parallel size > 1 not supported for ICT" + + print_rank_0('building ICTBertModel...') + + # simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes + model = ICTBertModel( + ict_head_size=args.ict_head_size, + num_tokentypes=2, + parallel_output=True, + only_query_model=only_query_model, + only_block_model=only_block_model) + + return model + + +class ICTBertModel(MegatronModule): + """Bert-based module for Inverse Cloze task.""" + def __init__(self, + ict_head_size, + num_tokentypes=1, + parallel_output=True, + only_query_model=False, + only_block_model=False): + super(ICTBertModel, self).__init__() + bert_kwargs = dict( + ict_head_size=ict_head_size, + num_tokentypes=num_tokentypes, + parallel_output=parallel_output + ) + assert not (only_block_model and only_query_model) + self.use_block_model = not only_query_model + self.use_query_model = not only_block_model + + if self.use_query_model: + # this model embeds (pseudo-)queries - Embed_input in the paper + self.query_model = IREncoderBertModel(**bert_kwargs) + self._query_key = 'question_model' + + if self.use_block_model: + # this model embeds evidence blocks - Embed_doc in the paper + self.block_model = IREncoderBertModel(**bert_kwargs) + self._block_key = 'context_model' + + def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask): + """Run a forward pass for each of the models and return the respective embeddings.""" + query_logits = self.embed_query(query_tokens, query_attention_mask) + block_logits = self.embed_block(block_tokens, block_attention_mask) + return query_logits, block_logits + + def embed_query(self, query_tokens, query_attention_mask): + """Embed a batch of tokens using the query model""" + if self.use_query_model: + query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0) + query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types) + return query_ict_logits + else: + raise ValueError("Cannot embed query without query model.") + + def embed_block(self, block_tokens, block_attention_mask): + """Embed a batch of tokens using the block model""" + if self.use_block_model: + block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0) + block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types) + return block_ict_logits + else: + raise ValueError("Cannot embed block without block model.") + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + """Save dict with state dicts of each of the models.""" + state_dict_ = {} + if self.use_query_model: + state_dict_[self._query_key] \ + = self.query_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + + if self.use_block_model: + state_dict_[self._block_key] \ + = self.block_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Load the state dicts of each of the models""" + if self.use_query_model: + print("Loading ICT query model", flush=True) + self.query_model.load_state_dict( + state_dict[self._query_key], strict=strict) + + if self.use_block_model: + print("Loading ICT block model", flush=True) + self.block_model.load_state_dict( + state_dict[self._block_key], strict=strict) + + def init_state_dict_from_bert(self): + """Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining""" + args = get_args() + tracker_filename = get_checkpoint_tracker_filename(args.bert_load) + if not os.path.isfile(tracker_filename): + raise FileNotFoundError("Could not find BERT load for ICT") + with open(tracker_filename, 'r') as f: + iteration = int(f.read().strip()) + assert iteration > 0 + + checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False) + if mpu.get_data_parallel_rank() == 0: + print('global rank {} is loading checkpoint {}'.format( + torch.distributed.get_rank(), checkpoint_name)) + + try: + state_dict = torch.load(checkpoint_name, map_location='cpu') + except BaseException: + raise ValueError("Could not load checkpoint") + + # load the LM state dict into each model + model_dict = state_dict['model']['language_model'] + self.query_model.language_model.load_state_dict(model_dict) + self.block_model.language_model.load_state_dict(model_dict) + + # give each model the same ict_head to begin with as well + query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head'] + self.block_model.ict_head.load_state_dict(query_ict_head_state_dict) + + +class IREncoderBertModel(MegatronModule): + """BERT-based encoder for queries or blocks used for learned information retrieval.""" + def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True): + super(IREncoderBertModel, self).__init__() + args = get_args() + + self.ict_head_size = ict_head_size + self.parallel_output = parallel_output + init_method = init_method_normal(args.init_method_std) + scaled_init_method = scaled_init_method_normal(args.init_method_std, + args.num_layers) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=True, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method) + + self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method) + self._ict_head_key = 'ict_head' + + def forward(self, input_ids, attention_mask, tokentype_ids=None): + extended_attention_mask = bert_extended_attention_mask( + attention_mask, next(self.language_model.parameters()).dtype) + position_ids = bert_position_ids(input_ids) + + lm_output, pooled_output = self.language_model( + input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids) + + # Output. + ict_logits = self.ict_head(pooled_output) + return ict_logits, None + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + state_dict_[self._ict_head_key] \ + = self.ict_head.state_dict(destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + self.ict_head.load_state_dict( + state_dict[self._ict_head_key], strict=strict) + + diff --git a/megatron/model/t5_model.py b/megatron/model/t5_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2d30848b9ed01004a3acb7773d44fcc899675a --- /dev/null +++ b/megatron/model/t5_model.py @@ -0,0 +1,217 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5 model.""" + +import torch + +from megatron import ( + get_args, + mpu +) +from megatron.model.enums import AttnMaskType +from megatron.model.language_model import parallel_lm_logits, get_language_model +from megatron.model.transformer import LayerNorm +from megatron.model.utils import ( + openai_gelu, + get_linear_layer, + init_method_normal, + scaled_init_method_normal +) +from .module import MegatronModule + + +def t5_extended_attention_mask(attention_mask_list): + + def attn_mask_postprocess(attn_mask): + # [b, 1, s, s] + extended_attention_mask = attn_mask.unsqueeze(1) + return extended_attention_mask + + return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list] + + +def t5_position_ids(token_ids): + # Create position ids + seq_length = token_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, + device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + +class T5LMHead(MegatronModule): + """Masked LM head for T5 + + Arguments: + mpu_vocab_size: model parallel size of vocabulary. + hidden_size: hidden size + init_method: init method for weight initialization + layernorm_epsilon: tolerance for layer norm divisions + parallel_output: wether output logits being distributed or not. + """ + + def __init__(self, mpu_vocab_size, parallel_output): + super(T5LMHead, self).__init__() + + args = get_args() + + self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) + self.bias.model_parallel = True + self.bias.partition_dim = 0 + self.bias.stride = 1 + self.parallel_output = parallel_output + + def forward(self, hidden_states, word_embeddings_weight): + output = parallel_lm_logits(hidden_states, + word_embeddings_weight, + self.parallel_output, + bias=self.bias) + return output + + +class T5Model(MegatronModule): + """T5 Language model.""" + + def __init__(self, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True, + add_encoder=True, + add_decoder=True): + super(T5Model, self).__init__() + args = get_args() + + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + self.parallel_output = parallel_output + init_method = init_method_normal(args.init_method_std) + scaled_init_method = scaled_init_method_normal(args.init_method_std, + args.num_layers) + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=False, + add_encoder=add_encoder, + add_decoder=add_decoder, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method, + pre_process=self.pre_process, + post_process=self.post_process) + + self.initialize_word_embeddings(init_method_normal) + + if self.post_process and self.add_decoder: + self.lm_head = T5LMHead( + self.word_embeddings_weight().size(0), + parallel_output) + self._lm_head_key = 'lm_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, + encoder_input_ids, + decoder_input_ids, + encoder_attn_mask, + decoder_attn_mask, + encoder_decoder_attn_mask, + tokentype_ids=None, + lm_labels=None, + enc_hidden_states=None): + + # Converting the attention masks to proper parameter settings + encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask( + [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]) + + encoder_position_ids = t5_position_ids(encoder_input_ids) + decoder_position_ids = t5_position_ids(decoder_input_ids) + + lm_output = self.language_model(encoder_input_ids, + encoder_position_ids, + encoder_attn_mask, + decoder_input_ids, + decoder_position_ids, + decoder_attn_mask, + encoder_decoder_attn_mask, + tokentype_ids=tokentype_ids, + enc_hidden_states=enc_hidden_states) + import pdb;pdb.set_trace() + if self.post_process and self.add_decoder: + decoder_output, encoder_output = lm_output + # Output. [s, b, h] + lm_logits = self.lm_head(decoder_output, + self.word_embeddings_weight()) + + if lm_labels is None: + # [s b h] => [b s h] + return lm_logits.transpose(0,1).contiguous() + else: + # [b s] => [s b] + lm_labels = lm_labels.transpose(0,1).contiguous() + if self.fp16_lm_cross_entropy: + assert lm_logits.dtype == torch.half + lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) + else: + lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), + lm_labels) + # [s b] => [b s] + lm_loss = lm_loss.transpose(0,1).contiguous() + return lm_loss + elif self.add_decoder and not self.add_encoder: + decoder_output, encoder_output = lm_output + return decoder_output + else: + encoder_output = lm_output + return encoder_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.post_process and self.add_decoder: + state_dict_[self._lm_head_key] \ + = self.lm_head.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + # Save word_embeddings. + if self.post_process and not self.pre_process and self.add_decoder: + state_dict_[self._word_embeddings_for_head_key] \ + = self.word_embeddings.state_dict(destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + if self.post_process and self.add_decoder: + self.lm_head.load_state_dict(state_dict[self._lm_head_key], + strict=strict) + # Load word embeddings. + if self.post_process and not self.pre_process and self.add_decoder: + self.word_embeddings.load_state_dict( + state_dict[self._word_embeddings_for_head_key], strict=strict) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..3039d5a7bc363d03fed2ae4b4fdeafdd8a60c815 --- /dev/null +++ b/megatron/model/transformer.py @@ -0,0 +1,957 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer.""" +import math +from contextlib import nullcontext +import torch +import torch.nn.functional as F + +from megatron import get_timers, get_args, get_global_memory_buffer +from megatron import mpu +from .module import MegatronModule +from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType +from megatron.model import LayerNorm +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from megatron.model.fused_bias_gelu import bias_gelu_impl +from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu + + +""" We use the following notation throughout this file: + h: hidden size + n: number of attention heads + p: number of model parallel partitions + np: n/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + l: number of layers + Transformer takes input of size [s, b, h] and returns a + tensor of the same size. We use the following arguments: + hyperparameters: transformer hyperparameters +""" + +class DropPath(MegatronModule): + """Drop paths (Stochastic Depth) per sample + (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=0.): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_state): + if self.drop_prob == 0. or not self.training: + return hidden_state + keep_prob = 1 - self.drop_prob + # work with diff dim tensors, not just 2D ConvNets + shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1) + random_tensor = keep_prob + \ + torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device) + random_tensor.floor_() # binarize + output = hidden_state.div(keep_prob) * random_tensor + return output + + +class ParallelMLP(MegatronModule): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, init_method, output_layer_init_method): + super(ParallelMLP, self).__init__() + args = get_args() + + # Project to 4h. + self.dense_h_to_4h = mpu.ColumnParallelLinear( + args.hidden_size, + args.ffn_hidden_size, + gather_output=False, + init_method=init_method, + skip_bias_add=True) + + self.bias_gelu_fusion = args.bias_gelu_fusion + self.activation_func = F.gelu + if args.openai_gelu: + self.activation_func = openai_gelu + elif args.onnx_safe: + self.activation_func = erf_gelu + + # Project back to h. + self.dense_4h_to_h = mpu.RowParallelLinear( + args.ffn_hidden_size, + args.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) + + if self.bias_gelu_fusion: + intermediate_parallel = \ + bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + intermediate_parallel = \ + self.activation_func(intermediate_parallel + bias_parallel) + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias + +class SwitchMLP(MegatronModule): + """ + Routes input to one of N MLP "experts" + """ + def __init__(self, init_method, output_layer_init_method): + super(SwitchMLP, self).__init__() + args = get_args() + self.router = torch.nn.Linear(args.hidden_size, args.num_experts) + self.experts = torch.nn.ModuleList() + for i in range(args.num_experts): + self.experts.append(ParallelMLP(init_method, output_layer_init_method)) + + def forward(self, hidden_states): + # hidden_states: [s, b, h] + s = hidden_states.size(0) + b = hidden_states.size(1) + h = hidden_states.size(2) + route = self.router(hidden_states) + route = torch.nn.functional.softmax(route, dim=2) + max_prob, max_ind = torch.max(route, dim=2) + max_prob = torch.unsqueeze(max_prob, 2) # [s b 1] + + # TODO (rprenger) TODO this could be made easier to read + # Converting [s, b, h] to [s*b, h]. + # Each vector could be routed differently + hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h] + max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1] + max_ind = max_ind.view(-1) # [s*b] + + output_total = torch.empty_like(hidden_states) + output_bias_total = torch.empty_like(hidden_states) + #TODO (rprenger) This does each expert in serial, but it could be parallelized + + for expert_num, expert in enumerate(self.experts): + local_indices = (max_ind == expert_num).nonzero() + hidden = hidden_states[local_indices,:] + output, output_bias = expert(hidden) + output_bias = output_bias.expand_as(output) + output_total[local_indices,:] = output + output_bias_total[local_indices,:] = output_bias + + output_total = output_total*max_prob + output_bias_total = output_bias_total*max_prob + output_total = output_total.view(s, b, h) + output_bias_total = output_bias_total.view(s, b, h) + + return output_total, output_bias_total + + +class CoreAttention(MegatronModule): + + def __init__(self, layer_number, + attn_mask_type=AttnMaskType.padding): + super(CoreAttention, self).__init__() + args = get_args() + self.fp16 = args.fp16 + self.bf16 = args.bf16 + + self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type + self.sequence_parallel = args.sequence_parallel + + projection_size = args.kv_channels * args.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(projection_size, + world_size) + self.hidden_size_per_attention_head = mpu.divide( + projection_size, args.num_attention_heads) + self.num_attention_heads_per_partition = mpu.divide( + args.num_attention_heads, world_size) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + self.fp16, self.bf16, + self.attn_mask_type, + args.masked_softmax_fusion, + attention_mask_func, + self.attention_softmax_in_fp32, + coeff) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout(args.attention_dropout) + + def forward(self, query_layer, key_layer, + value_layer, attention_mask): + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = get_global_memory_buffer().get_tensor( + (output_size[0]*output_size[1], output_size[2], output_size[3]), + query_layer.dtype, "mpu") + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, alpha=(1.0/self.norm_factor)) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.sequence_parallel: + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), + output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class ParallelAttention(MegatronModule): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, init_method, + output_layer_init_method, layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=AttnMaskType.padding): + super(ParallelAttention, self).__init__() + args = get_args() + self.layer_number = max(1, layer_number) + self.attention_type = attention_type + self.attn_mask_type = attn_mask_type + self.params_dtype = args.params_dtype + + projection_size = args.kv_channels * args.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_attention_head = mpu.divide( + projection_size, args.num_attention_heads) + self.num_attention_heads_per_partition = mpu.divide( + args.num_attention_heads, world_size) + + # Strided linear layer. + if attention_type == AttnType.self_attn: + self.query_key_value = mpu.ColumnParallelLinear( + args.hidden_size, + 3 * projection_size, + gather_output=False, + init_method=init_method) + else: + assert attention_type == AttnType.cross_attn + self.query = mpu.ColumnParallelLinear( + args.hidden_size, + projection_size, + gather_output=False, + init_method=init_method) + + self.key_value = mpu.ColumnParallelLinear( + args.hidden_size, + 2 * projection_size, + gather_output=False, + init_method=init_method) + + self.core_attention = CoreAttention(self.layer_number, + self.attn_mask_type) + self.checkpoint_core_attention = args.recompute_granularity == 'selective' + + # Output. + self.dense = mpu.RowParallelLinear( + projection_size, + args.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True) + + def _checkpointed_attention_forward(self, query_layer, key_layer, + value_layer, attention_mask): + """Forward method with activation checkpointing.""" + def custom_forward(*inputs): + query_layer = inputs[0] + key_layer = inputs[1] + value_layer = inputs[2] + attention_mask = inputs[3] + output_ = self.core_attention(query_layer, key_layer, + value_layer, attention_mask) + return output_ + + hidden_states = mpu.checkpoint( + custom_forward, + False, query_layer, key_layer, value_layer, attention_mask) + + return hidden_states + + def _allocate_memory(self, inference_max_sequence_len, batch_size): + return torch.empty( + inference_max_sequence_len, + batch_size, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + dtype=self.params_dtype, + device=torch.cuda.current_device()) + + def forward(self, hidden_states, attention_mask, + encoder_output=None, inference_params=None): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + if inference_params: + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_len = inference_params.max_sequence_len + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size) + inference_value_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size) + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, inference_value_memory) + else: + inference_key_memory, inference_value_memory = \ + inference_params.key_value_memory_dict[self.layer_number] + + # ===================== + # Query, Key, and Value + # ===================== + + if self.attention_type == AttnType.self_attn: + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, + key_layer, + value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) + else: + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + mixed_kv_layer, _ = self.key_value(encoder_output) + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head) + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key_layer, + value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2) + + # Attention head [sq, b, h] --> [sq, b, hp] + query_layer, _ = self.query(hidden_states) + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + query_layer = query_layer.view(*new_tensor_shape) + + # ================================== + # Adjust key and value for inference + # ================================== + + if inference_params: + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + # Copy key and values. + inference_key_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = key_layer + inference_value_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = value_layer + key_layer = inference_key_memory[ + :sequence_end, batch_start:batch_end, ...] + value_layer = inference_value_memory[ + :sequence_end, batch_start:batch_end, ...] + + # ================================== + # core attention computation + # ================================== + + if self.checkpoint_core_attention: + context_layer = self._checkpointed_attention_forward( + query_layer, key_layer, value_layer, attention_mask) + else: + context_layer = self.core_attention( + query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + return output, bias + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor + out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = residual + out + return out + + +def get_bias_dropout_add(training): + def _bias_dropout_add(x, bias, residual, prob): + return bias_dropout_add(x, bias, residual, prob, training) + return _bias_dropout_add + + +@torch.jit.script +def bias_dropout_add_fused_train(x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_fused_inference(x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, False) + + +class ParallelTransformerLayer(MegatronModule): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, init_method, output_layer_init_method, + layer_number, layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + drop_path_rate=0.): + args = get_args() + + super(ParallelTransformerLayer, self).__init__() + self.layer_number = layer_number + self.layer_type = layer_type + + self.apply_residual_connection_post_layernorm \ + = args.apply_residual_connection_post_layernorm + + self.bf16 = args.bf16 + self.fp32_residual_connection = args.fp32_residual_connection + + # Layernorm on the input data. + self.input_layernorm = LayerNorm( + args.hidden_size, + eps=args.layernorm_epsilon, + no_persist_layer_norm=args.no_persist_layer_norm, + sequence_parallel=args.sequence_parallel) + + # Self attention. + self.self_attention = ParallelAttention( + init_method, + output_layer_init_method, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=self_attn_mask_type) + self.hidden_dropout = args.hidden_dropout + self.bias_dropout_fusion = args.bias_dropout_fusion + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNorm( + args.hidden_size, + eps=args.layernorm_epsilon, + no_persist_layer_norm=args.no_persist_layer_norm, + sequence_parallel=args.sequence_parallel) + + if self.layer_type == LayerType.decoder: + self.inter_attention = ParallelAttention( + init_method, + output_layer_init_method, + layer_number, + attention_type=AttnType.cross_attn) + # Layernorm on the attention output. + self.post_inter_attention_layernorm = LayerNorm( + args.hidden_size, + eps=args.layernorm_epsilon, + no_persist_layer_norm=args.no_persist_layer_norm, + sequence_parallel=args.sequence_parallel) + + # MLP + if args.num_experts is not None: + self.mlp = SwitchMLP(init_method, output_layer_init_method) + else: + self.mlp = ParallelMLP(init_method, output_layer_init_method) + + # Set bias+dropout+add fusion grad_enable execution handler. + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + self.bias_dropout_add_exec_handler = \ + nullcontext if use_nvfuser else torch.enable_grad + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + inference_params=None): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, attention_bias = \ + self.self_attention( + layernorm_output, + attention_mask, + inference_params=inference_params) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + if self.drop_path is None: + # jit scripting for a nn.module (with dropout) is not + # trigerring the fusion kernel. For now, we use two + # different nn.functional routines to account for varying + # dropout semantics during training and inference phases. + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = bias_dropout_add_fused_train + else: + bias_dropout_add_func = bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) + + with self.bias_dropout_add_exec_handler(): + layernorm_input = bias_dropout_add_func( + attention_output, + attention_bias.expand_as(residual), + residual, + self.hidden_dropout) + else: + out = torch.nn.functional.dropout(attention_output + attention_bias, + p=self.hidden_dropout, + training=self.training) + layernorm_input = residual + self.drop_path(out) + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + if self.layer_type == LayerType.decoder: + attention_output, attention_bias = \ + self.inter_attention(layernorm_output, + enc_dec_attn_mask, + encoder_output=encoder_output) + # residual connection + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + with self.bias_dropout_add_exec_handler(): + layernorm_input = bias_dropout_add_func( + attention_output, + attention_bias.expand_as(residual), + residual, + self.hidden_dropout) + + # Layer norm post the decoder attention + layernorm_output = self.post_inter_attention_layernorm(layernorm_input) + + # MLP. + mlp_output, mlp_bias = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + if self.drop_path is None: + with self.bias_dropout_add_exec_handler(): + output = bias_dropout_add_func( + mlp_output, + mlp_bias.expand_as(residual), + residual, + self.hidden_dropout) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = mpu.make_viewless_tensor(inp = output, + requires_grad = output.requires_grad, + keep_graph = True) + + else: + out = torch.nn.functional.dropout(mlp_output + mlp_bias, + p=self.hidden_dropout, + training=self.training) + output = residual + self.drop_path(out) + + return output + + +class NoopTransformerLayer(MegatronModule): + """A single 'no-op' transformer layer. + + The sole purpose of this layer is for when a standalone embedding layer + is used (i.e., args.standalone_embedding_stage == True). In this case, + zero transformer layers are assigned when pipeline rank == 0. Additionally, + when virtual pipeline rank >= 1, zero total model parameters are created + (virtual rank 0 contains the input embedding). This results in the model's + input and output tensors being the same, which causes an error when + performing certain memory optimiations on the output tensor (e.g., + deallocating it). Thus, this layer disconnects the input from the output + via a clone. Since ranks containing a no-op layer are generally under- + utilized (both compute and memory), there's no worry of any performance + degredation. + """ + + def __init__(self, layer_number): + super().__init__() + self.layer_number = layer_number + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + inference_params=None): + return hidden_states.clone() + + +class ParallelTransformer(MegatronModule): + """Transformer class.""" + + def __init__(self, init_method, output_layer_init_method, + layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + trans_layer_type="encoder", + post_layer_norm=True, + pre_process=True, post_process=True, + drop_path_rate=0.0): + super(ParallelTransformer, self).__init__() + args = get_args() + + self.layer_type = layer_type + self.model_type = args.model_type + self.bf16 = args.bf16 + self.fp32_residual_connection = args.fp32_residual_connection + self.post_layer_norm = post_layer_norm + self.pre_process = pre_process + self.post_process = post_process + self.input_tensor = None + self.drop_path_rate = drop_path_rate + + # Store activation checkpoiting flag. + self.recompute_granularity = args.recompute_granularity + self.recompute_method = args.recompute_method + self.recompute_num_layers = args.recompute_num_layers + self.distribute_saved_activations = \ + args.distribute_saved_activations and not args.sequence_parallel + + self.sequence_parallel = args.sequence_parallel + + # Number of layers. + if trans_layer_type == "encoder": + self.num_layers = mpu.get_num_layers( + args, args.model_type == ModelType.encoder_and_decoder) + elif trans_layer_type == "decoder": + self.num_layers = mpu.get_num_layers_decoder( + args, args.model_type == ModelType.encoder_and_decoder) + else: + print("No support layer type") + import sys;sys.exit(0) + + self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)] + + # Transformer layers. + def build_layer(layer_number): + return ParallelTransformerLayer( + init_method, + output_layer_init_method, + layer_number, + layer_type=layer_type, + self_attn_mask_type=self_attn_mask_type, + drop_path_rate=self.drop_path_rates[layer_number - 1]) + if args.virtual_pipeline_model_parallel_size is not None: + assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \ + 'num_layers_per_stage must be divisible by ' \ + 'virtual_pipeline_model_parallel_size' + assert args.model_type != ModelType.encoder_and_decoder + # Number of layers in each model chunk is the number of layers in the stage, + # divided by the number of model chunks in a stage. + self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size + # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0] [2] [4] [6] + # Stage 1: [1] [3] [5] [7] + # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0, 1] [4, 5] + # Stage 1: [2, 3] [6, 7] + offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( + args.num_layers // args.virtual_pipeline_model_parallel_size) + \ + (mpu.get_pipeline_model_parallel_rank() * self.num_layers) + else: + # Each stage gets a contiguous set of layers. + if args.model_type == ModelType.encoder_and_decoder and \ + mpu.get_pipeline_model_parallel_world_size() > 1: + pipeline_rank = mpu.get_pipeline_model_parallel_rank() + if layer_type == LayerType.encoder: + offset = pipeline_rank * self.num_layers + else: + num_ranks_in_enc = args.pipeline_model_parallel_split_rank + offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers + else: + offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers + + if self.num_layers == 0: + # When a standalone embedding stage is used (e.g., + # args.standalone_embedding_stage == True), virtual pipeline ranks + # on pipeline rank 0 will have zero transformer layers assigned to + # them. This results in the model's input and output tensors to be + # the same, which will cause failure for certain output tensor + # optimizations (e.g., pipeline output deallocation). To remedy + # this, we assign a 'no-op' layer on these ranks, which will + # disconnect the input tensor from the output tensor. + self.num_layers = 1 + self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ]) + else: + self.layers = torch.nn.ModuleList( + [build_layer(i + 1 + offset) for i in range(self.num_layers)]) + + if self.post_process and self.post_layer_norm: + # Final layer norm before output. + self.final_layernorm = LayerNorm( + args.hidden_size, + eps=args.layernorm_epsilon, + no_persist_layer_norm=args.no_persist_layer_norm, + sequence_parallel=args.sequence_parallel) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def _checkpointed_forward(self, hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask): + """Forward method with activation checkpointing.""" + def custom(start, end): + def custom_forward(*inputs): + x_ = inputs[0] + attention_mask = inputs[1] + encoder_output = inputs[2] + enc_dec_attn_mask = inputs[3] + for index in range(start, end): + layer = self._get_layer(index) + x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask) + return x_ + return custom_forward + + if self.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + l = 0 + while l < self.num_layers: + hidden_states = mpu.checkpoint( + custom(l, l + self.recompute_num_layers), + self.distribute_saved_activations, + hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) + l += self.recompute_num_layers + + elif self.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + for l in range(self.num_layers): + if l < self.recompute_num_layers: + hidden_states = mpu.checkpoint( + custom(l, l + 1), + self.distribute_saved_activations, + hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) + else: + hidden_states = custom(l, l + 1)( + hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + inference_params=None): + # hidden_states: [s, b, h] + + # Checks. + if inference_params: + assert self.recompute_granularity is None, \ + 'inference does not work with activation checkpointing' + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = mpu.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + if self.sequence_parallel: + rng_context = mpu.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + with rng_context: + # Forward pass. + if self.recompute_granularity == 'full': + hidden_states = self._checkpointed_forward(hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask) + else: + for index in range(self.num_layers): + layer = self._get_layer(index) + hidden_states = layer( + hidden_states, + attention_mask, + encoder_output=encoder_output, + enc_dec_attn_mask=enc_dec_attn_mask, + inference_params=inference_params) + + # Final layer norm. + if self.post_process and self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states diff --git a/megatron/model/utils.py b/megatron/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f26b0685340a0cddd28cd70a34e3c8b2321c306f --- /dev/null +++ b/megatron/model/utils.py @@ -0,0 +1,67 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for models.""" + +import math + +import torch + +from megatron import get_args + +def init_method_normal(sigma): + """Init method based on N(0, sigma).""" + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + + return init_ + + +def scaled_init_method_normal(sigma, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +def get_linear_layer(rows, columns, init_method): + """Simple linear layer with weight initialization.""" + layer = torch.nn.Linear(rows, columns) + if get_args().perform_initialization: + init_method(layer.weight) + with torch.no_grad(): + layer.bias.zero_() + return layer + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * + (1.0 + 0.044715 * x * x))) +def openai_gelu(x): + return gelu_impl(x) + +#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter +@torch.jit.script +def erf_gelu(x): + return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype)) diff --git a/megatron/model/vision/classification.py b/megatron/model/vision/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..41e26d3ecd0d685712a4c9595a53c72a4376c420 --- /dev/null +++ b/megatron/model/vision/classification.py @@ -0,0 +1,97 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Vision Transformer(VIT) model.""" + +import torch +from torch.nn.init import trunc_normal_ +from megatron import get_args +from megatron.model.utils import get_linear_layer +from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead +from megatron.model.vision.mit_backbone import mit_b3_avg +from megatron.model.module import MegatronModule + +class VitClassificationModel(MegatronModule): + """Vision Transformer Model.""" + + def __init__(self, num_classes, finetune=False, + pre_process=True, post_process=True): + super(VitClassificationModel, self).__init__() + args = get_args() + + self.hidden_size = args.hidden_size + self.num_classes = num_classes + self.finetune = finetune + self.pre_process = pre_process + self.post_process = post_process + self.backbone = VitBackbone( + pre_process=self.pre_process, + post_process=self.post_process, + single_token_output=True + ) + + if self.post_process: + if not self.finetune: + self.head = VitMlpHead(self.hidden_size, self.num_classes) + else: + self.head = get_linear_layer( + self.hidden_size, + self.num_classes, + torch.nn.init.zeros_ + ) + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.backbone.set_input_tensor(input_tensor) + + def forward(self, input): + hidden_states = self.backbone(input) + + if self.post_process: + hidden_states = self.head(hidden_states) + + return hidden_states + + +class MitClassificationModel(MegatronModule): + """Mix vision Transformer Model.""" + + def __init__(self, num_classes, + pre_process=True, post_process=True): + super(MitClassificationModel, self).__init__() + args = get_args() + + self.hidden_size = args.hidden_size + self.num_classes = num_classes + + self.backbone = mit_b3_avg() + self.head = torch.nn.Linear(512, num_classes) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, torch.nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, torch.nn.Linear) and m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + pass + + def forward(self, input): + hidden_states = self.backbone(input) + hidden_states = self.head(hidden_states) + + return hidden_states diff --git a/megatron/model/vision/dino.py b/megatron/model/vision/dino.py new file mode 100644 index 0000000000000000000000000000000000000000..651271a6fc8f76ffd943e57157e8dbc9293bf573 --- /dev/null +++ b/megatron/model/vision/dino.py @@ -0,0 +1,288 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py +# reworked/refactored some parts to make it run in Megatron. +import math +import apex +import einops +import torch +import numpy as np +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ +from megatron import get_args, print_rank_0 +from megatron.model.utils import get_linear_layer +from megatron.model.vision.vit_backbone import VitBackbone +from megatron.model.module import MegatronModule +from megatron.model.vision.mit_backbone import mit_b5_avg +from megatron.model.vision.esvit_swin_backbone import get_swin + + +class DINOLoss(torch.nn.Module): + def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, + warmup_teacher_temp_epochs, nepochs, student_temp=0.1, + center_momentum=0.9): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.ncrops = ncrops + self.register_buffer("center", torch.zeros(1, out_dim)) + # we apply a warm up for the teacher temperature because + # a too high temperature makes the training instable at the beginning + self.teacher_temp_schedule = np.concatenate(( + np.linspace(warmup_teacher_temp, + teacher_temp, warmup_teacher_temp_epochs), + np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp + )) + self.teacher_temp = teacher_temp + + def forward(self, student_output, teacher_output, iteration): + """ + Cross-entropy between softmax outputs of the teacher + and student network. + """ + args = get_args() + student_out = student_output / self.student_temp + student_out = student_out.chunk(self.ncrops) + + epoch = iteration // args.iter_per_epoch + + # teacher centering and sharpening + temp = self.teacher_temp_schedule[epoch] + teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) + + teacher_out = teacher_out.detach().chunk(2) + + total_loss = 0 + n_loss_terms = 0 + for iq, q in enumerate(teacher_out): + for v in range(len(student_out)): + if v == iq: + # we skip cases where student and teacher operate on the same view + continue + loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) + total_loss += loss.mean() + n_loss_terms += 1 + total_loss /= n_loss_terms + self.update_center(teacher_output) + return total_loss + + @torch.no_grad() + def update_center(self, teacher_output): + """ + Update center used for teacher output. + """ + batch_center = torch.sum(teacher_output, dim=0, keepdim=True) + torch.distributed.all_reduce(batch_center) + batch_center = batch_center / (len(teacher_output) * torch.distributed.get_world_size()) + self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) + +class DINOHead(torch.nn.Module): + def __init__(self, in_dim, out_dim, norm_last_layer=True, nlayers=3): + super().__init__() + args = get_args() + hidden_dim = args.dino_head_hidden_size + bottleneck_dim = args.dino_bottleneck_size + nlayers = max(nlayers, 1) + if nlayers == 1: + self.mlp = torch.nn.Linear(in_dim, bottleneck_dim) + else: + layers = [torch.nn.Linear(in_dim, hidden_dim)] + layers.append(torch.nn.GELU()) + for _ in range(nlayers - 2): + layers.append(torch.nn.Linear(hidden_dim, hidden_dim)) + layers.append(torch.nn.GELU()) + layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim)) + self.mlp = torch.nn.Sequential(*layers) + self.apply(self._init_weights) + self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + if norm_last_layer: + self.last_layer.weight_g.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, torch.nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, torch.nn.Linear) and m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + x = torch.nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x + + +class MultiCropWrapper(MegatronModule): + + """ + Perform forward pass separately on each resolution input. + The inputs corresponding to a single resolution are clubbed and single + forward is run on the same resolution inputs. Hence we do several + forward passes = number of different resolutions used. We then + concatenate all the output features and run the head forward on these + concatenated features. + """ + def __init__(self, backbone, head): + super(MultiCropWrapper, self).__init__() + # disable layers dedicated to ImageNet labels classification + #backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity() + self.backbone = backbone + self.head = head + + def forward(self, x): + # convert to list + if not isinstance(x, list): + x = [x] + idx_crops = torch.cumsum(torch.unique_consecutive( + torch.tensor([inp.shape[-1] for inp in x]), + return_counts=True, + )[1], 0) + + start_idx = 0 + for end_idx in idx_crops: + _out = self.backbone(torch.cat(x[start_idx: end_idx])) + if start_idx == 0: + output = _out + else: + output = torch.cat((output, _out)) + start_idx = end_idx + # Run the head forward on the concatenated features. + if self.training: + return self.head(output) + else: + return output + + +def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, + warmup_epochs=0, start_warmup_value=0): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_epochs > 0: + warmup_schedule = \ + np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(epochs * niter_per_ep - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) \ + * (1 + np.cos(np.pi * iters / len(iters))) + + schedule = np.concatenate((warmup_schedule, schedule)) + assert len(schedule) == epochs * niter_per_ep + return schedule + + +def get_student_backbone_and_num_features(pre_process=True, post_process=True): + args = get_args() + + if args.vision_backbone_type == 'vit': + student = VitBackbone(pre_process=pre_process, + post_process=post_process, + drop_path_rate=0.1, + single_token_output=True) + num_features = args.hidden_size + elif args.vision_backbone_type == 'mit': + student = mit_b5_avg(drop_path_rate=0.1) + num_features = 512 + elif args.vision_backbone_type == 'swin': + student = get_swin() + num_features = student.num_features + else: + raise Exception('{} vision backbone is not supported.'.format( + args.vision_backbone_type)) + + return student, num_features + +def get_teacher_backbone_and_num_features(pre_process=True, post_process=True): + args = get_args() + + if args.vision_backbone_type == 'vit': + teacher = VitBackbone(pre_process=pre_process, + post_process=post_process, + single_token_output=True) + num_features = args.hidden_size + elif args.vision_backbone_type == 'mit': + teacher = mit_b5_avg(drop_path_rate=0.0) + num_features = 512 + elif args.vision_backbone_type == 'swin': + teacher = get_swin(is_teacher=True) + num_features = teacher.num_features + else: + raise Exception('{} vision backbone is not supported.'.format( + args.vision_backbone_type)) + return teacher, num_features + + +class DINOPretrainModel(MegatronModule): + def __init__(self, pre_process=True, post_process=True): + super(DINOPretrainModel, self).__init__() + args = get_args() + self.out_dim = 65536 + + self.dino_loss = DINOLoss( + self.out_dim, + args.dino_local_crops_number + 2, + args.dino_warmup_teacher_temp, + args.dino_teacher_temp, + args.dino_warmup_teacher_temp_epochs, + 300, + ) + + self.pre_process = pre_process + self.post_process = post_process + self.momentum_teacher = 0.996 + + student_backbone, num_features = \ + get_student_backbone_and_num_features(pre_process, post_process) + + self.student = MultiCropWrapper( + student_backbone, + DINOHead(num_features, self.out_dim, + norm_last_layer=args.dino_norm_last_layer) + ) + + self.momentum_schedule = cosine_scheduler( + self.momentum_teacher, 1, + args.train_iters // args.iter_per_epoch, + args.iter_per_epoch + ) + + teacher_backbone, num_features = \ + get_teacher_backbone_and_num_features(pre_process, post_process) + self.teacher = MultiCropWrapper( + teacher_backbone, + DINOHead(num_features, self.out_dim) + ) + self.teacher.load_state_dict(self.student.state_dict()) + + for p in self.teacher.parameters(): + if hasattr(p, "requires_grad") and p.requires_grad is not None: + p.requires_grad = False + + def set_input_tensor(self, tensor): + pass + + def forward(self, input): + student_output = None + if self.training: + student_output = self.student(input) + teacher_output = self.teacher(input[:2]) + else: + teacher_output = self.teacher(input) + return student_output, teacher_output + + def cancel_gradients_last_layer(self, iteration): + args = get_args() + epoch = iteration // args.iter_per_epoch + if epoch < args.dino_freeze_last_layer: + for n, p in self.student.named_parameters(): + if "last_layer" in n: + p.grad = None + + def update_momentum(self, iteration): + with torch.no_grad(): + m = self.momentum_schedule[iteration] + for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()): + param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) + diff --git a/megatron/model/vision/esvit_swin_backbone.py b/megatron/model/vision/esvit_swin_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..70aee3db429bf63680b7c19b4d5acfe25ee2edf5 --- /dev/null +++ b/megatron/model/vision/esvit_swin_backbone.py @@ -0,0 +1,849 @@ +# Copyright (c) 2021 Microsoft +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Modified by Chunyuan Li (chunyl@microsoft.com) +# Swin Transformer +# -------------------------------------------------------- + +import os +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +import torch.distributed as dist +from torch.nn.init import trunc_normal_ +from megatron.model.transformer import DropPath +from megatron import get_args +from megatron.model import LayerNorm +import numpy as np +from math import sqrt + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super(Mlp, self).__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super(WindowAttention, self).__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2 Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0).type(attn.type()) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn_out = attn + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn_out + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + @staticmethod + def compute_macs(module, input, output): + B, N, C = input[0].shape + + module.__flops__ += module.flops(N) * B + + +class SwinTransformerBlock(nn.Module): + r"""Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = input_resolution[0] + self.W = input_resolution[1] + + self.attn_mask_dict = {} + + + def create_attn_mask(self, H, W): + # calculate attention mask for SW-MSA + + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + + def forward(self, x): + B, L, C = x.shape + H = int(sqrt(L)) + W = H + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + + if H in self.attn_mask_dict.keys(): + attn_mask = self.attn_mask_dict[H] + else: + self.attn_mask_dict[H] = self.create_attn_mask(self.H, self.W).to(x.device) + attn_mask = self.attn_mask_dict[H] + + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows, attn = self.attn(x_windows, attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x, attn + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size} mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r"""Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + H = int(sqrt(L)) + W = H + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + x, _ = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def forward_with_features(self, x): + fea = [] + for blk in self.blocks: + x, _ = blk(x) + fea.append(x) + if self.downsample is not None: + x = self.downsample(x) + return x, fea + + def forward_with_attention(self, x): + attns = [] + for blk in self.blocks: + x, attn = blk(x) + attns.append(attn) + if self.downsample is not None: + x = self.downsample(x) + return x, attns + + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. + patch_size (int | tuple(int)): Patch size. + in_chans (int): Number of input channels. + num_classes (int): Number of classes for classification head. + embed_dim (int): Embedding dimension. + depths (tuple(int)): Depth of Swin Transformer layers. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): normalization layer. + ape (bool): If True, add absolute position embedding to the patch embedding. + patch_norm (bool): If True, add normalization after patch embedding. + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + # todo: to be implemented + return {'relative_position_bias_table'} + + def forward(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x_region = self.norm(x) # B L C + x = self.avgpool(x_region.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + + return x + + + def forward_feature_maps(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x_grid = self.norm(x) # B L C + x = self.avgpool(x_grid.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + + return x, x_grid + + + def forward_selfattention(self, x, n=1): + # n=1 return the last layer attn map; otherwise return attn maps in all layers + + + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + if n==1: + return self.forward_last_selfattention(x) + else: + return self.forward_all_selfattention(x) + + def forward_last_selfattention(self, x): + + for i, layer in enumerate(self.layers): + if i < len(self.layers) - 1: + x = layer(x) + else: + x, attns = layer.forward_with_attention(x) + return attns[-1] + + def forward_all_selfattention(self, x): + attn_out = [] + + for layer in self.layers: + x, attns = layer.forward_with_attention(x) + attn_out += attns + + return attn_out + + + def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False, depth=[]): + + num_blks = sum(depth) + start_idx = num_blks - n + + sum_cur = 0 + for i, d in enumerate(depth): + sum_cur_new = sum_cur + d + if start_idx >= sum_cur and start_idx < sum_cur_new: + start_stage = i + start_blk = start_idx - sum_cur + sum_cur = sum_cur_new + + + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + # we will return the averaged token features from the `n` last blocks + # note: there is no [CLS] token in Swin Transformer + output = [] + s = 0 + for i, layer in enumerate(self.layers): + x, fea = layer.forward_with_features(x) + + if i >= start_stage: + for x_ in fea[start_blk:]: + + if i == len(self.layers)-1: # use the norm in the last stage + x_ = self.norm(x_) + + x_avg = torch.flatten(self.avgpool(x_.transpose(1, 2)), 1) # B C + # print(f'Stage {i}, x_avg {x_avg.shape}') + output.append(x_avg) + + start_blk = 0 + + return torch.cat(output, dim=-1) + + + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + if dist.get_rank() == 0: + print(f"GFLOPs layer_{i}: {layer.flops() / 1e9}") + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + def init_weights(self, pretrained='', pretrained_layers=[], verbose=True): + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained, map_location='cpu') + logging.info(f'=> loading pretrained model {pretrained}') + model_dict = self.state_dict() + pretrained_dict = { + k: v for k, v in pretrained_dict.items() + if k in model_dict.keys() + } + need_init_state_dict = {} + for k, v in pretrained_dict.items(): + need_init = ( + k.split('.')[0] in pretrained_layers + or pretrained_layers[0] is '*' + or 'relative_position_index' not in k + or 'attn_mask' not in k + ) + + if need_init: + if verbose: + logging.info(f'=> init {k} from {pretrained}') + + if 'relative_position_bias_table' in k and v.size() != model_dict[k].size(): + relative_position_bias_table_pretrained = v + relative_position_bias_table_current = model_dict[k] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + if nH1 != nH2: + logging.info(f"Error in loading {k}, passing") + else: + if L1 != L2: + logging.info( + '=> load_pretrained: resized variant: {} to {}' + .format((L1, nH1), (L2, nH2)) + ) + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( + relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) + + if 'absolute_pos_embed' in k and v.size() != model_dict[k].size(): + absolute_pos_embed_pretrained = v + absolute_pos_embed_current = model_dict[k] + _, L1, C1 = absolute_pos_embed_pretrained.size() + _, L2, C2 = absolute_pos_embed_current.size() + if C1 != C1: + logging.info(f"Error in loading {k}, passing") + else: + if L1 != L2: + logging.info( + '=> load_pretrained: resized variant: {} to {}' + .format((1, L1, C1), (1, L2, C2)) + ) + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) + absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) + absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( + absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') + v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2) + + need_init_state_dict[k] = v + self.load_state_dict(need_init_state_dict, strict=False) + + def freeze_pretrained_layers(self, frozen_layers=[]): + for name, module in self.named_modules(): + if ( + name.split('.')[0] in frozen_layers + or '.'.join(name.split('.')[0:2]) in frozen_layers + or (len(frozen_layers) > 0 and frozen_layers[0] is '*') + ): + for _name, param in module.named_parameters(): + param.requires_grad = False + logging.info( + '=> set param {} requires grad to False' + .format(name) + ) + for name, param in self.named_parameters(): + if ( + name.split('.')[0] in frozen_layers + or (len(frozen_layers) > 0 and frozen_layers[0] is '*') + and param.requires_grad is True + ): + param.requires_grad = False + logging.info( + '=> set param {} requires grad to False' + .format(name) + ) + return self + + +def get_swin(is_teacher=False): + args = get_args() + + if args.swin_backbone_type == "tiny": + embed_dim = 96 + depths = [2, 2, 6, 2] + num_heads = [3, 6, 12, 24] + drop_path_rate = 0.1 + elif args.swin_backbone_type == 'h3': + embed_dim = 384 + depths = [2, 2, 18, 2] + num_heads = [6, 12, 24, 48] + drop_path_rate = 0.2 + else: + embed_dim = 128 + depths = [2, 2, 18, 2] + num_heads = [4, 8, 16, 32] + drop_path_rate = 0.2 + + swin = SwinTransformer( + img_size=224, + in_chans=3, + num_classes=1000, + patch_size=4, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=7, + mlp_ratio=4, + qkv_bias=True, + drop_rate=0, + attn_drop_rate=0, + drop_path_rate=(0.0 if is_teacher else drop_path_rate), + norm_layer=partial(LayerNorm, eps=1e-6), + ape=False, + patch_norm=True, + ) + + return swin + diff --git a/megatron/model/vision/inpainting.py b/megatron/model/vision/inpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..e44debe4d633ef35c7312755b9c7fbcd403f119f --- /dev/null +++ b/megatron/model/vision/inpainting.py @@ -0,0 +1,151 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +i +import math +import apex +import einops +import torch +import torch.nn.functional as F +from megatron import get_args, print_rank_0 +from megatron.model.utils import get_linear_layer +from megatron.model.vision.vit_backbone import VitBackbone +from megatron.model.module import MegatronModule +from megatron.model.vision.mit_backbone import mit_b3 +from megatron.model.vision.utils import resize_ + + +class VitInpaintingModel(MegatronModule): + + def __init__(self, pre_process=True, post_process=True): + super(VitInpaintingModel, self).__init__() + args = get_args() + + self.pre_process = pre_process + self.post_process = post_process + self.hidden_size = args.hidden_size + self.backbone = VitBackbone( + pre_process=self.pre_process, + post_process=self.post_process, + class_token=False, + ) + self.patch_dim = args.patch_dim + self.img_h = args.img_h + self.img_w = args.img_w + self.seq_length = args.seq_length + # full mask + + if self.post_process: + self.linear_decoder = get_linear_layer( + self.hidden_size, + self.backbone.flatten_dim, + torch.nn.init.zeros_ + ) + + def set_input_tensor(self, input_tensor): + self.backbone.set_input_tensor(input_tensor) + + def forward(self, input): + + hidden_states = self.backbone(input) + + if not self.post_process: + return hidden_states + decoded_output = self.linear_decoder(hidden_states) + output = einops.rearrange( + decoded_output, + "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", + p1=self.patch_dim, + p2=self.patch_dim, + h=self.img_h//self.patch_dim, + w=self.img_w//self.patch_dim, + ) + + return output + + +class MLP(torch.nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = torch.nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class MitInpaintingModel(MegatronModule): + """Mix vision Transformer Model.""" + + def __init__(self, pre_process=True, post_process=True): + super(MitInpaintingModel, self).__init__() + self.pre_process = pre_process + self.post_process = post_process + + args = get_args() + self.patch_dim = args.patch_dim + self.img_h = args.img_h + self.img_w = args.img_w + self.flatten_dim = self.patch_dim * self.patch_dim * 3 + self.backbone = mit_b3() + + self.in_channels = [64, 128, 320, 512] + self.embedding_dim = 768 + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim) + + self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4, self.embedding_dim, 1, 1, bias=False) + self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim) + self.dropout = torch.nn.Dropout2d(0.1) + + self.linear_pred = torch.nn.Conv2d(self.embedding_dim, self.flatten_dim, kernel_size=1) + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + pass + + def forward(self, input): + c1, c2, c3, c4 = self.backbone(input) + + n, _, h, w = c4.shape + _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) + _c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False) + + _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) + _c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False) + + _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) + _c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False) + + _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) + + _c = torch.cat([_c4, _c3, _c2, _c1], dim=1) + _c = self.conv_fuse(_c) + + x = self.norm(_c) + x = F.relu(x, inplace=True) + x = self.dropout(x) + + x = self.linear_pred(x) + + output = einops.rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.patch_dim, + p2=self.patch_dim, + h=self.img_h//self.patch_dim, + w=self.img_w//self.patch_dim, + ) + + return output diff --git a/megatron/model/vision/knn_monitor.py b/megatron/model/vision/knn_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a7588008d40733f636b006aa00e3d23531caf7 --- /dev/null +++ b/megatron/model/vision/knn_monitor.py @@ -0,0 +1,128 @@ +import torch.nn.functional as F +import torch +from megatron import print_rank_0, get_args, mpu +from megatron.data.vit_dataset import ClassificationTransform +from megatron.data.image_folder import ImageFolder + +_FEATURE_BANK = None + + +def build_data_loader(dataset, drop_last=True, shuffle=False): + """Data loader. Note that batch-size is the local (per GPU) batch-size.""" + # Sampler. + args = get_args() + micro_batch_size = 16 + num_workers = args.num_workers + world_size = mpu.get_data_parallel_world_size() + rank = mpu.get_data_parallel_rank() + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, + drop_last=drop_last, shuffle=shuffle + ) + + # Data loader. Note that batch size is the per GPU batch size. + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=micro_batch_size, + sampler=sampler, + shuffle=False, + num_workers=num_workers, + drop_last=not drop_last, + pin_memory=True, + ) + return data_loader + + +def compute_feature_bank(model): + args = get_args() + global _FEATURE_BANK + feature_bank = [] + feature_label = [] + + train_ds = ImageFolder( + root=args.data_path[0], + transform=ClassificationTransform((args.img_h, args.img_w), train=False), + data_per_class_fraction=1.0 + ) + classes = len(train_ds.classes) + dataloader = build_data_loader(train_ds) + + for m in model: + m.eval() + + with torch.no_grad(): + for i, batch in enumerate(dataloader): + images = batch[0].cuda().contiguous() + labels = batch[1].cuda().contiguous() + student_feature, teacher_feature = model[0](images) + feature = F.normalize(teacher_feature.float(), dim=1) + feature_bank.append(feature) + feature_label.append(labels) + + for m in model: + m.train() + + # [N', D] + feature_bank = torch.cat(feature_bank, dim=0).contiguous() + feature_label = torch.cat(feature_label, dim=0).contiguous() + + feature_banks = [torch.zeros_like(feature_bank) + for i in range(mpu.get_data_parallel_world_size())] + torch.distributed.all_gather(feature_banks, + feature_bank, + group=mpu.get_data_parallel_group()) + + assert torch.all(torch.eq(feature_banks[mpu.get_data_parallel_rank()], + feature_bank)) + + feature_labels = [torch.zeros_like(feature_label) + for i in range(mpu.get_data_parallel_world_size())] + torch.distributed.all_gather(feature_labels, + feature_label, + group=mpu.get_data_parallel_group()) + + # [D, N] + feature_banks = torch.cat(feature_banks, dim=0).t().contiguous() + # [N] + feature_labels = torch.cat(feature_labels, dim=0).contiguous() + print_rank_0("feature_banks size is {}".format(feature_banks.size())) + print_rank_0("feature labels size is {}".format(feature_labels.size())) + + _FEATURE_BANK = (feature_banks, feature_labels, classes) + + +def get_feature_bank(): + global _FEATURE_BANK + assert _FEATURE_BANK is not None + return _FEATURE_BANK + + +# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978 +# implementation follows http://github.com/zhirongw/lemniscate.pytorch and +# https://github.com/leftthomas/SimCLR +def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): + # compute cos similarity between each feature vector and feature bank ---> [B, N] + sim_matrix = torch.mm(feature, feature_bank) + # [B, K] + sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) + # [B, K] + sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), + dim=-1, + index=sim_indices) + sim_weight = (sim_weight / knn_t).exp() + + # counts for each class + one_hot_label = torch.zeros(feature.size(0) * knn_k, + classes, + device=sim_labels.device) + # [B*K, C] + one_hot_label = one_hot_label.scatter(dim=-1, + index=sim_labels.view(-1, 1), + value=1.0) + # weighted score ---> [B, C] + pred_scores = torch.sum( + one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), + dim=1) + + pred_labels = pred_scores.argsort(dim=-1, descending=True) + return pred_labels diff --git a/megatron/model/vision/mit_backbone.py b/megatron/model/vision/mit_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..c67ca2c62bbd380751998954be16fdb93fef38f4 --- /dev/null +++ b/megatron/model/vision/mit_backbone.py @@ -0,0 +1,420 @@ +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# found in the LICENSE file in the root directory of this +# source tree. +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from torch.nn.init import trunc_normal_ +from megatron.model.transformer import DropPath +from megatron.model import LayerNorm + + +class Mlp(nn.Module): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = LayerNorm(dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + return x + + +class OverlapPatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2)) + self.norm = LayerNorm(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNorm, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], output_avg=False): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.output_avg = output_avg + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + def forward_features(self, x): + B = x.shape[0] + outs = [] + + # stage 1 + x, H, W = self.patch_embed1(x) + for i, blk in enumerate(self.block1): + x = blk(x, H, W) + x = self.norm1(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 2 + x, H, W = self.patch_embed2(x) + for i, blk in enumerate(self.block2): + x = blk(x, H, W) + x = self.norm2(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 3 + x, H, W = self.patch_embed3(x) + for i, blk in enumerate(self.block3): + x = blk(x, H, W) + x = self.norm3(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 4 + x, H, W = self.patch_embed4(x) + for i, blk in enumerate(self.block4): + x = blk(x, H, W) + x = self.norm4(x) + if not self.output_avg: + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return outs + + def forward(self, x): + x = self.forward_features(x) + + if self.output_avg: + x = x[3].mean(dim=1) + + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b0, self).__init__( + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b1, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b2, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b3, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + +class mit_b3_avg(MixVisionTransformer): + def __init__(self, drop_path_rate=0.1, **kwargs): + super(mit_b3_avg, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True) + +class mit_b4(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b4, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + +class mit_b5(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b5, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + +class mit_b5_avg(MixVisionTransformer): + def __init__(self, drop_path_rate=0.1, **kwargs): + super(mit_b5_avg, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True) + diff --git a/megatron/model/vision/swin_backbone.py b/megatron/model/vision/swin_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..9a622c7070f5a8335b56ce01398e3e464e3fef6a --- /dev/null +++ b/megatron/model/vision/swin_backbone.py @@ -0,0 +1,625 @@ +# Copyright (c) 2021 Microsoft +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Swin Transformer +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from math import sqrt + +from megatron import get_args +from functools import partial + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = input_resolution[0] + self.W = input_resolution[1] + + self.attn_mask_dict = {} + + def create_attn_mask(self, H, W): + # calculate attention mask for SW-MSA + + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + + def forward(self, x): + B, L, C = x.shape + H = int(sqrt(L)) + W = H + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + x_b4_ds = x + if self.downsample is not None: + x = self.downsample(x) + return x_b4_ds, x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + norm_layer=partial(nn.LayerNorm, eps=1e-6), ape=False, patch_norm=True, + use_checkpoint=False, output_avg=False, **kwargs): + super().__init__() + + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + self.img_size = to_2tuple(img_size) + self.patch_size = to_2tuple(patch_size) + self.output_avg = output_avg + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + h = self.img_size[0] // self.patch_size[0] + w = self.img_size[1] // self.patch_size[1] + outs = [] + + for i, layer in enumerate(self.layers): + px, x = layer(x) + b, n, c = px.shape + + if i != len(self.layers) - 1 or not self.output_avg: + px = px.permute(0, 2, 1).contiguous() + px = px.reshape(b, c, h, w) + # is this a fair assumption ?? i think it's baked into the architecture + h, w = h//2, w//2 + outs.append(px) + + if self.output_avg: + return outs[-1].mean(dim=1) + + return outs + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +def get_swin(drop_path_rate=0.3, output_avg=False): + args = get_args() + + window_size = 7 + embed_dim = 128 + depths = [2, 2, 18, 2] + num_heads = [4, 8, 16, 32] + swin = SwinTransformer( + img_size=(args.img_h, args.img_w,), + in_chans=3, + patch_size=args.patch_dim, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + drop_path_rate=drop_path_rate, + output_avg=output_avg, + ) + + return swin + diff --git a/megatron/model/vision/utils.py b/megatron/model/vision/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4068912c8bb234eff54d6b4feae499f7e8ab30c --- /dev/null +++ b/megatron/model/vision/utils.py @@ -0,0 +1,27 @@ +import warnings +import torch +import torch.nn.functional as F + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + if isinstance(size, torch.Size): + size = tuple(int(x) for x in size) + return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/megatron/model/vision/vit_backbone.py b/megatron/model/vision/vit_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..b29a5e478d0c97753dc98d54666e7f2c54337f90 --- /dev/null +++ b/megatron/model/vision/vit_backbone.py @@ -0,0 +1,260 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Vision Transformer(VIT) model.""" + +import math +import einops +import torch +import apex +import torch.nn.functional as F +from megatron import get_args +from megatron.model.transformer import ParallelTransformer +from megatron.model.utils import ( + get_linear_layer, + init_method_normal, + scaled_init_method_normal, +) +from megatron.model.module import MegatronModule + +CLASS_TOKEN_LENGTH = 8 + +class VitMlpHead(MegatronModule): + """Pooler layer. + + Pool hidden states of a specific token (for example start of the + sequence) and add a linear transformation followed by a tanh. + + Arguments: + hidden_size: hidden size + init_method: weight initialization method for the linear layer. + bias is set to zero. + """ + + def __init__(self, hidden_size, num_classes): + super(VitMlpHead, self).__init__() + self.dense_in = torch.nn.Linear(hidden_size, hidden_size) + self.relu = torch.nn.ReLU() + self.dense_out = torch.nn.Linear(hidden_size, num_classes) + torch.nn.init.constant_(self.dense_out.bias, -10) + + def forward(self, hidden_states): + # hidden_states: [b, 1, h] + # sequence_index: index of the token to pool. + dense_in_result = self.dense_in(hidden_states) + tanh_result = torch.tanh(dense_in_result) + dense_out_result = self.dense_out(tanh_result) + return dense_out_result + + +def isPerfectSquare(x): + if(x >= 0): + sr = math.sqrt(x) + return (int(sr) * int(sr) == x) + return False + + +def twod_interpolate_position_embeddings_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + + args = get_args() + num_patches_per_dim_h = args.img_h // args.patch_dim + num_patches_per_dim_w = args.img_w // args.patch_dim + num_patches = num_patches_per_dim_h * num_patches_per_dim_w + hidden_size = args.hidden_size + + key = prefix + "weight" + + assert key in state_dict + if key in state_dict: + input_param = state_dict[key] + + input_seq_len = input_param.shape[0] + assert(isPerfectSquare(input_seq_len) or isPerfectSquare(input_seq_len - CLASS_TOKEN_LENGTH)) + input_has_class_token = not isPerfectSquare(input_seq_len) + num_tok_input = input_seq_len - CLASS_TOKEN_LENGTH if input_has_class_token else input_seq_len + num_tok_output = num_patches + output_has_class_token = args.class_token_present + + # update input_param and load it to state_dict[key] + if input_has_class_token: + input_param_tok = input_param[:CLASS_TOKEN_LENGTH, :] + input_param_grid = input_param[CLASS_TOKEN_LENGTH:, :] + else: + input_param_tok = torch.zeros(CLASS_TOKEN_LENGTH, hidden_size) + input_param_grid = input_param + + assert input_param.shape[1] == hidden_size + + if num_tok_input != num_tok_output: + + gs_input = int(math.sqrt(num_tok_input)) + gs_new = (num_patches_per_dim_h, num_patches_per_dim_w) + + input_param_grid = input_param_grid.transpose(0, 1).contiguous() + input_param_grid = input_param_grid.reshape( + (1, -1, gs_input, gs_input) + ) + input_param_grid = input_param_grid.float() + scale_factor = (gs_new[0] / gs_input, gs_new[1] / gs_input) + + input_param_grid = F.interpolate( + input_param_grid, scale_factor=scale_factor, mode="bilinear" + ) + + input_param_grid = input_param_grid.half() + input_param_grid = input_param_grid.reshape((-1, num_tok_output)) + input_param_grid = input_param_grid.transpose(0, 1).contiguous() + + assert input_param_grid.shape[1] == hidden_size + + input_param = input_param_grid + assert ( + input_param.shape[0] == num_tok_output + and input_param.shape[1] == hidden_size + ) + + if output_has_class_token: + input_param = torch.cat((input_param_tok, input_param), dim=0) + + state_dict[key] = input_param + + +class VitBackbone(MegatronModule): + """Vision Transformer Model.""" + + def __init__(self, + pre_process=True, + post_process=True, + class_token=True, + single_token_output=False, + post_layer_norm=True, + drop_path_rate=0.0): + super(VitBackbone, self).__init__(share_word_embeddings=False) + args = get_args() + + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + if args.init_method_xavier_uniform: + self.init_method = torch.nn.init.xavier_uniform_ + self.scaled_init_method = torch.nn.init.xavier_uniform_ + else: + self.init_method = init_method_normal(args.init_method_std) + self.scaled_init_method = scaled_init_method_normal( + args.init_method_std, args.num_layers + ) + + self.pre_process = pre_process + self.post_process = post_process + self.class_token = class_token + self.post_layer_norm = post_layer_norm + self.hidden_size = args.hidden_size + self.patch_dim = args.patch_dim + self.img_h = args.img_h + self.img_w = args.img_w + self.micro_batch_size = args.micro_batch_size + self.single_token_output = single_token_output + self.drop_path_rate = drop_path_rate + + assert self.img_h % self.patch_dim == 0 + assert self.img_w % self.patch_dim == 0 + self.num_patches_per_dim_h = self.img_h // self.patch_dim + self.num_patches_per_dim_w = self.img_w // self.patch_dim + self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w + self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH if self.class_token else 0) + self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels + self.input_tensor = None + self.position_ids = None + + if self.pre_process: + # cls_token + if self.class_token: + self.cls_token = torch.nn.Parameter( + torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size) + ) + torch.nn.init.zeros_(self.cls_token) + self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda() + + # Linear encoder + self.linear_encoder = torch.nn.Linear( + self.flatten_dim, self.hidden_size + ) + + # embedding + self.position_embeddings = torch.nn.Embedding( + self.seq_length, self.hidden_size + ) + init_method_normal(args.init_method_std)( + self.position_embeddings.weight + ) + + args.class_token_present = self.class_token + self.position_embeddings._register_load_state_dict_pre_hook( + twod_interpolate_position_embeddings_hook + ) + + self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout) + + # Transformer + self.transformer = ParallelTransformer( + self.init_method, + self.scaled_init_method, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=self.post_layer_norm, + drop_path_rate=self.drop_path_rate + ) + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.transformer.set_input_tensor(input_tensor) + + def forward(self, input): + + if self.pre_process: + rearranged_input = einops.rearrange( + input, + "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", + p1=self.patch_dim, + p2=self.patch_dim, + ) + + assert rearranged_input.dtype == torch.half + encoder_output = self.linear_encoder(rearranged_input) + + concatenated_tokens = encoder_output + if self.class_token: + cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1) + concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1) + + token_embeddings = concatenated_tokens + \ + self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]]) + hidden_states = self.embedding_dropout(token_embeddings) + else: + hidden_states = input + + hidden_states = self.transformer(hidden_states, None) + + if self.single_token_output: + hidden_states = hidden_states[:,0,:] + + return hidden_states + diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52d24422549e8dc5d8184ac185b042128a32d5b4 --- /dev/null +++ b/megatron/mpu/__init__.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model parallel utility interface.""" + +from .cross_entropy import vocab_parallel_cross_entropy + +from .data import broadcast_data + +from .initialize import is_unitialized +from .initialize import destroy_model_parallel +from .initialize import get_data_parallel_group +from .initialize import get_data_parallel_rank +from .initialize import get_data_parallel_world_size +from .initialize import get_embedding_group +from .initialize import get_position_embedding_group +from .initialize import get_model_parallel_group +from .initialize import get_tensor_model_parallel_group +from .initialize import get_pipeline_model_parallel_group +from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank +from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank +from .initialize import is_pipeline_first_stage, is_pipeline_last_stage +from .initialize import is_rank_in_embedding_group +from .initialize import is_rank_in_position_embedding_group +from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split +from .initialize import is_pipeline_stage_at_split +from .initialize import get_num_layers +from .initialize import get_num_layers_decoder +from .initialize import get_tensor_model_parallel_src_rank +from .initialize import get_data_parallel_src_rank +from .initialize import get_pipeline_model_parallel_first_rank +from .initialize import get_pipeline_model_parallel_last_rank +from .initialize import get_pipeline_model_parallel_next_rank +from .initialize import get_pipeline_model_parallel_prev_rank +from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size +from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size +from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank +from .initialize import initialize_model_parallel +from .initialize import model_parallel_is_initialized + +from .layers import LinearWithGradAccumulationAndAsyncCommunication +from .layers import ColumnParallelLinear +from .layers import RowParallelLinear +from .layers import VocabParallelEmbedding +from .layers import (set_tensor_model_parallel_attributes, + set_defaults_if_not_set_tensor_model_parallel_attributes, + copy_tensor_model_parallel_attributes) + +from .mappings import copy_to_tensor_model_parallel_region +from .mappings import reduce_from_tensor_model_parallel_region +from .mappings import scatter_to_tensor_model_parallel_region +from .mappings import gather_from_tensor_model_parallel_region +from .mappings import scatter_to_sequence_parallel_region +from .mappings import gather_from_sequence_parallel_region +from .mappings import reduce_scatter_to_sequence_parallel_region + +from .random import checkpoint +from .random import get_cuda_rng_tracker +from .random import model_parallel_cuda_manual_seed +from .random import gather_split_1d_tensor +from .random import split_tensor_into_1d_equal_chunks +from .random import make_viewless_tensor +from .random import assert_viewless_tensor +from .random import safely_set_viewless_tensor_data + +from .utils import divide +from .utils import split_tensor_along_last_dim diff --git a/megatron/mpu/cross_entropy.py b/megatron/mpu/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..8c790cd3e1a2e7fd6f108e1e719b01e53f651aff --- /dev/null +++ b/megatron/mpu/cross_entropy.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from .initialize import get_tensor_model_parallel_group +from .initialize import get_tensor_model_parallel_rank +from .initialize import get_tensor_model_parallel_world_size +from .utils import VocabUtility + + +class _VocabParallelCrossEntropy(torch.autograd.Function): + + @staticmethod + def forward(ctx, vocab_parallel_logits, target): + + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, + op=torch.distributed.ReduceOp.MAX, + group=get_tensor_model_parallel_group()) + # Subtract the maximum value. + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + + # Get the partition's vocab indecies + get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start_index, vocab_end_index = get_vocab_range( + partition_vocab_size, rank, world_size) + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], + device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + torch.distributed.all_reduce(predicted_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_tensor_model_parallel_group()) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = vocab_parallel_logits + torch.exp(vocab_parallel_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + torch.distributed.all_reduce(sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_tensor_model_parallel_group()) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + + # Store softmax, target-mask and masked-target for backward pass. + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], + device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= ( + 1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None + + +def vocab_parallel_cross_entropy(vocab_parallel_logits, target): + """Helper function for the cross entropy.""" + return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) diff --git a/megatron/mpu/data.py b/megatron/mpu/data.py new file mode 100644 index 0000000000000000000000000000000000000000..dd57a8ffc088a9ac4ab808685b3b0f318818bb3b --- /dev/null +++ b/megatron/mpu/data.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from .initialize import get_tensor_model_parallel_group +from .initialize import get_tensor_model_parallel_rank +from .initialize import get_tensor_model_parallel_src_rank + + +_MAX_DATA_DIM = 5 + + +def _check_data_types(keys, data, target_dtype): + """Check that all the keys have the same target data type.""" + for key in keys: + assert data[key].dtype == target_dtype, '{} has data type {} which '\ + 'is different than {}'.format(key, data[key].dtype, target_dtype) + + +def _build_key_size_numel_dictionaries(keys, data): + """Build the size on rank 0 and broadcast.""" + max_dim = _MAX_DATA_DIM + sizes = [0 for _ in range(max_dim) for _ in keys] + + # Pack the sizes on rank zero. + if get_tensor_model_parallel_rank() == 0: + offset = 0 + for key in keys: + assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' + size = data[key].size() + for i, s in enumerate(size): + sizes[i + offset] = s + offset += max_dim + + # Move to GPU and broadcast. + sizes_cuda = torch.cuda.LongTensor(sizes) + torch.distributed.broadcast(sizes_cuda, get_tensor_model_parallel_src_rank(), + group=get_tensor_model_parallel_group()) + + # Move back to cpu and unpack. + sizes_cpu = sizes_cuda.cpu() + key_size = {} + key_numel = {} + total_numel = 0 + offset = 0 + for key in keys: + i = 0 + size = [] + numel = 1 + while sizes_cpu[offset + i] > 0: + this_size = sizes_cpu[offset + i] + size.append(this_size) + numel *= this_size + i += 1 + key_size[key] = size + key_numel[key] = numel + total_numel += numel + offset += max_dim + + return key_size, key_numel, total_numel + + +def broadcast_data(keys, data, datatype): + """Broadcast data from rank zero of each model parallel group to the + members of the same model parallel group. + + Arguments: + keys: list of keys in the data disctionary to be broadcasted + data: data dictionary of string keys and cpu tensor values. + datatype: torch data type of all tensors in data associated + with keys. + """ + # Build (key, size) and (key, number of elements) dictionaries along + # with the total number of elements on all ranks. + key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, + data) + + # Pack on rank zero. + if get_tensor_model_parallel_rank() == 0: + # Check that all keys have the same data type. + _check_data_types(keys, data, datatype) + # Flatten the data associated with the keys + flatten_data = torch.cat( + [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() + else: + flatten_data = torch.empty(total_numel, + device=torch.cuda.current_device(), + dtype=datatype) + + # Broadcast + torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(), + group=get_tensor_model_parallel_group()) + + # Unpack + output = {} + offset = 0 + for key in keys: + size = key_size[key] + numel = key_numel[key] + output[key] = flatten_data.narrow(0, offset, numel).view(size) + offset += numel + + return output diff --git a/megatron/mpu/initialize.py b/megatron/mpu/initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..f4526d127a47f5b948add6ade8b613f374e105ee --- /dev/null +++ b/megatron/mpu/initialize.py @@ -0,0 +1,593 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Model and data parallel groups.""" + +import torch + +from .utils import ensure_divisibility + + +# Intra-layer model parallel group that the current rank belongs to. +_TENSOR_MODEL_PARALLEL_GROUP = None +# Inter-layer model parallel group that the current rank belongs to. +_PIPELINE_MODEL_PARALLEL_GROUP = None +# Model parallel group (both intra- and pipeline) that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +# Embedding group. +_EMBEDDING_GROUP = None +# Position embedding group. +_POSITION_EMBEDDING_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None + +_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None +_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None + +# These values enable us to change the mpu sizes on the fly. +_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_TENSOR_MODEL_PARALLEL_RANK = None +_MPU_PIPELINE_MODEL_PARALLEL_RANK = None + +# A list of ranks that have a copy of the embedding. +_EMBEDDING_GLOBAL_RANKS = None + +# A list of ranks that have a copy of the position embedding. +_POSITION_EMBEDDING_GLOBAL_RANKS = None + +# A list of global ranks for each pipeline group to ease calculation of the source +# rank when broadcasting from the first or last pipeline stage. +_PIPELINE_GLOBAL_RANKS = None + +# A list of global ranks for each data parallel group to ease calculation of the source +# rank when broadcasting weights from src to all other data parallel ranks +_DATA_PARALLEL_GLOBAL_RANKS = None + + + +def is_unitialized(): + """Useful for code segments that may be accessed with or without mpu initialization""" + return _DATA_PARALLEL_GROUP is None + + +def initialize_model_parallel(tensor_model_parallel_size_=1, + pipeline_model_parallel_size_=1, + virtual_pipeline_model_parallel_size_=None, + pipeline_model_parallel_split_rank_=None): + """ + Initialize model data parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. + virtual_pipeline_model_parallel_size: number of virtual stages (interleaved + pipeline). + pipeline_model_parallel_split_rank: for models with both encoder and decoder, + rank in pipeline with split point. + + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 8 tensor model-parallel groups, 4 pipeline model-parallel groups + and 8 data-parallel groups as: + 8 data_parallel groups: + [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] + 8 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 4 pipeline model-parallel groups: + [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + if torch.distributed.get_rank() == 0: + print('> initializing tensor model parallel with size {}'.format( + tensor_model_parallel_size_)) + print('> initializing pipeline model parallel with size {}'.format( + pipeline_model_parallel_size_)) + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size = torch.distributed.get_world_size() + tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size) + pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) + ensure_divisibility(world_size, + tensor_model_parallel_size * pipeline_model_parallel_size) + data_parallel_size = world_size // (tensor_model_parallel_size * + pipeline_model_parallel_size) + + num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size + num_data_parallel_groups = world_size // data_parallel_size + + if virtual_pipeline_model_parallel_size_ is not None: + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_ + + if pipeline_model_parallel_split_rank_ is not None: + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_ + + rank = torch.distributed.get_rank() + + # Build the data-parallel groups. + global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GLOBAL_RANKS + assert _DATA_PARALLEL_GROUP is None, \ + 'data parallel group is already initialized' + all_data_parallel_group_ranks = [] + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + for j in range(tensor_model_parallel_size): + ranks = range(start_rank + j, end_rank, + tensor_model_parallel_size) + all_data_parallel_group_ranks.append(list(ranks)) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _DATA_PARALLEL_GROUP = group + _DATA_PARALLEL_GLOBAL_RANKS = ranks + + # Build the model-parallel groups. + global _MODEL_PARALLEL_GROUP + assert _MODEL_PARALLEL_GROUP is None, \ + 'model parallel group is already initialized' + for i in range(data_parallel_size): + ranks = [data_parallel_group_ranks[i] + for data_parallel_group_ranks in all_data_parallel_group_ranks] + group = torch.distributed.new_group(ranks) + if rank in ranks: + _MODEL_PARALLEL_GROUP = group + + # Build the tensor model-parallel groups. + global _TENSOR_MODEL_PARALLEL_GROUP + assert _TENSOR_MODEL_PARALLEL_GROUP is None, \ + 'tensor model parallel group is already initialized' + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _TENSOR_MODEL_PARALLEL_GROUP = group + + # Build the pipeline model-parallel groups and embedding groups + # (first and last rank in each pipeline model-parallel group). + global _PIPELINE_MODEL_PARALLEL_GROUP + global _PIPELINE_GLOBAL_RANKS + assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \ + 'pipeline model parallel group is already initialized' + global _EMBEDDING_GROUP + global _EMBEDDING_GLOBAL_RANKS + assert _EMBEDDING_GROUP is None, \ + 'embedding group is already initialized' + global _POSITION_EMBEDDING_GROUP + global _POSITION_EMBEDDING_GLOBAL_RANKS + assert _POSITION_EMBEDDING_GROUP is None, \ + 'position embedding group is already initialized' + for i in range(num_pipeline_model_parallel_groups): + ranks = range(i, world_size, + num_pipeline_model_parallel_groups) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _PIPELINE_MODEL_PARALLEL_GROUP = group + _PIPELINE_GLOBAL_RANKS = ranks + # Setup embedding group (to exchange gradients between + # first and last stages). + if len(ranks) > 1: + embedding_ranks = [ranks[0], ranks[-1]] + position_embedding_ranks = [ranks[0]] + if pipeline_model_parallel_split_rank_ is not None: + if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks: + embedding_ranks = [ranks[0], + ranks[pipeline_model_parallel_split_rank_], + ranks[-1]] + if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks: + position_embedding_ranks = [ranks[0], + ranks[pipeline_model_parallel_split_rank_]] + else: + embedding_ranks = ranks + position_embedding_ranks = ranks + + group = torch.distributed.new_group(embedding_ranks) + if rank in embedding_ranks: + _EMBEDDING_GROUP = group + if rank in ranks: + _EMBEDDING_GLOBAL_RANKS = embedding_ranks + + group = torch.distributed.new_group(position_embedding_ranks) + if rank in position_embedding_ranks: + _POSITION_EMBEDDING_GROUP = group + if rank in ranks: + _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks + + +def model_parallel_is_initialized(): + """Check if model and data parallel groups are initialized.""" + if _TENSOR_MODEL_PARALLEL_GROUP is None or \ + _PIPELINE_MODEL_PARALLEL_GROUP is None or \ + _DATA_PARALLEL_GROUP is None: + return False + return True + + +def get_model_parallel_group(): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, \ + 'model parallel group is not initialized' + return _MODEL_PARALLEL_GROUP + + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \ + 'intra_layer_model parallel group is not initialized' + return _TENSOR_MODEL_PARALLEL_GROUP + + +def get_pipeline_model_parallel_group(): + """Get the pipeline model parallel group the caller rank belongs to.""" + assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \ + 'pipeline_model parallel group is not initialized' + return _PIPELINE_MODEL_PARALLEL_GROUP + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, \ + 'data parallel group is not initialized' + return _DATA_PARALLEL_GROUP + + +def get_embedding_group(): + """Get the embedding group the caller rank belongs to.""" + assert _EMBEDDING_GROUP is not None, \ + 'embedding group is not initialized' + return _EMBEDDING_GROUP + + +def get_position_embedding_group(): + """Get the position embedding group the caller rank belongs to.""" + assert _POSITION_EMBEDDING_GROUP is not None, \ + 'position embedding group is not initialized' + return _POSITION_EMBEDDING_GROUP + + +def set_tensor_model_parallel_world_size(world_size): + """Set the tensor model parallel size""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def set_pipeline_model_parallel_world_size(world_size): + """Set the pipeline model parallel size""" + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) + + +def set_tensor_model_parallel_rank(rank): + """Set tensor model parallel rank.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + _MPU_TENSOR_MODEL_PARALLEL_RANK = rank + + +def set_pipeline_model_parallel_rank(rank): + """Set pipeline model parallel rank.""" + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: + return _MPU_TENSOR_MODEL_PARALLEL_RANK + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: + return _MPU_PIPELINE_MODEL_PARALLEL_RANK + return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) + + +def get_num_layers(args, is_encoder_and_decoder_model): + """Compute the number of transformer layers resident on the current rank.""" + if get_pipeline_model_parallel_world_size() > 1: + if is_encoder_and_decoder_model: + assert args.pipeline_model_parallel_split_rank is not None + + # When a standalone embedding stage is used, a rank is taken from + # the encoder's ranks, to be used for the encoder's embedding + # layer. This way, the rank referenced by the 'split rank' remains + # the same whether or not a standalone embedding stage is used. + num_ranks_in_encoder = ( + args.pipeline_model_parallel_split_rank - 1 + if args.standalone_embedding_stage else + args.pipeline_model_parallel_split_rank + ) + num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder + assert args.num_layers % num_ranks_in_encoder == 0, \ + 'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers, num_ranks_in_encoder) + assert args.num_layers % num_ranks_in_decoder == 0, \ + 'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder) + if is_pipeline_stage_before_split(): + num_layers = ( + 0 + if args.standalone_embedding_stage + and get_pipeline_model_parallel_rank() == 0 else + args.num_layers // num_ranks_in_encoder + ) + else: + num_layers = args.num_layers // num_ranks_in_decoder + else: + assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ + 'num_layers must be divisible by transformer_pipeline_model_parallel_size' + + # When a standalone embedding stage is used, all transformer layers + # are divided among pipeline rank >= 1, while on pipeline rank 0, + # ranks either contain the input embedding layer (virtual pp rank 0), + # or no layers at all (virtual pp rank >= 1). + num_layers = ( + 0 + if args.standalone_embedding_stage + and get_pipeline_model_parallel_rank() == 0 else + args.num_layers // args.transformer_pipeline_model_parallel_size + ) + else: + num_layers = args.num_layers + return num_layers + +def get_num_layers_decoder(args, is_encoder_and_decoder_model): + """Compute the number of transformer layers resident on the current rank.""" + if get_pipeline_model_parallel_world_size() > 1: + if is_encoder_and_decoder_model: + assert args.pipeline_model_parallel_split_rank is not None + + # When a standalone embedding stage is used, a rank is taken from + # the encoder's ranks, to be used for the encoder's embedding + # layer. This way, the rank referenced by the 'split rank' remains + # the same whether or not a standalone embedding stage is used. + num_ranks_in_encoder = ( + args.pipeline_model_parallel_split_rank - 1 + if args.standalone_embedding_stage else + args.pipeline_model_parallel_split_rank + ) + num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder + assert args.num_layers_decoder % num_ranks_in_encoder == 0, \ + 'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers_decoder, num_ranks_in_encoder) + assert args.num_layers_decoder % num_ranks_in_decoder == 0, \ + 'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers_decoder, num_ranks_in_decoder) + if is_pipeline_stage_before_split(): + num_layers = ( + 0 + if args.standalone_embedding_stage + and get_pipeline_model_parallel_rank() == 0 else + args.num_layers_decoder // num_ranks_in_encoder + ) + else: + num_layers = args.num_layers_decoder // num_ranks_in_decoder + else: + assert args.num_layers_decoder % args.transformer_pipeline_model_parallel_size == 0, \ + 'num_layers must be divisible by transformer_pipeline_model_parallel_size' + + # When a standalone embedding stage is used, all transformer layers + # are divided among pipeline rank >= 1, while on pipeline rank 0, + # ranks either contain the input embedding layer (virtual pp rank 0), + # or no layers at all (virtual pp rank >= 1). + num_layers = ( + 0 + if args.standalone_embedding_stage + and get_pipeline_model_parallel_rank() == 0 else + args.num_layers_decoder // args.transformer_pipeline_model_parallel_size + ) + else: + num_layers = args.num_layers_decoder + return num_layers + + +def is_pipeline_first_stage(ignore_virtual=False): + """Return True if in the first pipeline model-parallel stage, False otherwise.""" + if not ignore_virtual: + if get_virtual_pipeline_model_parallel_world_size() is not None and \ + get_virtual_pipeline_model_parallel_rank() != 0: + return False + return get_pipeline_model_parallel_rank() == 0 + + +def is_pipeline_last_stage(ignore_virtual=False): + """Return True if in the last pipeline model-parallel stage, False otherwise.""" + if not ignore_virtual: + virtual_pipeline_model_parallel_world_size = \ + get_virtual_pipeline_model_parallel_world_size() + if virtual_pipeline_model_parallel_world_size is not None and \ + get_virtual_pipeline_model_parallel_rank() != ( + virtual_pipeline_model_parallel_world_size - 1): + return False + return get_pipeline_model_parallel_rank() == ( + get_pipeline_model_parallel_world_size() - 1) + + +def is_rank_in_embedding_group(ignore_virtual=False): + """Return true if current rank is in embedding group, False otherwise.""" + rank = torch.distributed.get_rank() + global _EMBEDDING_GLOBAL_RANKS + if ignore_virtual: + return rank in _EMBEDDING_GLOBAL_RANKS + if rank in _EMBEDDING_GLOBAL_RANKS: + if rank == _EMBEDDING_GLOBAL_RANKS[0]: + return is_pipeline_first_stage(ignore_virtual=False) + elif rank == _EMBEDDING_GLOBAL_RANKS[-1]: + return is_pipeline_last_stage(ignore_virtual=False) + else: + return True + return False + + +def is_rank_in_position_embedding_group(): + """Return true if current rank is in position embedding group, False otherwise.""" + rank = torch.distributed.get_rank() + global _POSITION_EMBEDDING_GLOBAL_RANKS + return rank in _POSITION_EMBEDDING_GLOBAL_RANKS + + +def is_pipeline_stage_before_split(rank=None): + """Return True if pipeline stage executes encoder block for a model + with both encoder and decoder.""" + if get_pipeline_model_parallel_world_size() == 1: + return True + if rank is None: + rank = get_pipeline_model_parallel_rank() + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: + return True + if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: + return True + return False + + +def is_pipeline_stage_after_split(rank=None): + """Return True if pipeline stage executes decoder block for a model + with both encoder and decoder.""" + if get_pipeline_model_parallel_world_size() == 1: + return True + if rank is None: + rank = get_pipeline_model_parallel_rank() + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: + return True + if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: + return True + return False + + +def is_pipeline_stage_at_split(): + """Return true if pipeline stage executes decoder block and next + stage executes encoder block for a model with both encoder and + decoder.""" + rank = get_pipeline_model_parallel_rank() + return is_pipeline_stage_before_split(rank) and \ + is_pipeline_stage_after_split(rank+1) + + +def get_virtual_pipeline_model_parallel_rank(): + """Return the virtual pipeline-parallel rank.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + + +def set_virtual_pipeline_model_parallel_rank(rank): + """Set the virtual pipeline-parallel rank.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank + + +def get_virtual_pipeline_model_parallel_world_size(): + """Return the virtual pipeline-parallel world size.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +def get_data_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the data parallel group.""" + assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \ + "Data parallel group is not initialized" + return _DATA_PARALLEL_GLOBAL_RANKS[0] + + +def get_pipeline_model_parallel_first_rank(): + assert _PIPELINE_GLOBAL_RANKS is not None, \ + "Pipeline parallel group is not initialized" + return _PIPELINE_GLOBAL_RANKS[0] + + +def get_pipeline_model_parallel_last_rank(): + assert _PIPELINE_GLOBAL_RANKS is not None, \ + "Pipeline parallel group is not initialized" + last_rank_local = get_pipeline_model_parallel_world_size() - 1 + return _PIPELINE_GLOBAL_RANKS[last_rank_local] + +def get_pipeline_model_parallel_next_rank(): + assert _PIPELINE_GLOBAL_RANKS is not None, \ + "Pipeline parallel group is not initialized" + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + + +def get_pipeline_model_parallel_prev_rank(): + assert _PIPELINE_GLOBAL_RANKS is not None, \ + "Pipeline parallel group is not initialized" + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return torch.distributed.get_world_size(group=get_data_parallel_group()) + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return torch.distributed.get_rank(group=get_data_parallel_group()) + + +def destroy_model_parallel(): + """Set the groups to none.""" + global _MODEL_PARALLEL_GROUP + _MODEL_PARALLEL_GROUP = None + global _TENSOR_MODEL_PARALLEL_GROUP + _TENSOR_MODEL_PARALLEL_GROUP = None + global _PIPELINE_MODEL_PARALLEL_GROUP + _PIPELINE_MODEL_PARALLEL_GROUP = None + global _DATA_PARALLEL_GROUP + _DATA_PARALLEL_GROUP = None + global _EMBEDDING_GROUP + _EMBEDDING_GROUP = None + global _POSITION_EMBEDDING_GROUP + _POSITION_EMBEDDING_GROUP = None diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..8c154f9f29c8a624249bf742757130536a1328eb --- /dev/null +++ b/megatron/mpu/layers.py @@ -0,0 +1,532 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch + + +import math + +import torch +import torch.nn.functional as F +import torch.nn.init as init +from torch.nn.parameter import Parameter + +from .initialize import get_tensor_model_parallel_rank +from .initialize import get_tensor_model_parallel_world_size +from .initialize import get_tensor_model_parallel_group +from .mappings import copy_to_tensor_model_parallel_region +from .mappings import gather_from_tensor_model_parallel_region +from .mappings import gather_from_sequence_parallel_region +from .mappings import reduce_from_tensor_model_parallel_region +from .mappings import scatter_to_tensor_model_parallel_region +from .mappings import reduce_scatter_to_sequence_parallel_region + +from .random import get_cuda_rng_tracker +from .utils import divide +from .utils import split_tensor_along_last_dim +from .utils import VocabUtility +from megatron import get_args, get_global_memory_buffer + +_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, + 'partition_dim': -1, + 'partition_stride': 1} + +def param_is_not_tensor_parallel_duplicate(param): + return (hasattr(param, 'tensor_model_parallel') and + param.tensor_model_parallel) or ( + get_tensor_model_parallel_rank() == 0) + + +def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): + # Make sure the attributes are not set. + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + assert not hasattr(tensor, attribute) + # Set the attributes. + setattr(tensor, 'tensor_model_parallel', is_parallel) + setattr(tensor, 'partition_dim', dim) + setattr(tensor, 'partition_stride', stride) + + +def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): + def maybe_set(attribute, value): + if not hasattr(tensor, attribute): + setattr(tensor, attribute, value) + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) + + +def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): + def maybe_copy(attribute): + if hasattr(source_tensor, attribute): + setattr(destination_tensor, attribute, + getattr(source_tensor, attribute)) + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + maybe_copy(attribute) + + +def _initialize_affine_weight_gpu(weight, init_method, + partition_dim, stride=1): + """Initialize affine weight for model parallel on GPU.""" + + set_tensor_model_parallel_attributes(tensor=weight, + is_parallel=True, + dim=partition_dim, + stride=stride) + + with get_cuda_rng_tracker().fork(): + init_method(weight) + + +def _initialize_affine_weight_cpu(weight, output_size, input_size, + per_partition_size, partition_dim, + init_method, stride=1, + return_master_weight=False): + """Initialize affine weight for model parallel. + + Build the master weight on all processes and scatter + the relevant chunk.""" + + set_tensor_model_parallel_attributes(tensor=weight, + is_parallel=True, + dim=partition_dim, + stride=stride) + + # Initialize master weight + master_weight = torch.empty(output_size, input_size, + dtype=torch.float, + requires_grad=False) + init_method(master_weight) + args = get_args() + master_weight = master_weight.to(dtype=args.params_dtype) + + # Split and copy + per_partition_per_stride_size = divide(per_partition_size, stride) + weight_list = torch.split(master_weight, per_partition_per_stride_size, + dim=partition_dim) + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + my_weight_list = weight_list[rank::world_size] + + with torch.no_grad(): + torch.cat(my_weight_list, dim=partition_dim, out=weight) + if return_master_weight: + return master_weight + return None + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + init_method: method to initialize weights. + """ + + def __init__(self, num_embeddings, embedding_dim, + init_method=init.xavier_normal_): + super(VocabParallelEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + # Set the detauls for compatibility. + self.padding_idx = None + self.max_norm = None + self.norm_type = 2. + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() + # Divide the weight matrix along the vocaburaly dimension. + self.vocab_start_index, self.vocab_end_index = \ + VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, get_tensor_model_parallel_rank(), + self.tensor_model_parallel_size) + self.num_embeddings_per_partition = self.vocab_end_index - \ + self.vocab_start_index + + # Allocate weights and initialize. + args = get_args() + if args.use_cpu_initialization: + self.weight = Parameter(torch.empty( + self.num_embeddings_per_partition, self.embedding_dim, + dtype=args.params_dtype)) + if args.perform_initialization: + _initialize_affine_weight_cpu( + self.weight, self.num_embeddings, self.embedding_dim, + self.num_embeddings_per_partition, 0, init_method) + else: + self.weight = Parameter(torch.empty( + self.num_embeddings_per_partition, self.embedding_dim, + device=torch.cuda.current_device(), dtype=args.params_dtype)) + if args.perform_initialization: + _initialize_affine_weight_gpu(self.weight, init_method, + partition_dim=0, stride=1) + + def forward(self, input_): + if self.tensor_model_parallel_size > 1: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | \ + (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output_parallel = F.embedding(masked_input, self.weight, + self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, + self.sparse) + # Mask the output embedding. + if self.tensor_model_parallel_size > 1: + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = reduce_from_tensor_model_parallel_region(output_parallel) + return output + + +class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication and gradient accumulation + fusion in backprop. + """ + + @staticmethod + def forward(ctx, input, weight, bias, gradient_accumulation_fusion, + async_grad_allreduce, sequence_parallel): + ctx.save_for_backward(input, weight) + ctx.use_bias = bias is not None + ctx.gradient_accumulation_fusion = gradient_accumulation_fusion + ctx.async_grad_allreduce = async_grad_allreduce + ctx.sequence_parallel = sequence_parallel + + if sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = \ + get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") + torch.distributed._all_gather_base( + all_gather_buffer, + input, + group=get_tensor_model_parallel_group()) + total_input = all_gather_buffer + else: + total_input = input + + output = torch.matmul(total_input, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + if ctx.sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = \ + get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") + handle = torch.distributed._all_gather_base( + all_gather_buffer, + input, + group=get_tensor_model_parallel_group(), async_op=True) + + # Delay the start of intput gradient computation shortly (3us) to have + # gather scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + total_input = all_gather_buffer + else: + total_input = input + grad_input = grad_output.matmul(weight) + + if ctx.sequence_parallel: + handle.wait() + + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], + grad_output.shape[2]) + total_input = total_input.view(total_input.shape[0] * total_input.shape[1], + total_input.shape[2]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = torch.distributed.all_reduce( + grad_input, group=get_tensor_model_parallel_group(), async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + if ctx.sequence_parallel: + assert not ctx.async_grad_allreduce + dim_size = list(input.size()) + sub_grad_input = torch.empty(dim_size, dtype=input.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + # reduce_scatter + handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, + group=get_tensor_model_parallel_group(), + async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + + if ctx.gradient_accumulation_fusion: + import fused_dense_cuda + fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.sequence_parallel: + handle.wait() + return sub_grad_input, grad_weight, grad_bias, None, None, None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + """ + + def __init__(self, input_size, output_size, bias=True, gather_output=True, + init_method=init.xavier_normal_, stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False): + super(ColumnParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + world_size = get_tensor_model_parallel_world_size() + self.output_size_per_partition = divide(output_size, world_size) + self.skip_bias_add = skip_bias_add + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + # Initialize weight. + args = get_args() + if args.use_cpu_initialization: + self.weight = Parameter(torch.empty(self.output_size_per_partition, + self.input_size, + dtype=args.params_dtype)) + if args.perform_initialization: + self.master_weight = _initialize_affine_weight_cpu( + self.weight, self.output_size, self.input_size, + self.output_size_per_partition, 0, init_method, + stride=stride, return_master_weight=keep_master_weight_for_test) + else: + self.weight = Parameter(torch.empty( + self.output_size_per_partition, self.input_size, + device=torch.cuda.current_device(), dtype=args.params_dtype)) + if args.perform_initialization: + _initialize_affine_weight_gpu(self.weight, init_method, + partition_dim=0, stride=stride) + + if bias: + if args.use_cpu_initialization: + self.bias = Parameter(torch.empty( + self.output_size_per_partition, dtype=args.params_dtype)) + else: + self.bias = Parameter(torch.empty( + self.output_size_per_partition, + device=torch.cuda.current_device(), + dtype=args.params_dtype)) + set_tensor_model_parallel_attributes(self.bias, True, 0, stride) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + self.async_tensor_model_parallel_allreduce = ( + args.async_tensor_model_parallel_allreduce and + world_size > 1) + self.sequence_parallel = ( + args.sequence_parallel and + world_size > 1) + assert not self.async_tensor_model_parallel_allreduce or \ + not self.sequence_parallel + self.gradient_accumulation_fusion = args.gradient_accumulation_fusion + + def forward(self, input_): + bias = self.bias if not self.skip_bias_add else None + + if self.async_tensor_model_parallel_allreduce or \ + self.sequence_parallel: + input_parallel = input_ + else: + input_parallel = copy_to_tensor_model_parallel_region(input_) + # Matrix multiply. + output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply( + input_parallel, self.weight, bias, self.gradient_accumulation_fusion, + self.async_tensor_model_parallel_allreduce, self.sequence_parallel) + if self.gather_output: + # All-gather across the partitions. + assert not self.sequence_parallel + output = gather_from_tensor_model_parallel_region(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimization where bias + can be fused with other elementwise operations. We skip + adding bias but instead return it. + """ + + def __init__(self, input_size, output_size, bias=True, + input_is_parallel=False, + init_method=init.xavier_normal_, stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False): + super(RowParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.input_is_parallel = input_is_parallel + # Divide the weight matrix along the last dimension. + world_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, world_size) + self.skip_bias_add = skip_bias_add + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + # Initialize weight. + args = get_args() + if args.use_cpu_initialization: + self.weight = Parameter(torch.empty(self.output_size, + self.input_size_per_partition, + dtype=args.params_dtype)) + if args.perform_initialization: + self.master_weight = _initialize_affine_weight_cpu( + self.weight, self.output_size, self.input_size, + self.input_size_per_partition, 1, init_method, + stride=stride, return_master_weight=keep_master_weight_for_test) + else: + self.weight = Parameter(torch.empty( + self.output_size, self.input_size_per_partition, + device=torch.cuda.current_device(), dtype=args.params_dtype)) + if args.perform_initialization: + _initialize_affine_weight_gpu(self.weight, init_method, + partition_dim=1, stride=stride) + if bias: + if args.use_cpu_initialization: + self.bias = Parameter(torch.empty(self.output_size, + dtype=args.params_dtype)) + else: + self.bias = Parameter(torch.empty( + self.output_size, device=torch.cuda.current_device(), + dtype=args.params_dtype)) + setattr(self.bias, 'sequence_parallel', args.sequence_parallel) + + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + self.sequence_parallel = args.sequence_parallel + self.gradient_accumulation_fusion = args.gradient_accumulation_fusion + + + + def forward(self, input_): + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + assert not self.sequence_parallel + input_parallel = scatter_to_tensor_model_parallel_region(input_) + # Matrix multiply. + output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply( + input_parallel, self.weight, None, + self.gradient_accumulation_fusion, None, None) + # All-reduce across all the partitions. + if self.sequence_parallel: + output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) + else: + output_ = reduce_from_tensor_model_parallel_region(output_parallel) + if not self.skip_bias_add: + output = output_ + self.bias if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..524994dca309c5c1720b3b60e6bd16919f4f2098 --- /dev/null +++ b/megatron/mpu/mappings.py @@ -0,0 +1,288 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank +from .utils import split_tensor_along_last_dim + + +def _reduce(input_): + """All-reduce the input tensor across model parallel group.""" + + # Bypass the function if we are using only 1 GPU. + if get_tensor_model_parallel_world_size()==1: + return input_ + + # All-reduce. + torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) + + return input_ + + +def _split_along_last_dim(input_): + """Split the tensor along its last dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along last dimension. + input_list = split_tensor_along_last_dim(input_, world_size) + + # Note: torch.split does not create contiguous tensors by default. + rank = get_tensor_model_parallel_rank() + output = input_list[rank].contiguous() + + return output + + +def _split_along_first_dim(input_): + """Split the tensor along its first dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along first dimension. + dim_size = input_.size()[0] + assert dim_size % world_size == 0, \ + "First dimension of the tensor should be divisible by tensor parallel size" + local_dim_size = dim_size // world_size + rank = get_tensor_model_parallel_rank() + dim_offset = rank * local_dim_size + + output = input_[dim_offset:dim_offset+local_dim_size].contiguous() + + return output + + +def _gather_along_last_dim(input_): + """Gather tensors and concatinate along the last dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Size and dimension. + last_dim = input_.dim() - 1 + rank = get_tensor_model_parallel_rank() + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=last_dim).contiguous() + + return output + + +def _gather_along_first_dim(input_): + """Gather tensors and concatinate along the first dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device()) + torch.distributed._all_gather_base(output, input_.contiguous(), + group=get_tensor_model_parallel_group()) + + return output + +def _reduce_scatter_along_first_dim(input_): + """Reduce-scatter the input tensor across model parallel group.""" + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + assert dim_size[0] % world_size == 0, \ + "First dimension of the tensor should be divisible by tensor parallel size" + + dim_size[0] = dim_size[0] // world_size + + output = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device()) + torch.distributed._reduce_scatter_base(output, input_.contiguous(), + group=get_tensor_model_parallel_group()) + return output + + +class _CopyToModelParallelRegion(torch.autograd.Function): + """Pass the input to the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_): + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output) + + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """All-reduce the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class _ScatterToModelParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split_along_last_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _split_along_last_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_last_dim(grad_output) + + +class _GatherFromModelParallelRegion(torch.autograd.Function): + """Gather the input from model parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_): + return _gather_along_last_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _gather_along_last_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _split_along_last_dim(grad_output) + + +class _ScatterToSequenceParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _split_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + """Gather the input from sequence parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_, tensor_parallel_output_grad=True): + return _gather_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_, tensor_parallel_output_grad=True): + ctx.tensor_parallel_output_grad = tensor_parallel_output_grad + return _gather_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + tensor_parallel_output_grad = ctx.tensor_parallel_output_grad + + # If the computation graph after the gather operation is + # in the tensor parallel mode, output gradients need to reduce + # scattered and whereas if the computation is duplicated, + # output gradients need to be scattered. + if tensor_parallel_output_grad: + return _reduce_scatter_along_first_dim(grad_output), None + else: + return _split_along_first_dim(grad_output), None + + +class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): + """Reduce scatter the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + +# ----------------- +# Helper functions. +# ----------------- + +def copy_to_tensor_model_parallel_region(input_): + return _CopyToModelParallelRegion.apply(input_) + + +def reduce_from_tensor_model_parallel_region(input_): + return _ReduceFromModelParallelRegion.apply(input_) + + +def scatter_to_tensor_model_parallel_region(input_): + return _ScatterToModelParallelRegion.apply(input_) + + +def gather_from_tensor_model_parallel_region(input_): + return _GatherFromModelParallelRegion.apply(input_) + + +def scatter_to_sequence_parallel_region(input_): + return _ScatterToSequenceParallelRegion.apply(input_) + + +def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): + return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) + + +def reduce_scatter_to_sequence_parallel_region(input_): + return _ReduceScatterToSequenceParallelRegion.apply(input_) + diff --git a/megatron/mpu/random.py b/megatron/mpu/random.py new file mode 100644 index 0000000000000000000000000000000000000000..142ebac0c8721c186b4a269fa8204312ecbd38b8 --- /dev/null +++ b/megatron/mpu/random.py @@ -0,0 +1,379 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch + +import contextlib + +import torch +from torch import _C +from torch.cuda import _lazy_call, device as device_ctx_manager +from torch.utils.checkpoint import detach_variable + +from megatron.memory import allocate_mem_buff + +from .initialize import get_data_parallel_rank +from .initialize import get_tensor_model_parallel_group +from .initialize import get_tensor_model_parallel_rank +from .initialize import get_tensor_model_parallel_world_size + + +# Default name for the model parallel rng tracker. +_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' + + +def _set_cuda_rng_state(new_state, device=-1): + """Sets the random number generator state of the current GPU. + + Argumentss: + new_state (torch.ByteTensor): The desired state + This function is adapted from PyTorch repo (torch.cuda.set_rng_state) + with a single change: the input state is not cloned. Cloning caused + major performance issues for +4 GPU cases. + """ + if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): + # older PyTorch + def cb(): + with device_ctx_manager(device): + _C._cuda_setRNGState(new_state) + else: + # newer PyTorch + if device == -1: + device = torch.device('cuda') + elif isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device('cuda', device) + + def cb(): + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.set_state(new_state) + + _lazy_call(cb) + + +def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): + """Break a tensor into equal 1D chunks.""" + partition_size = torch.numel(tensor) // \ + get_tensor_model_parallel_world_size() + start_index = partition_size * get_tensor_model_parallel_rank() + end_index = start_index + partition_size + if new_buffer: + data = torch.empty(partition_size, dtype=tensor.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + data.copy_(tensor.view(-1)[start_index:end_index]) + else: + data = tensor.view(-1)[start_index:end_index] + return data + + +def gather_split_1d_tensor(tensor): + """Opposite of above function, gather values from model parallel ranks.""" + numel_gathered = torch.numel(tensor) * \ + get_tensor_model_parallel_world_size() + gathered = torch.empty(numel_gathered, dtype=tensor.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + # TODO: This API is experimental in pytorch (as of Feb 2022) and + # this might break in future pytorch releases. We chose this API + # as opposed to torch.distributed.all_gather for efficiency reasons. + # This API calls directly NCCL all-gather versus the former does + # internal copies and can potentially cause slow down. + torch.distributed._all_gather_base(gathered, tensor, + group=get_tensor_model_parallel_group()) + return gathered + + +def _kernel_make_viewless_tensor(inp, requires_grad): + '''Make a viewless tensor. + + View tensors have the undesirable side-affect of retaining a reference + to the originally-viewed tensor, even after manually setting the '.data' + field. This method creates a new tensor that links to the old tensor's + data, without linking the viewed tensor, referenced via the '._base' + field. + ''' + out = torch.empty( + (1,), + dtype = inp.dtype, + device = inp.device, + requires_grad = requires_grad, + ) + out.data = inp.data + return out + +class MakeViewlessTensor(torch.autograd.Function): + ''' + Autograd function to make a viewless tensor. + + This function should be used in cases where the computation graph needs + to be propagated, but we only want a viewless tensor (e.g., + ParallelTransformer's hidden_states). Call this function by passing + 'keep_graph = True' to 'make_viewless_tensor()'. + ''' + @staticmethod + def forward(ctx, inp, requires_grad): + return _kernel_make_viewless_tensor(inp, requires_grad) + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + +def make_viewless_tensor(inp, requires_grad, keep_graph): + ''' + Entry-point for creating viewless tensors. + + This method should be used, rather than calling 'MakeViewlessTensor' + or '_kernel_make_viewless_tensor' directly. This method acts as a + switch for determining if an autograd function or a regular method + should be used to create the tensor. + ''' + + # return tensor as-is, if not a 'view' + if inp._base is None: + return inp + + # create viewless tensor + if keep_graph: + return MakeViewlessTensor.apply(inp, requires_grad) + else: + return _kernel_make_viewless_tensor(inp, requires_grad) + +def assert_viewless_tensor(tensor, extra_msg = None): + '''Assert that a tensor is not a view (i.e., its '._base' field is + not set).''' + if isinstance(tensor, list): + [ assert_viewless_tensor(t) for t in tensor ] + return tensor + if not isinstance(tensor, torch.Tensor): + return tensor + assert tensor._base is None, ( + "Ensure tensor._base is None before setting tensor.data or storing " + "tensor to memory buffer. Otherwise, a memory leak will occur (and " + "likely accumulate over iterations). %s" + ) % extra_msg + return tensor + +def safely_set_viewless_tensor_data(tensor, new_data_tensor): + '''Safely set tensor's '.data' field. + + Check first that the tensor is viewless (i.e., '._base' not set). If not, + raise an exception. + ''' + assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape)) + tensor.data = new_data_tensor + + +class CudaRNGStatesTracker: + """Tracker for the cuda RNG states. + + Using the `add` method, a cuda rng state is initialized based on + the input `seed` and is assigned to `name`. Later, by forking the + rng state, we can perform operations and return to our starting + cuda state. + """ + + def __init__(self): + # Map from a string name to the cuda rng state. + self.states_ = {} + # Seeds are just for book keeping and ensure no seed is set twice. + self.seeds_ = set() + + def reset(self): + """Set to the initial state (no tracker).""" + self.states_ = {} + self.seeds_ = set() + + def get_states(self): + """Get rng states. Copy the dictionary so we have direct + pointers to the states, not just a pointer to the dictionary.""" + states = {} + for name in self.states_: + states[name] = self.states_[name] + return states + + def set_states(self, states): + """Set the rng states. For efficiency purposes, we do not check + the size of seed for compatibility.""" + self.states_ = states + + def add(self, name, seed): + """Track the rng state.""" + # Check seed is not already used. + if seed in self.seeds_: + raise Exception('seed {} already exists'.format(seed)) + self.seeds_.add(seed) + # Check that state is not already defined. + if name in self.states_: + raise Exception('cuda rng state {} already exists'.format(name)) + # Get the current rng state. + orig_rng_state = torch.cuda.get_rng_state() + # Set the new state and store it. + torch.cuda.manual_seed(seed) + self.states_[name] = torch.cuda.get_rng_state() + # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) + + @contextlib.contextmanager + def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): + """Fork the cuda rng state, perform operations, and exit with + the original state.""" + # Check if we have added the state + if name not in self.states_: + raise Exception('cuda rng state {} is not added'.format(name)) + # Store current rng state. + orig_cuda_rng_state = torch.cuda.get_rng_state() + # Set rng state to the desired one + _set_cuda_rng_state(self.states_[name]) + # Do the stuff we wanted to do. + try: + yield + finally: + # Update the current rng state for later use. + self.states_[name] = torch.cuda.get_rng_state() + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state) + + +# RNG tracker object. +_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + + +def get_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _CUDA_RNG_STATE_TRACKER + + +def model_parallel_cuda_manual_seed(seed): + """Initialize model parallel cuda seed. + + This function should be called after the model parallel is + initialized. Also, no torch.cuda.manual_seed should be called + after this function. Basically, this is replacement for that + function. + Two set of RNG states are tracked: + default state: This is for data parallelism and is the same among a + set of model parallel GPUs but different across + different model paralle groups. This is used for + example for dropout in the non-tensor-model-parallel regions. + tensor-model-parallel state: This state is different among a set of model + parallel GPUs, but the same across data parallel + groups. This is used for example for dropout in + model parallel regions. + """ + # 2718 is just for fun and any POSITIVE value will work. + offset = seed + 2718 + tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() + # Data parallel gets the original seed. + data_parallel_seed = seed + + if torch.distributed.get_rank() == 0: + print('> initializing model parallel cuda seeds on global rank {}, ' + 'model parallel rank {}, and data parallel rank {} with ' + 'model parallel seed: {} and data parallel seed: {}'.format( + torch.distributed.get_rank(), get_tensor_model_parallel_rank(), + get_data_parallel_rank(), tensor_model_parallel_seed, + data_parallel_seed), flush=True) + _CUDA_RNG_STATE_TRACKER.reset() + # Set the default state. + torch.cuda.manual_seed(data_parallel_seed) + # and model parallel state. + _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, + tensor_model_parallel_seed) + + +class CheckpointFunction(torch.autograd.Function): + """This function is adapted from torch.utils.checkpoint with + two main changes: + 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` + 2) the states in the model parallel tracker are also properly + tracked/set/reset. + """ + @staticmethod + def forward(ctx, run_function, distribute_saved_activations, *args): + ctx.run_function = run_function + ctx.distribute_saved_activations \ + = distribute_saved_activations + + # Copy the rng states. + ctx.fwd_cpu_rng_state = torch.get_rng_state() + ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + with torch.no_grad(): + outputs = run_function(*args) + + # Divide hidden states across model parallel group and only keep + # the chunk corresponding to the current rank. + if distribute_saved_activations: + ctx.input_0_shape = args[0].data.shape + safely_set_viewless_tensor_data( + args[0], + split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)) + + # Store everything. + ctx.save_for_backward(*args) + + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("Checkpointing is not compatible with .grad(), " + "please use .backward() if possible") + inputs = ctx.saved_tensors + if ctx.distribute_saved_activations: + safely_set_viewless_tensor_data( + inputs[0], + gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)) + + # Store the current states. + bwd_cpu_rng_state = torch.get_rng_state() + bwd_cuda_rng_state = torch.cuda.get_rng_state() + bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(ctx.fwd_cpu_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) + + # Compute the forward pass. + detached_inputs = detach_variable(inputs) + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_cpu_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + torch.autograd.backward(outputs, args) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp + for inp in detached_inputs) + return (None, None) + grads + + +def checkpoint(function, distribute_saved_activations, *args): + """Checkpoint a model or part of the model. + This has been directly copied from torch.utils.checkpoint.""" + return CheckpointFunction.apply(function, + distribute_saved_activations, *args) diff --git a/megatron/mpu/tests/__init__.py b/megatron/mpu/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/megatron/mpu/tests/commons.py b/megatron/mpu/tests/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..5e7a186728726517d1b59e1cc861829cbb86f18a --- /dev/null +++ b/megatron/mpu/tests/commons.py @@ -0,0 +1,83 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import random +import numpy +import torch + +import mpu + + +class IdentityLayer(torch.nn.Module): + def __init__(self, size, scale=1.0): + super(IdentityLayer, self).__init__() + self.weight = torch.nn.Parameter(scale * torch.randn(size)) + + def forward(self): + return self.weight + + +def set_random_seed(seed): + """Set random seed for reproducability.""" + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + mpu.model_parallel_cuda_manual_seed(seed) + + +def initialize_distributed(backend='nccl'): + """Initialize torch.distributed.""" + # Get local rank in case it is provided. + parser = argparse.ArgumentParser() + parser.add_argument('--local_rank', type=int, default=None, + help='local rank passed from distributed launcher') + args = parser.parse_args() + local_rank = args.local_rank + + # Get rank and world size. + rank = int(os.getenv('RANK', '0')) + world_size = int(os.getenv("WORLD_SIZE", '1')) + + print('> initializing torch.distributed with local rank: {}, ' + 'rank: {}, world size: {}'.format(local_rank, rank, world_size)) + + # Set the device id. + device = rank % torch.cuda.device_count() + if local_rank is not None: + device = local_rank + torch.cuda.set_device(device) + + # Call the init process. + init_method = 'tcp://' + master_ip = os.getenv('MASTER_ADDR', 'localhost') + master_port = os.getenv('MASTER_PORT', '6000') + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + init_method=init_method) + + +def print_separator(message): + torch.distributed.barrier() + filler_len = (78 - len(message)) // 2 + filler = '-' * filler_len + string = '\n' + filler + ' {} '.format(message) + filler + if torch.distributed.get_rank() == 0: + print(string, flush=True) + torch.distributed.barrier() diff --git a/megatron/mpu/tests/test_cross_entropy.py b/megatron/mpu/tests/test_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..46d7ba981c906f410e93eddadb2f272e97f5cbb0 --- /dev/null +++ b/megatron/mpu/tests/test_cross_entropy.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from commons import set_random_seed +from commons import IdentityLayer +from commons import print_separator +from commons import initialize_distributed +from mpu.cross_entropy import vocab_parallel_cross_entropy +import mpu +import torch.nn.functional as F +import torch +import random +import sys +sys.path.append("../..") + + +def torch_cross_entropy(batch_size, seq_length, vocab_size, + logits_scale, seed): + set_random_seed(seed) + identity = IdentityLayer((batch_size, seq_length, vocab_size), + scale=logits_scale).cuda() + logits = identity() + target = torch.cuda.LongTensor( + size=(batch_size, seq_length)).random_(0, vocab_size) + loss = F.cross_entropy(logits.view(-1, logits.size()[-1]), + target.view(-1), + reduction='none').view_as(target).mean() + loss.backward() + return loss, identity.weight.grad + + +def mpu_cross_entropy(batch_size, seq_length, vocab_size, + logits_scale, seed): + set_random_seed(seed) + identity = IdentityLayer((batch_size, seq_length, vocab_size), + scale=logits_scale).cuda() + logits = identity() + logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits) + target = torch.cuda.LongTensor( + size=(batch_size, seq_length)).random_(0, vocab_size) + loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() + loss.backward() + return loss, identity.weight.grad + + +def test_cross_entropy(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing cross entropy with model parallel size {} ...'. + format(tensor_model_parallel_size)) + + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + batch_size = 13 + seq_length = 17 + vocab_size_per_partition = 11 + logits_scale = 1000.0 + vocab_size = vocab_size_per_partition * tensor_model_parallel_size + seed = 1234 + + loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, + vocab_size, logits_scale, + seed) + loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, + vocab_size, logits_scale, + seed) + + error = loss_torch.sub_(loss_mpu).abs().max() + print(' max error in loss on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = grad_torch.sub_(grad_mpu).abs().max() + print(' max error in grad on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_tensor_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test cross entropy') + test_cross_entropy(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 diff --git a/megatron/mpu/tests/test_data.py b/megatron/mpu/tests/test_data.py new file mode 100644 index 0000000000000000000000000000000000000000..ae36277036496c3acb9b632aa4e4f9935fe61d18 --- /dev/null +++ b/megatron/mpu/tests/test_data.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from commons import print_separator +from commons import initialize_distributed +from mpu import data as data_utils +import mpu +import torch +import functools +import operator +import sys +sys.path.append("../..") + + +def test_broadcast_data(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing broadcast_data with model parallel size {} ...'. + format(tensor_model_parallel_size)) + + mpu.initialize_model_parallel(tensor_model_parallel_size) + torch.manual_seed(1234 + mpu.get_data_parallel_rank()) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + key_size_t = {'key1': [7, 11], + 'key2': [8, 2, 1], + 'key3': [13], + 'key4': [5, 1, 2], + 'key5': [5, 12]} + keys = list(key_size_t.keys()) + + data = {} + data_t = {} + for key in key_size_t: + data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) + data_t[key] = data[key].clone() + data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) + data_t['keyX'] = data['keyX'].clone() + if mpu.get_tensor_model_parallel_rank() != 0: + data = None + + data_utils._check_data_types(keys, data_t, torch.int64) + key_size, key_numel, \ + total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) + for key in keys: + assert key_size[key] == key_size_t[key] + total_numel_t = 0 + for key in keys: + target_size = functools.reduce(operator.mul, key_size_t[key], 1) + assert key_numel[key] == target_size + total_numel_t += target_size + assert total_numel == total_numel_t + + data_b = data_utils.broadcast_data(keys, data, torch.int64) + for key in keys: + tensor = data_t[key].cuda() + assert data_b[key].sub(tensor).abs().max() == 0 + + # Reset groups + mpu.destroy_tensor_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test test broadcast data') + test_broadcast_data(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 diff --git a/megatron/mpu/tests/test_initialize.py b/megatron/mpu/tests/test_initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..ba505b8d5c39060a5687654f1a023f2f4527287a --- /dev/null +++ b/megatron/mpu/tests/test_initialize.py @@ -0,0 +1,95 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from commons import print_separator +from commons import initialize_distributed +import mpu +import torch +import sys +sys.path.append("../..") + + +def test_initialize_model_parallel(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing initialize_model_parallel with size {} ...'.format( + tensor_model_parallel_size)) + tensor_model_parallel_size_ = min(tensor_model_parallel_size, + torch.distributed.get_world_size()) + assert not mpu.model_parallel_is_initialized() + mpu.initialize_model_parallel(tensor_model_parallel_size_) + assert mpu.model_parallel_is_initialized() + + # Checks. + def check(group, world_size, rank): + assert world_size == torch.distributed.get_world_size(group=group) + assert rank == torch.distributed.get_rank(group=group) + + # Model parallel. + world_size = tensor_model_parallel_size_ + rank = torch.distributed.get_rank() % tensor_model_parallel_size_ + assert world_size == mpu.get_tensor_model_parallel_world_size() + assert rank == mpu.get_tensor_model_parallel_rank() + check(mpu.get_tensor_model_parallel_group(), world_size, rank) + + # Data parallel. + world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_ + rank = torch.distributed.get_rank() // tensor_model_parallel_size + assert world_size == mpu.get_data_parallel_world_size() + assert rank == mpu.get_data_parallel_rank() + check(mpu.get_data_parallel_group(), world_size, rank) + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_): + + if torch.distributed.get_rank() == 0: + print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format( + tensor_model_parallel_size_)) + tensor_model_parallel_size = min(tensor_model_parallel_size_, + torch.distributed.get_world_size()) + assert not mpu.model_parallel_is_initialized() + mpu.initialize_model_parallel(tensor_model_parallel_size) + assert mpu.model_parallel_is_initialized() + + # Checks + src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank() + assert mpu.get_tensor_model_parallel_src_rank() == src_rank + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test initialize model parallel') + test_initialize_model_parallel(tensor_model_parallel_size) + print_separator('test model parallel source rank') + test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 diff --git a/megatron/mpu/tests/test_layers.py b/megatron/mpu/tests/test_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b12f48509bb8f0df75a3b2e02b1230ba1163ef84 --- /dev/null +++ b/megatron/mpu/tests/test_layers.py @@ -0,0 +1,530 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mpu import layers +from commons import set_random_seed +from commons import print_separator +from commons import initialize_distributed +import mpu +from torch.nn.parameter import Parameter +import torch.nn.init as init +import torch +import random +import sys +sys.path.append("../..") + + +def test_parallel_embedding(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing parallel embedding with model parallel size {} ...'. + format(tensor_model_parallel_size)) + + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + batch_size = 17 + seq_length = 23 + vocab_size = 48 + hidden_size = 16 + seed = 1236 + + set_random_seed(123) + input_data = torch.LongTensor( + size=(batch_size, seq_length)).random_(0, vocab_size).cuda() + loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() + + set_random_seed(seed) + embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() + + output = embedding_original(input_data) + loss_original = torch.mul(output, loss_weight).sum() + loss_original.backward() + + set_random_seed(seed) + embedding_parallel = layers.ParallelEmbedding( + vocab_size, hidden_size, init_method=init.normal_).cuda() + output = embedding_parallel(input_data) + loss_parallel = torch.mul(output, loss_weight).sum() + loss_parallel.backward() + + set_random_seed(seed) + embedding_vocab_parallel = layers.VocabParallelEmbedding( + vocab_size, hidden_size, init_method=init.normal_).cuda() + output = embedding_vocab_parallel(input_data) + loss_vocab_parallel = torch.mul(output, loss_weight).sum() + loss_vocab_parallel.backward() + + torch.distributed.barrier() + error = loss_parallel.sub(loss_original).abs() + print(' error in loss (parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + torch.distributed.barrier() + error = loss_vocab_parallel.sub(loss_original).abs() + print(' error in loss (vocab parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + weight_grad_orig = torch.split(embedding_original.weight.grad, + hidden_size // tensor_model_parallel_size, + 1)[mpu.get_tensor_model_parallel_rank()] + error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() + print(' error in grad (parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + weight_grad_orig = torch.split(embedding_original.weight.grad, + vocab_size // tensor_model_parallel_size, + 0)[mpu.get_tensor_model_parallel_rank()] + error = embedding_vocab_parallel.weight.grad.sub( + weight_grad_orig).abs().max() + print(' error in grad (vocab parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_initialize_affine_weight(tensor_model_parallel_size): + + mpu.initialize_model_parallel(tensor_model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing initialize_affine_weight with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + seed = 12345 + input_size_coeff = 13 + input_size = input_size_coeff * tensor_model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * tensor_model_parallel_size + + # --------------- + # Column parallel + # --------------- + weight = torch.empty(output_size_coeff, input_size) + set_random_seed(seed) + layers._initialize_affine_weight(weight, output_size, input_size, + + output_size_coeff, 0, + torch.nn.init.normal_) + # Target. + set_random_seed(seed) + master_weight = torch.empty(output_size, input_size) + torch.nn.init.normal_(master_weight) + rank = mpu.get_tensor_model_parallel_rank() + my_weight = torch.split(master_weight, output_size_coeff, + dim=0)[rank].contiguous().clone() + + # Compare. + error = weight.sub(my_weight).abs().max() + torch.distributed.barrier() + print(' column parallel max error (should be zero) on global rank ' + '{}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # ------------ + # Row parallel + # ------------ + weight = torch.empty(output_size, input_size_coeff) + set_random_seed(seed) + mpu.layers._initialize_affine_weight(weight, output_size, input_size, + input_size_coeff, 1, + torch.nn.init.normal_) + # Target. + set_random_seed(seed) + master_weight = torch.empty(output_size, input_size) + torch.nn.init.normal_(master_weight) + rank = mpu.get_tensor_model_parallel_rank() + my_weight = torch.split(master_weight, input_size_coeff, + dim=1)[rank].contiguous().clone() + + # Compare. + error = weight.sub(my_weight).abs().max() + torch.distributed.barrier() + print(' row parallel max error (should be zero) on global rank ' + '{}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +class IdentityLayer2D(torch.nn.Module): + def __init__(self, m, n): + super(IdentityLayer2D, self).__init__() + self.weight = Parameter(torch.Tensor(m, n)) + torch.nn.init.xavier_normal_(self.weight) + + def forward(self): + return self.weight + + +def test_column_parallel_linear(tensor_model_parallel_size): + + mpu.initialize_model_parallel(tensor_model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing ColumnParallelLinear with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + input_size_coeff = 13 + input_size = input_size_coeff * tensor_model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * tensor_model_parallel_size + batch_size = 7 + + # Network + identity_layer = IdentityLayer2D(batch_size, input_size).cuda() + linear_layer = mpu.ColumnParallelLinear( + input_size, output_size, keep_master_weight_for_test=True).cuda() + loss_weight = torch.randn([batch_size, output_size]).cuda() + # Forward + input_ = identity_layer() + output = linear_layer(input_) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + # Values. + dLdY = loss_weight + X = identity_layer.weight + A = linear_layer.master_weight.cuda() + dLdA = torch.matmul(dLdY.t(), X) + dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) + dLdX = torch.matmul(dLdY, A) + + rank = mpu.get_tensor_model_parallel_rank() + my_dLdA = torch.split(dLdA, output_size_coeff, + dim=0)[rank].contiguous().clone() + error = my_dLdA.sub(linear_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdA on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + my_dLdb = torch.split(dLdb, output_size_coeff, + dim=0)[rank].contiguous().clone() + error = my_dLdb.sub(linear_layer.bias.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdb on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdX.sub(identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdX on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +def test_row_parallel_linear(tensor_model_parallel_size): + + mpu.initialize_model_parallel(tensor_model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing RowParallelLinear with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + input_size_coeff = 13 + input_size = input_size_coeff * tensor_model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * tensor_model_parallel_size + batch_size = 7 + + # Network + identity_layer = IdentityLayer2D(batch_size, input_size).cuda() + linear_layer = mpu.RowParallelLinear( + input_size, output_size, keep_master_weight_for_test=True).cuda() + loss_weight = torch.randn([batch_size, output_size]).cuda() + # Forward + input_ = identity_layer() + output = linear_layer(input_) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + # Values. + dLdY = loss_weight + X = identity_layer.weight + A = linear_layer.master_weight.cuda() + dLdA = torch.matmul(dLdY.t(), X) + dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) + dLdX = torch.matmul(dLdY, A) + + rank = mpu.get_tensor_model_parallel_rank() + my_dLdA = torch.split(dLdA, input_size_coeff, + dim=1)[rank].contiguous().clone() + error = my_dLdA.sub(linear_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdA on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdb.sub(linear_layer.bias.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdb on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdX.sub(identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdX on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +class IdentityLayer3D(torch.nn.Module): + def __init__(self, m, n, k): + super(IdentityLayer3D, self).__init__() + self.weight = Parameter(torch.Tensor(m, n, k)) + torch.nn.init.xavier_normal_(self.weight) + + def forward(self): + return self.weight + + +def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, + sequence_length): + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + + num_att_heads = num_att_heads_per_partition * \ + torch.distributed.get_world_size() + hidden_size = hidden_size_per_att_head * num_att_heads + + # Network + identity_layer = IdentityLayer3D(batch_size, sequence_length, + hidden_size).cuda() + attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads, + dropout_prob).cuda() + loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() + attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() + # Forward + input_ = identity_layer() + output = attention_layer(input_, attention_mask) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + rank = mpu.get_tensor_model_parallel_rank() + mpu.destroy_model_parallel() + return rank, hidden_size, tensor_model_parallel_size, loss, \ + attention_layer, identity_layer + + +def test_parallel_self_attention(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing ParallelSelfAttention with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + + num_att_heads_per_partition = 3 + hidden_size_per_att_head = 7 + dropout_prob = 0.0 # has to be zero + batch_size = 5 + sequence_length = 13 + + rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \ + attention_layer_1, identity_layer_1 = parallel_self_attention( + 1, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) + + rank, hidden_size, tensor_model_parallel_size, loss, \ + attention_layer, identity_layer = parallel_self_attention( + tensor_model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) + assert hideen_size_1 == hidden_size + + error = loss_1.sub(loss).abs().max() + torch.distributed.barrier() + print(' loss error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-6 + + my_lin_grad_list = torch.split( + attention_layer_1.query_key_value.weight.grad, + hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size] + my_lin_grad = torch.cat(my_lin_grad_list, dim=0) + error = my_lin_grad.sub( + attention_layer.query_key_value.weight.grad).abs().max() + torch.distributed.barrier() + print(' weight gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-6 + + error = identity_layer_1.weight.grad.sub( + identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' input gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-6 + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length): + + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + + num_att_heads = num_att_heads_per_partition * \ + torch.distributed.get_world_size() + hidden_size = hidden_size_per_att_head * num_att_heads + intermediate_size = 4 * hidden_size + + # Network + identity_layer = IdentityLayer3D(batch_size, sequence_length, + hidden_size).cuda() + transformer_layer = mpu.BertParallelTransformerLayer( + hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, + torch.nn.functional.relu, 1.0e-5).cuda() + + loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() + attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() + # Forward + input_ = identity_layer() + output = transformer_layer(input_, attention_mask) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + rank = mpu.get_tensor_model_parallel_rank() + mpu.destroy_model_parallel() + return rank, hidden_size, tensor_model_parallel_size, loss, \ + transformer_layer, identity_layer + + +def test_parallel_transformer_layer(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing ParallelTransformerLayer with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + + num_att_heads_per_partition = 3 + hidden_size_per_att_head = 7 + batch_size = 5 + sequence_length = 13 + + rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \ + transformer_layer_1, identity_layer_1 = parallel_transformer( + 1, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length) + + rank, hidden_size, tensor_model_parallel_size, loss, \ + transformer_layer, identity_layer = parallel_transformer( + tensor_model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length) + + error = loss_1.sub(loss).abs().max() + torch.distributed.barrier() + print(' loss error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-5, 'error: {}'.format(error) + + error = identity_layer_1.weight.grad.sub( + identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' input gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-5, 'error: {}'.format(error) + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +if __name__ == '__main__': + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + print_separator('test initialize affine weight') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_initialize_affine_weight(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test parallel embedding') + test_parallel_embedding(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + print_separator('test column-parallel linear') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_column_parallel_linear(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + print_separator('test row-parallel linear') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_row_parallel_linear(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + print_separator('test parallel self-attention') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_parallel_self_attention(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + print_separator('test parallel transformer') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_parallel_transformer_layer(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 diff --git a/megatron/mpu/tests/test_random.py b/megatron/mpu/tests/test_random.py new file mode 100644 index 0000000000000000000000000000000000000000..9c9c503410f7f3bd790b0b55a8cca696fc55097d --- /dev/null +++ b/megatron/mpu/tests/test_random.py @@ -0,0 +1,204 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from commons import print_separator +from commons import initialize_distributed +import mpu +import torch +import sys +sys.path.append("../..") + + +def test_set_cuda_rng_state(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing set_rng_state with size {} ...'. + format(tensor_model_parallel_size)) + + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + size = 123 + seed = 1234 + torch.cuda.manual_seed(1234) + tensor = torch.cuda.FloatTensor(size) + + # Get the state + rng_state = torch.cuda.get_rng_state() + rng_state_copy = rng_state.clone() + + # Do some stuff. + for _ in range(5): + torch.randn(size, out=tensor) + result_1 = tensor.clone() + + assert rng_state.sub(rng_state_copy).max() == 0 + assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0 + + # State should be different. + new_rng_state = torch.cuda.get_rng_state() + max_diff = new_rng_state.sub(rng_state).max() + print(' max diff in rng state (should be non-zero) on global rank {}: {}'. + format(torch.distributed.get_rank(), max_diff)) + assert max_diff > 0 + + # Reset the rng state and do the same stuff. + mpu.random._set_cuda_rng_state(rng_state) + for _ in range(5): + torch.randn(size, out=tensor) + mpu.random._set_cuda_rng_state(rng_state) + for _ in range(5): + torch.randn(size, out=tensor) + result_2 = tensor.clone() + + # Results should be the same + error = result_2.sub(result_1).abs().max() + print(' max error in generated tensors (should be zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Input state should have remained intact. + error = rng_state.sub(rng_state_copy).max() + print(' max error in rng state (should be zero) on global rank {}: {}'. + format(torch.distributed.get_rank(), error)) + assert error == 0 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_cuda_rng_tracker(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing cuda rng tracker with size {} ...'. + format(tensor_model_parallel_size)) + + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + seed_1 = 1234 + seed_2 = 4321 + size = [12, 21] + tensor = torch.cuda.FloatTensor(size) + + # Set to seed_1 and generate two tensors. + torch.cuda.manual_seed(seed_1) + torch.randn(size, out=tensor) + target_11 = tensor.clone() + torch.randn(size, out=tensor) + target_12 = tensor.clone() + + # Set to seed_2 and generate two tensors. + torch.cuda.manual_seed(seed_2) + torch.randn(size, out=tensor) + target_21 = tensor.clone() + torch.randn(size, out=tensor) + target_22 = tensor.clone() + + # Now if we interleave seed_1 and seed_2, + # we should still get the same tensors + torch.cuda.manual_seed(seed_1) + mpu.get_cuda_rng_tracker().add('test', seed_2) + + torch.randn(size, out=tensor) + result_11 = tensor.clone() + + with mpu.get_cuda_rng_tracker().fork('test'): + torch.randn(size, out=tensor) + result_21 = tensor.clone() + + torch.randn(size, out=tensor) + result_12 = tensor.clone() + + with mpu.get_cuda_rng_tracker().fork('test'): + torch.randn(size, out=tensor) + result_22 = tensor.clone() + + diff = result_11.sub(result_21).abs().max() + diff = min(diff, result_12.sub(result_22).abs().max()) + print(' max diff in generated tensors (should be non-zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), diff)) + assert diff > 1.0e-6 + error = max(result_11.sub(target_11).abs().max(), + result_12.sub(target_12).abs().max()) + error = max(error, result_21.sub(target_21).abs().max()) + error = max(error, result_22.sub(target_22).abs().max()) + print(' max error in generated tensors (should be zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset the tracker + mpu.get_cuda_rng_tracker().reset() + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing model parallel cuda manual seed with size {} ...'. + format(tensor_model_parallel_size)) + + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + mpu.model_parallel_cuda_manual_seed(12345) + assert torch.cuda.initial_seed() == 12345 + with mpu.get_cuda_rng_tracker().fork(): + assert torch.cuda.initial_seed() == (12345 + 2718 + + mpu.get_tensor_model_parallel_rank()) + + # Reset the tracker + mpu.get_cuda_rng_tracker().reset() + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test set rng state') + test_set_cuda_rng_state(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test cuda rng tracker') + test_cuda_rng_tracker(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test model parallel cuda manual seed') + test_model_parallel_cuda_manual_seed(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 diff --git a/megatron/mpu/utils.py b/megatron/mpu/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56ed1c76e1404389f18ab3be01e13dfdc678d942 --- /dev/null +++ b/megatron/mpu/utils.py @@ -0,0 +1,70 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, '{} is not divisible by {}'.format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim(tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class VocabUtility: + """Split the vocabulary into `world_size` chunks amd return the + first and last index of the vocabulary belonging to the `rank` + partition: Note that indecies in [fist, last)""" + + @staticmethod + def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, + rank, world_size): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + @staticmethod + def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, world_size) diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b95514a0fd694008cb340a074b8aa05b802673f --- /dev/null +++ b/megatron/optimizer/__init__.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from apex.optimizers import FusedAdam as Adam +from apex.optimizers import FusedSGD as SGD + +from megatron import get_args + +from .distrib_optimizer import DistributedOptimizer +from .grad_scaler import ConstantGradScaler, DynamicGradScaler +from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer + + +def get_param_groups(modules, + no_weight_decay_cond, + scale_lr_cond, + lr_mult): + """creates param groups based on weight decay condition (regularized vs non regularized) + and learning rate scale condition (args.lr vs lr_mult * args.lr) + scale_lr_cond is used during finetuning where head of the network requires a scaled + version of the base learning rate. + """ + wd_no_scale_lr = [] + wd_scale_lr = [] + no_wd_no_scale_lr = [] + no_wd_scale_lr = [] + for module in modules: + for name, param in module.named_parameters(): + if not param.requires_grad: + continue + + if no_weight_decay_cond is not None: + no_wd = no_weight_decay_cond(name, param) + else: + # do not regularize biases nor Norm parameters + no_wd = name.endswith(".bias") or len(param.shape) == 1 + + if scale_lr_cond is not None: + scale_lr = scale_lr_cond(name, param) + else: + scale_lr = False + + if not no_wd and not scale_lr: + wd_no_scale_lr.append(param) + elif not no_wd and scale_lr: + wd_scale_lr.append(param) + elif no_wd and not scale_lr: + no_wd_no_scale_lr.append(param) + else: + no_wd_scale_lr.append(param) + + param_groups = [] + if len(wd_no_scale_lr): + param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0}) + if len(wd_scale_lr): + param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult}) + if len(no_wd_no_scale_lr): + param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0}) + if len(no_wd_scale_lr): + param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult}) + + return param_groups + +def get_megatron_optimizer(model, + no_weight_decay_cond=None, + scale_lr_cond=None, + lr_mult=1.0): + args = get_args() + + # Base optimizer. + param_groups = get_param_groups(model, + no_weight_decay_cond, + scale_lr_cond, + lr_mult) + + if args.optimizer == 'adam': + optimizer = Adam(param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_eps) + elif args.optimizer == 'sgd': + optimizer = SGD(param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + momentum=args.sgd_momentum) + else: + raise Exception('{} optimizer is not supported.'.format( + args.optimizer)) + + # Determine whether the params have main-grad field. + params_have_main_grad = False + if args.DDP_impl == 'local': + params_have_main_grad = True + + # Mixed precision optimizer. + # - Note: both the Float16Optimizer and the DistributedOptimizer inherit + # from the MixedPrecisionOptimizer, which manages any optimizer where + # the model params and main params are distinct. + if args.fp16 or args.bf16 or args.use_distributed_optimizer: + + # Grad scaler: + # if loss-scale is provided, instantiate the constant scaler. + # if we are using fp16 and loss-scale is not present, use a + # dynamic scaler. + # otherwise we are running in bf16 with no loss-scale so + # leave it as None. + grad_scaler = None + + # Constant loss scale. + if args.loss_scale: + grad_scaler = ConstantGradScaler(args.loss_scale) + + # Dynamic loss scale. + else: + if args.fp16: + grad_scaler = DynamicGradScaler( + initial_scale=args.initial_loss_scale, + min_scale=args.min_loss_scale, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=args.loss_scale_window, + hysteresis=args.hysteresis) + + # Megatron optimizer. + opt_ty = DistributedOptimizer \ + if args.use_distributed_optimizer else \ + Float16OptimizerWithFloat16Params + return opt_ty(optimizer, + args.clip_grad, + args.log_num_zeros_in_grad, + params_have_main_grad, + args.use_contiguous_buffers_in_local_ddp, + args.fp16, + args.bf16, + grad_scaler, + model) + + # FP32. + return FP32Optimizer(optimizer, args.clip_grad, + args.log_num_zeros_in_grad, + params_have_main_grad, + args.use_contiguous_buffers_in_local_ddp, + model) diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py new file mode 100644 index 0000000000000000000000000000000000000000..ad249bd5d6e503cd2488a76c8e41996d84c86d50 --- /dev/null +++ b/megatron/optimizer/clip_grads.py @@ -0,0 +1,148 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gradient clipping.""" + +import torch +from torch._six import inf + +from apex.multi_tensor_apply import multi_tensor_applier +import amp_C + +from megatron.model.module import param_is_not_shared +from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate + + +def clip_grad_norm_fp32(parameters, grads_for_norm, + max_norm, norm_type=2, + model_parallel_group=None): + """Clips gradient norm of an iterable of parameters whose gradients + are in fp32. + + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single + Tensor that will be used for calculating the grad norm. + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + model_parallel_group (group): given the nature of the distributed + optimizer, this is passed as an argument. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + if isinstance(grads_for_norm, torch.Tensor): + grads_for_norm = [grads_for_norm] + + # Grads. + grads = [] + for param in parameters: + if param.grad is not None: + assert param.grad.type() == 'torch.cuda.FloatTensor' + grads.append(param.grad.detach()) + + # Norm parameters. + max_norm = float(max_norm) + norm_type = float(norm_type) + total_norm = 0.0 + + # Calculate norm. + if norm_type == inf: + total_norm = max(grad.abs().max() for grad in grads_for_norm) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + # Take max across all model-parallel GPUs. + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.MAX, + group=model_parallel_group) + total_norm = total_norm_cuda[0].item() + + else: + if norm_type == 2.0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + # Use apex's multi-tensor applier for efficiency reasons. + # Multi-tensor applier takes a function and a list of list + # and performs the operation on that list all in one kernel. + if grads_for_norm: + grad_norm, _ = multi_tensor_applier( + amp_C.multi_tensor_l2norm, + dummy_overflow_buf, + [grads_for_norm], + False # no per-parameter norm + ) + else: + grad_norm = torch.cuda.FloatTensor([0]) + # Since we will be summing across data parallel groups, + # we need the pow(norm-type). + total_norm = grad_norm ** norm_type + + else: + for grad in grads_for_norm: + grad_norm = torch.norm(grad, norm_type) + total_norm += grad_norm ** norm_type + + # Sum across all model-parallel GPUs. + torch.distributed.all_reduce(total_norm, + op=torch.distributed.ReduceOp.SUM, + group=model_parallel_group) + total_norm = total_norm.item() ** (1.0 / norm_type) + + # Scale. + clip_coeff = max_norm / (total_norm + 1.0e-6) + if clip_coeff < 1.0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + multi_tensor_applier(amp_C.multi_tensor_scale, + dummy_overflow_buf, + [grads, grads], + clip_coeff) + + return total_norm + + +def count_zeros_fp32(parameters, model_parallel_group): + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + total_num_zeros = torch.cuda.FloatTensor([0.0]) + for param in parameters: + grad_not_none = param.grad is not None + is_not_shared = param_is_not_shared(param) + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if grad_not_none and is_not_shared and is_not_tp_duplicate: + grad = param.grad.detach() + num_zeros = grad.numel() - torch.count_nonzero(grad) + total_num_zeros = num_zeros + total_num_zeros + + # Sum across all model-parallel GPUs. + torch.distributed.all_reduce(total_num_zeros, + op=torch.distributed.ReduceOp.SUM, + group=model_parallel_group) + + total_num_zeros = total_num_zeros.item() + + return total_num_zeros diff --git a/megatron/optimizer/distrib_optimizer.py b/megatron/optimizer/distrib_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..331f7846cd03949f5aa5a67a261f1f8f1fce0567 --- /dev/null +++ b/megatron/optimizer/distrib_optimizer.py @@ -0,0 +1,696 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron distributed optimizer.""" + + +import math +import torch + +from megatron import get_args +from megatron import get_timers +from megatron import mpu +from megatron import print_rank_0 +from megatron.model.module import param_is_not_shared +from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate + +from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper + + +class Range: + """ + A range represents a start and end points for indexing a shard + from a full tensor. + """ + def __init__(self, start, end): + self.start = start + self.end = end + self.size = end - start + def normalize(self, start = 0): + return Range(start, start + self.size) + def __str__(self): + return "%d,%d [%d]" % (self.start, self.end, self.size) + + +class DistributedOptimizer(MixedPrecisionOptimizer): + """Distributed optimizer, for all data types (fp16, bf16, and fp32). + + Arguments: + optimizer: base optimizer such as Adam or SGD + clip_grad: clip gradeints with this global L2 norm. Note + that clipping is ignored if clip_grad == 0 + log_num_zeros_in_grad: return number of zeros in the gradients. + params_have_main_grad: flag indicating if parameters have + a `main_grad` field. If this is set, we are assuming + that the model parameters are store in the `main_grad` + field instead of the typical `grad` field. This happens + for the DDP cases where there is a continuous buffer + holding the gradients. For example for bfloat16, we want + to do gradient accumulation and all-reduces in float32 + and as a result we store those gradients in the main_grad. + Note that main grad is not necessarily in float32. + use_contiguous_buffers_in_local_ddp: if true, the local DDP model + is using a contiguous buffer to hold the model grads. + fp16: if true, the model is running in fp16. + bf16: if true, the model is running in bfloat16. + grad_scaler: used for scaling gradients. Note that this can be + None. This case happens when `bf16 = True` and we don't + use any loss scale. Note that for `bf16 = True`, we can have + a constnat gradient scaler. Also for `bf16 = False`, we + always require a grad scaler. + models: list of models (i.e., the virtual pipelining models). This + is used by the distributed optimizer for mapping parameters. + """ + + @classmethod + def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range): + """ + Build mapping from param reference to grad buffer shard ranges. + + This method builds a mapping from parameter references to grad + buffer shard ranges, specific to each data-parallel (DP) rank's + set of 'owned' parameters. Each grad buffer (padded to be an even + multiple of DP-world-size) is conceptually divided into DP-world-size + contiguous regions, where each DP rank 'owns' a contiguous regions. + Ownership in this sense means DP rank is responsible for reducing + the relevant subset of grads, and updating the relevant subset of + params. + + This conceptual partitioning of the grad buffer does NOT respect + parameter boundaries, and as such it is assumed that each created + range references a shard (or subset) of the full parameter. It is + easiest to think of each DP rank as operating (i.e., reducing, + gathering) purely on views into the grad buffer, for all model-to- + main & main-to-model operations. + + This method creates three ranges: + - The param's range within the entire grad buffer (i.e., world index). + - The param's range within the DP rank's local view of the grad buffer. + - The param's range within itself (i.e., its shard). + """ + + # Param range map. + param_world_index_map = model._grad_buffer_param_index_map[dtype] + param_range_map = {} + for param, param_world_indexes in param_world_index_map.items(): + + # Param range. + param_world_start, param_world_end = param_world_indexes + param_local_start = max( + 0, + param_world_start - gbuf_world_range.start) + param_local_end = min( + gbuf_world_range.size, + param_world_end - gbuf_world_range.start) + + # Add param, if within local gbuf range. + if param_local_end > param_local_start: + param_local_range = Range(param_local_start, param_local_end) + param_world_range = param_local_range.normalize( + param_local_start + gbuf_world_range.start) + sub_param_start = max(0, gbuf_world_range.start-param_world_start) + sub_param_range = param_local_range.normalize(sub_param_start) + param_range_map[param] = { + "gbuf_world" : param_world_range, + "gbuf_local" : param_local_range, + "param" : sub_param_range, + } + + return param_range_map + + + @classmethod + def build_model_gbuf_range(cls, model, dtype): + """ + Build mapping between params and their grad buffers. + + This method does the initial setup for the method above. This setup + includes determining the shard ranges into the DDP's grad buffer for + each data-parallel (DP) rank. Each DP rank keeps range info for + all other DP ranks, for the purpose of creating args for + reduce-scatter and all-gather. + """ + + data_parallel_rank = mpu.get_data_parallel_rank() + data_parallel_world_size = mpu.get_data_parallel_world_size() + + # Grad buffer range. + grad_buffer = model._grad_buffers[dtype] + gbuf_size = grad_buffer.numel + max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size)) + + # All world ranges. (i.e., across all data parallel ranks) + gbuf_world_all_ranges = [] + for r in range(data_parallel_world_size): + gbuf_world_start = r * max_gbuf_range_size + gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size) + gbuf_world_range = Range(gbuf_world_start, gbuf_world_end) + gbuf_world_all_ranges.append(gbuf_world_range) + + # Local DP's ranges. + gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank] + gbuf_local_range = gbuf_world_range.normalize() + + # Get each param's ranges. + param_range_map = cls.build_model_gbuf_param_range_map(model, + dtype, + gbuf_world_range) + + # Group into dict. + data = { + "local" : gbuf_local_range, + "world" : gbuf_world_range, + "world_all" : gbuf_world_all_ranges, + "param_map" : param_range_map, + "max_range_size" : max_gbuf_range_size, + } + + return data + + + @classmethod + def build_model_gbuf_range_map(cls, model): + """ + Create param-to-grad-buffer mappings, for grad buffer data types + within a specific virtual model. + """ + return { + dtype : cls.build_model_gbuf_range(model, dtype) + for dtype in model._grad_buffers + } + + + @classmethod + def build_model_param_gbuf_map(cls, model_gbuf_ranges): + """ + Create a reverse of the model_gbuf_ranges, for referencing in + opposite direction. + """ + param_gbuf_map = {} + for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges): + for dtype, gbuf_range_map in model_gbuf_range_map.items(): + for param, param_range_map in gbuf_range_map["param_map"].items(): + param_gbuf_map[param] = (model_index, dtype) + return param_gbuf_map + + + @classmethod + def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges): + """ + Create optimizer groups. + + Given the set of parameter shard ranges that are owned by the current + data-parallel (DP) rank, gather the set of parameters that will be + used (in the method below) to create the current DP's optimizer + groups. + """ + + num_groups = len(param_groups) + + # Param group map. + param_group_map = {} + for group_index, group in enumerate(param_groups): + for param in group["params"]: + assert param.requires_grad + param_group_map[param] = group_index + + # Optimizer group ranges. + group_ranges = [ {"params": []} for _ in param_groups ] + for model_gbuf_range_map in model_gbuf_ranges: + for dtype, gbuf_range_map in model_gbuf_range_map.items(): + for param in gbuf_range_map["param_map"]: + group_index = param_group_map[param] + group_range = group_ranges[group_index] + group_range["params"].append(param) + + # Squeeze zero-size group ranges. + for group_index, group_range in enumerate(group_ranges): + group_range["orig_group"] = param_groups[group_index] + group_ranges = [ g for g in group_ranges if len(g["params"]) > 0 ] + + return group_ranges + + + @classmethod + def build_model_and_main_param_groups(cls, + model_gbuf_ranges, + param_gbuf_map, + opt_group_ranges): + """ + Create main parameter groups needed for the optimizer step. + + These groups encompass both: 1) groups used by this class, for + reducing/gather, and 2) groups used by the inner optimizer for the + parameter update. Given that the conceptual grad buffer partitioning + (created in earlier method) doesn't respect parameter boundaries, + the optimizer operates on shards of the model parameters, rather than + the full parameters. + """ + + # Parameter groups: + # model_float16_groups: original float16 parameters + # model_fp32_groups: original fp32 parameters + # shard_float16_groups: shards of original float16 parameters + # shard_fp32_groups: shards of original fp32 parameters + # shard_fp32_from_float16_groups: fp32 copy of float16 parameters + model_float16_groups = [] + model_fp32_groups = [] + shard_float16_groups = [] + shard_fp32_groups = [] + shard_fp32_from_float16_groups = [] + + # Allocate (or slice) each group's param shard. + for group_index, group_range in enumerate(opt_group_ranges): + + # Params of this group. + model_float16_params_this_group = [] + model_fp32_params_this_group = [] + shard_float16_params_this_group = [] + shard_fp32_params_this_group = [] + shard_fp32_from_float16_params_this_group = [] + model_float16_groups.append(model_float16_params_this_group) + model_fp32_groups.append(model_fp32_params_this_group) + shard_float16_groups.append(shard_float16_params_this_group) + shard_fp32_groups.append(shard_fp32_params_this_group) + shard_fp32_from_float16_groups.append( + shard_fp32_from_float16_params_this_group) + + for model_param in group_range["params"]: + + assert model_param.requires_grad + + model_index, dtype = param_gbuf_map[model_param] + gbuf_range = model_gbuf_ranges[model_index][dtype] + param_range = gbuf_range["param_map"][model_param]["param"] + + # fp16, bf16 params. + if model_param.type() in ['torch.cuda.HalfTensor', + 'torch.cuda.BFloat16Tensor']: + + # Clone model -> main. + shard_model_param = model_param.detach().view(-1) \ + [param_range.start:param_range.end] + shard_main_param = shard_model_param.clone().float() + mpu.copy_tensor_model_parallel_attributes( + shard_model_param, model_param) + mpu.copy_tensor_model_parallel_attributes( + shard_main_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + shard_main_param.shared = model_param.shared + + # Add to group. + model_float16_params_this_group.append(model_param) + shard_float16_params_this_group.append(shard_model_param) + shard_fp32_from_float16_params_this_group.append(shard_main_param) + + # fp32 params. + elif model_param.type() == 'torch.cuda.FloatTensor': + shard_model_param = model_param.view(-1) \ + [param_range.start:param_range.end] + model_fp32_params_this_group.append(model_param) + shard_fp32_params_this_group.append(shard_model_param) + mpu.copy_tensor_model_parallel_attributes( + shard_model_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + + else: + raise TypeError('Wrapped parameters must be one of ' + 'torch.cuda.FloatTensor, ' + 'torch.cuda.HalfTensor, or ' + 'torch.cuda.BFloat16Tensor. ' + 'Received {}'.format(param.type())) + + # Update optimizer's params. + group_range["orig_group"]["params"] = [ + *shard_fp32_params_this_group, + *shard_fp32_from_float16_params_this_group, + ] + + return ( + model_float16_groups, + model_fp32_groups, + shard_float16_groups, + shard_fp32_groups, + shard_fp32_from_float16_groups, + ) + + + def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, grad_scaler, models): + """ + See top of class definition for argument descriptions. + + The steps in this method create the core mapping between DDP grad + buffers, parameters, and parameter shard ranges, that is needed for + converting between model param indexes and main parameter shard + indexes. This method also updates the optimizer parameter groups + with the newly created shards. + """ + + super().__init__( + optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, grad_scaler, models) + + # Verify that contiguous buffers are being used. + # - Note: this should already be checked in arguments.py. + assert use_contiguous_buffers_in_local_ddp + + # Model grad buffer ranges. + self.model_gbuf_ranges = [] + for model_index, model in enumerate(self.models): + self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model)) + self.model_param_gbuf_map = \ + self.build_model_param_gbuf_map(self.model_gbuf_ranges) + + # Optimizer ranges. + self.opt_group_ranges = self.build_optimizer_group_ranges( + self.optimizer.param_groups, + self.model_gbuf_ranges) + + # Allocate main param shards. + ( + self.model_float16_groups, + self.model_fp32_groups, + self.shard_float16_groups, + self.shard_fp32_groups, + self.shard_fp32_from_float16_groups, + ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges, + self.model_param_gbuf_map, + self.opt_group_ranges) + + # Update optimizer groups. + # - Also, leverage state_dict() and load_state_dict() to + # recast preexisting per-param state tensors. + self.optimizer.param_groups = \ + [ g["orig_group"] for g in self.opt_group_ranges ] + self.optimizer.load_state_dict(self.optimizer.state_dict()) + + + def get_model_param_range_map(self, param): + """ + Given a model param, get the index sub-range of the param that this + data-parallel rank owns. + """ + model_index, dtype = self.model_param_gbuf_map[param] + gbuf_range_map = self.model_gbuf_ranges[model_index][dtype] + param_range_map = gbuf_range_map["param_map"][param] + return param_range_map + + + def get_model_parallel_group(self): + """ + With the distributed optimizer, the model parallel group is the + entire world. + """ + return None + + + def state_dict(self): + """ + The state dict must contain the fp32-from-float16 shards. + """ + state_dict = {} + state_dict['optimizer'] = self.optimizer.state_dict() + if self.grad_scaler: + state_dict['grad_scaler'] = self.grad_scaler.state_dict() + state_dict['shard_fp32_from_float16_groups'] = \ + self.shard_fp32_from_float16_groups + return state_dict + + + def load_state_dict(self, state_dict): + """ + Load the state dict. + """ + + # Optimizer. + optimizer_key = 'optimizer' + if optimizer_key not in state_dict: + optimizer_key = 'optimizer_state_dict' + print_rank_0('***WARNING*** loading optimizer from ' + 'an old checkpoint ...') + self.optimizer.load_state_dict(state_dict[optimizer_key]) + + # Grad scaler. + if 'grad_scaler' not in state_dict: + print_rank_0('***WARNING*** found an old checkpoint, will not ' + 'load grad scaler ...') + else: + if self.grad_scaler: + self.grad_scaler.load_state_dict(state_dict['grad_scaler']) + else: + print_rank_0('***WARNING*** fould the grad scaler in the ' + 'checkpoint but it is None in the class. ' + 'Skipping loading grad scaler ...') + + # Copy data for the main params. + for current_group, saved_group in zip( + self.shard_fp32_from_float16_groups, + state_dict["shard_fp32_from_float16_groups"]): + for current_param, saved_param in zip(current_group, saved_group): + current_param.data.copy_(saved_param.data) + + + def zero_grad(self, set_to_none=True): + """ + Zero grads. + + We only need to zero the model related parameters, i.e., + model_float16_groups & model_fp32_groups. We additionally zero + the remaining groups as a memory optimization to reduce + fragmentation; in the case of set_to_none==True, the space + used by this field can be safely deallocated at this point. + """ + for groups in ( + self.model_float16_groups, + self.model_fp32_groups, + self.shard_float16_groups, # grad empty/unused here? + self.shard_fp32_groups, # throws grad-access warning + self.shard_fp32_from_float16_groups): + for group in groups: + _zero_grad_group_helper(group, set_to_none) + + + def get_model_grad_buffer_dp_views(self): + """ + Get shard views of each of the DDP's grad buffers. + + In this nested list, the top level is grouped by the virtual model + index and the grad buffer's data type. The sub-level is a list of + shards of that grad buffer, where each shard in the list represents + a contiguous view of the grad buffer, that is owned by a data-parallel + rank. The shard boundary does not respect parameter boundaries, and + so the elements of some parameters are split across data parallel + ranks. + + Additionally, return references to the entire grad buffers, for use + in _reduce_scatter_base and _all_gather_base. + """ + + data_parallel_world_size = mpu.get_data_parallel_world_size() + + # Grad buffer views. + gbuf_view_items = [] + for model_index, model in enumerate(self.models): + for dtype, gbuf in model._grad_buffers.items(): + + assert gbuf.numel_padded % data_parallel_world_size == 0 + shard_size = int(gbuf.numel_padded / data_parallel_world_size) + gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)] + for r in range(data_parallel_world_size)] + gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views)) + + return gbuf_view_items + + + def reduce_model_grads(self, args, timers): + """ + Reduce-scatter model grads. + + The DDP's grad buffer is used for the reduce-scatter, and thus no + tensors are dynamically allocated. + + Note: this is a different order of reduction, versus the non- + distributed optimizer, which reduces: 1) layernorm grads, 2) all + grads, 3) embedding grads. + """ + + # All-reduce layer-norm grads (for sequence parallelism). + timers('backward-layernorm-all-reduce').start() + self.allreduce_layernorm_grads(args) + timers('backward-layernorm-all-reduce').stop() + + # All-reduce embedding grads. + timers('backward-embedding-all-reduce').start() + self.allreduce_embedding_grads(args) + timers('backward-embedding-all-reduce').stop() + + # Reduce-scatter setup. + timers('backward-params-all-reduce').start() + data_parallel_rank = mpu.get_data_parallel_rank() + data_parallel_world_size = mpu.get_data_parallel_world_size() + data_parallel_group = mpu.get_data_parallel_group() + + # Scale grad buffers by '1 / data_parallel_world_size'. + for model in self.models: + for dtype, gbuf in model._grad_buffers.items(): + gbuf.data /= data_parallel_world_size + + # Reduce-scatter all grads. + gbuf_view_items = self.get_model_grad_buffer_dp_views() + for index, (model_index, dtype, gbuf, gbuf_views) \ + in enumerate(gbuf_view_items): + + torch.distributed._reduce_scatter_base( + gbuf_views[data_parallel_rank], + gbuf, + group = data_parallel_group, + ) + + timers('backward-params-all-reduce').stop() + + + def gather_model_params(self, args, timers): + """ + All-gather updated model params. + + The DDP's grad buffer is used for the all-gather, and thus no + tensors are dynamically allocated. After the all-gather, the params + can be copied from param.main_grad to param. + """ + + timers('backward-params-all-gather').start() + + data_parallel_rank = mpu.get_data_parallel_rank() + data_parallel_group = mpu.get_data_parallel_group() + + # All-gather updated main params. + # - All grad buffer views are guaranteed to have the same num elements + # across all data parallel ranks, with grad buffer padding that is done + # in distributed.py. Thus, all sub-views will have consistent start/end + # indexes across data parallel ranks. + gbuf_view_items = self.get_model_grad_buffer_dp_views() + for index, (model_index, dtype, gbuf, gbuf_views) \ + in enumerate(gbuf_view_items): + + torch.distributed._all_gather_base( + gbuf, + gbuf_views[data_parallel_rank], + group = data_parallel_group, + ) + + # Each model param now contains its updated values in its + # '.main_grad' field. + for model in self.models: + for dtype, param_map in model._grad_buffer_param_index_map.items(): + for param in param_map: + param.detach().copy_(param.main_grad) + + timers('backward-params-all-gather').stop() + + + def _collect_main_grad_data_for_unscaling(self): + """ + Note: this should be equivalent to the float-16 optimizer's method, + but writtent differently, so the two should be combined. + """ + return [ + param.grad.data + for group in self.optimizer.param_groups + for param in group["params"] + ] + + + def _get_model_and_main_params_data_float16(self): + """ + Get aligned list of model and main params. + """ + model_data = [] + main_data = [] + for model_group, main_group in zip(self.shard_float16_groups, + self.shard_fp32_from_float16_groups): + for model_param, main_param in zip(model_group, main_group): + model_data.append(model_param.data) + main_data.append(main_param.data) + return model_data, main_data + + + def _copy_model_grads_to_main_grads(self): + """ + Copy model grads to main grads. + + Since this step follows a reduce-scatter through the DDP's grad + buffer, this method is responsible for copying the updated grads + from the grad buffer to the main shard's grad field. + """ + + # Utility method for copying group grads. + def copy_group_grads(model_groups, shard_main_groups): + for model_group, shard_main_group in zip(model_groups, + shard_main_groups): + for model_param, shard_main_param in zip(model_group, + shard_main_group): + + param_range_map = self.get_model_param_range_map(model_param) + param_range = param_range_map["param"] + assert param_range.size == shard_main_param.nelement() + + model_grad = model_param.main_grad + shard_model_grad = model_grad.view(-1) \ + [param_range.start:param_range.end] + shard_main_param.grad = shard_model_grad.float() + + # Copy model groups to shard groups. + copy_group_grads(self.model_float16_groups, + self.shard_fp32_from_float16_groups) + copy_group_grads(self.model_fp32_groups, + self.shard_fp32_groups) + + + def _copy_main_params_to_model_params(self): + """ + Copy main params to model params. + + Since this step is followed by an all-gather through the DDP's grad + buffer, this method is responsible for copying the updated params + from the main shards into the correct position in the grad buffer. + """ + + # Utility method for copying group params. + def copy_group_params(shard_main_groups, model_groups): + for shard_main_group, model_group in zip(shard_main_groups, + model_groups): + for shard_main_param, model_param in zip(shard_main_group, + model_group): + + param_range_map = self.get_model_param_range_map(model_param) + param_range = param_range_map["param"] + assert param_range.size == shard_main_param.nelement() + + model_grad = model_param.main_grad + shard_model_grad = model_grad.view(-1) \ + [param_range.start:param_range.end] + + shard_model_grad.data.copy_(shard_main_param) + + # Copy shard groups to model groups. + copy_group_params(self.shard_fp32_from_float16_groups, + self.model_float16_groups) + copy_group_params(self.shard_fp32_groups, + self.model_fp32_groups) diff --git a/megatron/optimizer/grad_scaler.py b/megatron/optimizer/grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..6b25588db7a7366ec0e7f9930d8fda773b0f91a9 --- /dev/null +++ b/megatron/optimizer/grad_scaler.py @@ -0,0 +1,133 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron grad scaler.""" + +from abc import ABC +from abc import abstractmethod + +import torch + + +class MegatronGradScaler(ABC): + + def __init__(self, initial_scale): + """Initialize scale value with the input initial scale.""" + assert initial_scale > 0.0 + self._scale = torch.cuda.FloatTensor([initial_scale]) + + @property + def scale(self): + return self._scale + + @property + def inv_scale(self): + return self._scale.double().reciprocal().float() + + @abstractmethod + def update(self, found_inf): + pass + + @abstractmethod + def state_dict(self): + pass + + @abstractmethod + def load_state_dict(self, state_dict): + pass + + + +class ConstantGradScaler(MegatronGradScaler): + + def update(self, found_inf): + pass + + def state_dict(self): + return dict() + + def load_state_dict(self, state_dict): + pass + + + +class DynamicGradScaler(MegatronGradScaler): + + def __init__(self, initial_scale, min_scale, + growth_factor, backoff_factor, + growth_interval, hysteresis): + """"Grad scaler with dynamic scale that gets adjusted + during training.""" + super(DynamicGradScaler, self).__init__(initial_scale) + + # Lower bound on the scale. + assert min_scale > 0.0 + assert min_scale <= initial_scale + self.min_scale = torch.cuda.FloatTensor([min_scale]) + # Growth and backoff factors for the scale. + assert growth_factor > 1.0 + self.growth_factor = torch.cuda.FloatTensor([growth_factor]) + assert backoff_factor < 1.0 + assert backoff_factor > 0.0 + self.backoff_factor = torch.cuda.FloatTensor([backoff_factor]) + # Interval over which if we don't see any inf/nan, + # we will scale the grad scale by the growth factor. + assert growth_interval > 0 + self.growth_interval = growth_interval + # Number of inf/nans we should see before scaling down + # the grad scale by the backoff factor. + assert hysteresis > 0 + self.hysteresis = hysteresis + + # Trackers. + self._growth_tracker = 0 + self._hysteresis_tracker = self.hysteresis + + + def update(self, found_inf): + + # If we have an inf/nan, growth tracker is set to 0 + # and hysterisis tracker is reduced by 1. + if found_inf: + self._growth_tracker = 0 + self._hysteresis_tracker -= 1 + # Now if we are out of hysteresis count, scale down the loss. + if self._hysteresis_tracker <= 0: + self._scale = torch.max(self._scale * self.backoff_factor, + self.min_scale) + else: + # If there is no nan/inf, increment the growth tracker. + self._growth_tracker += 1 + # If we have had enough consequitive intervals with no nan/inf: + if self._growth_tracker == self.growth_interval: + # Reset the tracker and hysteresis trackers, + self._growth_tracker = 0 + self._hysteresis_tracker = self.hysteresis + # and scale up the loss scale. + self._scale = self._scale * self.growth_factor + + + def state_dict(self): + state_dict = {} + state_dict['scale'] = self._scale + state_dict['growth_tracker'] = self._growth_tracker + state_dict['hysteresis_tracker'] = self._hysteresis_tracker + return state_dict + + + def load_state_dict(self, state_dict): + self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) + self._growth_tracker = state_dict['growth_tracker'] + self._hysteresis_tracker = state_dict['hysteresis_tracker'] diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b265145a3dcbb07caa5d0565595c1065aef77cdc --- /dev/null +++ b/megatron/optimizer/optimizer.py @@ -0,0 +1,772 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron optimizer.""" + +from abc import ABC +from abc import abstractmethod +from apex.multi_tensor_apply import multi_tensor_applier +import amp_C +import torch +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from megatron import get_timers +from megatron import mpu +from megatron import print_rank_0 +from megatron.model import DistributedDataParallel as LocalDDP +from megatron.model import Float16Module +from megatron.model.module import param_is_not_shared +from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate +from megatron.utils import unwrap_model + +from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 + + +def _zero_grad_group_helper(group, set_to_none): + """Zero out the gradient for a group of parameters. + Note: copied from torch.optim.optimizer.""" + for param in group: + if param.grad is not None: + if set_to_none: + param.grad = None + else: + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) + param.grad.zero_() + + +def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): + """Use multi-tensor-applier to copy values from one list to another. + We don't have a blfoat16 implementation so for now if the overflow_buf + is not provided, we default back to simple loop copy to be compatible + with bfloat16.""" + if overflow_buf: + overflow_buf.fill_(0) + # Scaling with factor `1.0` is equivalent to copy. + multi_tensor_applier(amp_C.multi_tensor_scale, + overflow_buf, + [this, that], + 1.0) + else: + for this_, that_ in zip(this, that): + that_.copy_(this_) + + + +class MegatronOptimizer(ABC): + + + def __init__(self, optimizer, clip_grad, + log_num_zeros_in_grad, + params_have_main_grad, + use_contiguous_buffers_in_local_ddp, + models): + + """Input optimizer is the base optimizer for example Adam.""" + self.optimizer = optimizer + assert self.optimizer, 'no optimizer is provided.' + # Set gradient clipping and logging params. + self.clip_grad = clip_grad + self.log_num_zeros_in_grad = log_num_zeros_in_grad + self.params_have_main_grad = params_have_main_grad + self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp + + # 'models' are retained for access to the contiguous grad buffers. + # (see distributed optimizer) + self.models = models + + if self.use_contiguous_buffers_in_local_ddp: + assert self.params_have_main_grad, \ + "use of contiguous buffer requires that params have main grad" + + + def get_parameters(self): + params = [] + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + params.append(param) + return params + + + def get_main_grads_for_grad_norm(self): + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + params = self.get_parameters() + grads_for_norm = [] + for param in params: + grad = param.grad + grad_not_none = grad is not None + is_not_shared = param_is_not_shared(param) + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if grad_not_none and is_not_shared and is_not_tp_duplicate: + grads_for_norm.append(grad) + + return grads_for_norm + + + def get_model_parallel_group(self): + """Default returned here, but the distributed optimizer overrides this.""" + return mpu.get_model_parallel_group() + + + def clip_grad_norm(self, clip_grad): + params = self.get_parameters() + grads_for_norm = self.get_main_grads_for_grad_norm() + return clip_grad_norm_fp32( + params, grads_for_norm, clip_grad, + model_parallel_group=self.get_model_parallel_group()) + + + def count_zeros(self): + params = self.get_parameters() + return count_zeros_fp32(params, + model_parallel_group=self.get_model_parallel_group()) + + + @abstractmethod + def zero_grad(self, set_to_none=True): + pass + + + @abstractmethod + def get_loss_scale(self): + """The output should be a cuda tensor of size 1.""" + pass + + + def scale_loss(self, loss): + """Simple scaling.""" + return self.get_loss_scale() * loss + + + @abstractmethod + def reload_model_params(self): + """Refreshes any internal state from the current model parameters. + Call whenever the parameters are changed outside of the optimizer. + For example, when we load a model from a checkpoint without loading + the optimizer, the model parameters are updated but for fp16 optimizer + with main parameters, the main parameters need to also be updated.""" + pass + + + @abstractmethod + def state_dict(self): + pass + + + @abstractmethod + def load_state_dict(self, state_dict): + pass + + + # Promote state so it can be retrieved or set via + # "optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + + # Promote param_groups so it can be retrieved or set via + # "optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) + + + @abstractmethod + def step(self, args, timers): + pass + + + def gather_model_params(self, args, timers): + """ + For the case of a non-distributed-optimizer, there is nothing to + do here. + """ + pass + + + def allreduce_word_embedding_grads(self, args): + """ + All-reduce word embedding grads. + + Reduce grads across first and last stages to ensure that word_embeddings + parameters stay in sync. This should only run for models that support + pipelined model parallelism (BERT and GPT-2). + """ + + if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ + mpu.get_pipeline_model_parallel_world_size() > 1: + if mpu.is_pipeline_first_stage(ignore_virtual=True): + unwrapped_model = self.models[0] + elif mpu.is_pipeline_last_stage(ignore_virtual=True): + unwrapped_model = self.models[-1] + else: # We do not support the interleaved schedule for T5 yet. + unwrapped_model = self.models[0] + unwrapped_model = unwrap_model( + unwrapped_model, (torchDDP, LocalDDP, Float16Module)) + + if unwrapped_model.share_word_embeddings: + word_embeddings_weight = unwrapped_model.word_embeddings_weight() + if args.DDP_impl == 'local': + grad = word_embeddings_weight.main_grad + else: + grad = word_embeddings_weight.grad + torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) + + + def allreduce_position_embedding_grads(self, args): + """ + All-reduce position_embeddings grad across first (encoder) and + split (decoder) stages to ensure that position embeddings parameters + stay in sync. This should only run for T5 models with pipeline + parallelism. + """ + if mpu.is_rank_in_position_embedding_group() and \ + mpu.get_pipeline_model_parallel_world_size() > 1 and \ + args.pipeline_model_parallel_split_rank is not None: + unwrapped_model = self.models[0] + unwrapped_model = unwrap_model( + unwrapped_model, (torchDDP, LocalDDP, Float16Module)) + assert args.DDP_impl == 'local', \ + 'T5 model is only supported with local DDP mode' + grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad + torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group()) + + + def allreduce_embedding_grads(self, args): + """All-reduce both word and position embeddings.""" + self.allreduce_word_embedding_grads(args) + self.allreduce_position_embedding_grads(args) + + + def allreduce_layernorm_grads(self, args): + """All-reduce layernorm grads (for sequence parallelism).""" + + # All-reduce layernorm parameters across model parallel nodes + # when sequence parallelism is used + if mpu.get_tensor_model_parallel_world_size() > 1 and \ + args.sequence_parallel: + grads = [] + for model_module in self.models: + unwrapped_model = unwrap_model( + model_module, (torchDDP, LocalDDP, Float16Module)) + for param in unwrapped_model.parameters(): + if getattr(param, 'sequence_parallel', False): + grad = param.main_grad if args.DDP_impl == 'local' else param.grad + grads.append(grad.data) + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce( + coalesced, group=mpu.get_tensor_model_parallel_group()) + for buf, synced in zip(grads, _unflatten_dense_tensors( + coalesced, grads)): + buf.copy_(synced) + + + def reduce_model_grads(self, args, timers): + """All-reduce all grads, and all-reduce embeddings.""" + + # All-reduce layer-norm grads (for sequence parallelism). + timers('backward-layernorm-all-reduce').start() + self.allreduce_layernorm_grads(args) + timers('backward-layernorm-all-reduce').stop() + + # All-reduce if needed. + if args.DDP_impl == 'local': + timers('backward-params-all-reduce').start() + for model in self.models: + model.allreduce_gradients() + timers('backward-params-all-reduce').stop() + + # All-reduce embedding grads. + timers('backward-embedding-all-reduce').start() + self.allreduce_embedding_grads(args) + timers('backward-embedding-all-reduce').stop() + + +class MixedPrecisionOptimizer(MegatronOptimizer): + """Base class for both the float-16 and the distributed optimizer. + + Arguments: + optimizer: base optimizer such as Adam or SGD + clip_grad: clip gradeints with this global L2 norm. Note + that clipping is ignored if clip_grad == 0 + log_num_zeros_in_grad: return number of zeros in the gradients. + params_have_main_grad: flag indicating if parameters have + a `main_grad` field. If this is set, we are assuming + that the model parameters are store in the `main_grad` + field instead of the typical `grad` field. This happens + for the DDP cases where there is a continuous buffer + holding the gradients. For example for bfloat16, we want + to do gradient accumulation and all-reduces in float32 + and as a result we store those gradients in the main_grad. + Note that main grad is not necessarily in float32. + use_contiguous_buffers_in_local_ddp: if true, the local DDP model + is using a contiguous buffer to hold the model grads. + fp16: if true, the model is running in fp16. + bf16: if true, the model is running in bfloat16. + grad_scaler: used for scaling gradients. Note that this can be + None. This case happens when `bf16 = True` and we don't + use any loss scale. Note that for `bf16 = True`, we can have + a constnat gradient scaler. Also for `bf16 = False`, we + always require a grad scaler. + models: list of models (i.e., the virtual pipelining models). This + is used by the distributed optimizer for mapping parameters. + """ + + def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, grad_scaler, + models): + + super().__init__( + optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + models) + + self.fp16 = fp16 + self.bf16 = bf16 + self.grad_scaler = grad_scaler + + # None grad scaler is only supported for bf16. + if self.grad_scaler is None: + assert not self.fp16, 'fp16 expects a grad scaler.' + + # Tensor used to determine if a nan/if has happend. + # Any non-zero value indicates inf/nan. + # Note that we keep this for the cases that grad scaler is none. + # We still record nan/inf if we have a bfloat16 with a grad scaler. + if self.grad_scaler: + self.found_inf = torch.cuda.FloatTensor([0.0]) + + # Dummy tensor needed for apex multi-apply tensor. + # For bfloat, we don't have multi-tensor apply and for now + # we set it to none so the multi-tensor apply gets ignored. + if bf16: + self._dummy_overflow_buf = None + else: + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + + # In case grad scaler is not passed, define the unity scale. + if self.grad_scaler is None: + self._scale_one = torch.cuda.FloatTensor([1.0]) + + + def get_loss_scale(self): + if self.grad_scaler is None: + return self._scale_one + return self.grad_scaler.scale + + + def reload_model_params(self): + self._copy_model_params_to_main_params() + + + def _unscale_main_grads_and_check_for_nan(self): + + # Collect main grads. + main_grads = self._collect_main_grad_data_for_unscaling() + + # Reset found inf. + self.found_inf.fill_(0.0) + + # Unscale and set found inf/nan + torch._amp_foreach_non_finite_check_and_unscale_( + main_grads, self.found_inf, self.grad_scaler.inv_scale) + + # Update across all model parallel instances. + torch.distributed.all_reduce(self.found_inf, + op=torch.distributed.ReduceOp.MAX, + group=self.get_model_parallel_group()) + + # Check for nan. + found_inf_flag = (self.found_inf.item() > 0) + + return found_inf_flag + + + @torch.no_grad() + def step(self, args, timers): + + # Copy gradients from model params to main params. + timers('optimizer-copy-to-main-grad').start() + self._copy_model_grads_to_main_grads() + timers('optimizer-copy-to-main-grad').stop() + + # Do unscale, check for inf, and update grad scaler only for + # the case that grad scaler is provided. + if self.grad_scaler: + + # Unscale and check for inf/nan. + timers('optimizer-unscale-and-check-inf').start() + found_inf_flag = self._unscale_main_grads_and_check_for_nan() + timers('optimizer-unscale-and-check-inf').stop() + + # We are done with scaling gradients + # so we can update the loss scale. + self.grad_scaler.update(found_inf_flag) + + # If we found inf/nan, skip the update. + if found_inf_flag: + return False, None, None + + # Clip the main gradients. + timers('optimizer-clip-main-grad').start() + grad_norm = None + if self.clip_grad > 0.0: + grad_norm = self.clip_grad_norm(self.clip_grad) + timers('optimizer-clip-main-grad').stop() + + # Count the zeros in the grads. + timers('optimizer-count-zeros').start() + num_zeros_in_grad = self.count_zeros() if \ + self.log_num_zeros_in_grad else None + timers('optimizer-count-zeros').stop() + + # Step the optimizer. + timers('optimizer-inner-step').start() + self.optimizer.step() + timers('optimizer-inner-step').stop() + + # Update params from main params. + timers('optimizer-copy-main-to-model-params').start() + self._copy_main_params_to_model_params() + timers('optimizer-copy-main-to-model-params').stop() + + # Successful update. + return True, grad_norm, num_zeros_in_grad + + +class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): + """Float16 optimizer for fp16 and bf16 data types. + + Arguments: + optimizer: base optimizer such as Adam or SGD + clip_grad: clip gradeints with this global L2 norm. Note + that clipping is ignored if clip_grad == 0 + log_num_zeros_in_grad: return number of zeros in the gradients. + params_have_main_grad: flag indicating if parameters have + a `main_grad` field. If this is set, we are assuming + that the model parameters are store in the `main_grad` + field instead of the typical `grad` field. This happens + for the DDP cases where there is a continuous buffer + holding the gradients. For example for bfloat16, we want + to do gradient accumulation and all-reduces in float32 + and as a result we store those gradients in the main_grad. + Note that main grad is not necessarily in float32. + use_contiguous_buffers_in_local_ddp: if true, the local DDP model + is using a contiguous buffer to hold the model grads. + fp16: if true, the model is running in fp16. + bf16: if true, the model is running in bfloat16. + grad_scaler: used for scaling gradients. Note that this can be + None. This case happens when `bf16 = True` and we don't + use any loss scale. Note that for `bf16 = True`, we can have + a constnat gradient scaler. Also for `bf16 = False`, we + always require a grad scaler. + models: list of models (i.e., the virtual pipelining models). This + is used by the distributed optimizer for mapping parameters. + """ + + def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, grad_scaler, models): + + super().__init__( + optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, grad_scaler, models) + + # ====================== + # main parameter stuff + # ====================== + + # Three groups of parameters: + # float16_groups: original float16 parameters + # fp32_from_float16_groups: fp32 copy of float16 parameters + # fp32_from_fp32_groups: original fp32 parameters + self.float16_groups = [] + self.fp32_from_float16_groups = [] + self.fp32_from_fp32_groups = [] + + # For all the groups in the original optimizer: + for param_group in self.optimizer.param_groups: + float16_params_this_group = [] + fp32_params_this_group = [] + fp32_from_float16_params_this_group = [] + # For all the parameters in this group: + for i, param in enumerate(param_group['params']): + if param.requires_grad: + + # float16 params: + if param.type() in ['torch.cuda.HalfTensor', + 'torch.cuda.BFloat16Tensor']: + float16_params_this_group.append(param) + # Create a copy + main_param = param.detach().clone().float() + # Copy tensor model parallel attributes. + mpu.copy_tensor_model_parallel_attributes(main_param, + param) + if hasattr(param, 'shared'): + main_param.shared = param.shared + # Replace the optimizer params with the new fp32 copy. + param_group['params'][i] = main_param + + fp32_from_float16_params_this_group.append(main_param) + # Reset existing state dict key to the new main param. + if param in self.optimizer.state: + self.optimizer.state[main_param] \ + = self.optimizer.state.pop(param) + # fp32 params. + elif param.type() == 'torch.cuda.FloatTensor': + fp32_params_this_group.append(param) + param_group['params'][i] = param + + else: + raise TypeError('Wrapped parameters must be one of ' + 'torch.cuda.FloatTensor, ' + 'torch.cuda.HalfTensor, or ' + 'torch.cuda.BFloat16Tensor. ' + 'Received {}'.format(param.type())) + + self.float16_groups.append(float16_params_this_group) + self.fp32_from_float16_groups.append( + fp32_from_float16_params_this_group) + self.fp32_from_fp32_groups.append(fp32_params_this_group) + + + def zero_grad(self, set_to_none=True): + """We only need to zero the model related parameters, i.e., + float16_groups & fp32_from_fp32_groups. We additionally zero + fp32_from_float16_groups as a memory optimization to reduce + fragmentation; in the case of set_to_none==True, the space + used by this field can be safely deallocated at this point.""" + for group in self.float16_groups: + _zero_grad_group_helper(group, set_to_none) + for group in self.fp32_from_float16_groups: + _zero_grad_group_helper(group, set_to_none) + for group in self.fp32_from_fp32_groups: + _zero_grad_group_helper(group, set_to_none) + + + def _collect_main_grad_data_for_unscaling(self): + + main_grads = [] + + # fp32 params from float16 ones. + for main_group in self.fp32_from_float16_groups: + for main_param in main_group: + if main_param.grad is not None: + main_grads.append(main_param.grad.data) + + # Append fp32 parameters. + for main_group in self.fp32_from_fp32_groups: + for main_param in main_group: + if main_param.grad is not None: + main_grads.append(main_param.grad.data) + + return main_grads + + + def _get_model_and_main_params_data_float16(self): + model_data = [] + main_data = [] + for model_group, main_group in zip(self.float16_groups, + self.fp32_from_float16_groups): + for model_param, main_param in zip(model_group, main_group): + model_data.append(model_param.data) + main_data.append(main_param.data) + return model_data, main_data + + + def _copy_model_grads_to_main_grads(self): + # This only needs to be done for the float16 group. + for model_group, main_group in zip(self.float16_groups, + self.fp32_from_float16_groups): + for model_param, main_param in zip(model_group, main_group): + if self.params_have_main_grad and hasattr(model_param, 'main_grad'): + main_param.grad = model_param.main_grad.float() + else: + if model_param.grad is not None: + main_param.grad = model_param.grad.float() + + # Safe to deallocate model's grad/main_grad after copying. + # (If using contiguous buffers, main_grad's memory should + # persist and therefore should not be deallocated.) + model_param.grad = None + if self.params_have_main_grad and \ + not self.use_contiguous_buffers_in_local_ddp: + model_param.main_grad = None + + # For fp32 grads, we need to reset the grads to main grad. + if self.params_have_main_grad: + for model_group in self.fp32_from_fp32_groups: + for model_param in model_group: + model_param.grad = model_param.main_grad + + # Safe to de-reference model's main_grad after copying. + # (If using contiguous buffers, main_grad's memory should + # persist and therefore should not be deallocated.) + if not self.use_contiguous_buffers_in_local_ddp: + model_param.main_grad = None + + + def _copy_main_params_to_model_params(self): + # Only needed for the float16 params. + model_data, main_data = self._get_model_and_main_params_data_float16() + _multi_tensor_copy_this_to_that(this=main_data, that=model_data, + overflow_buf=self._dummy_overflow_buf) + + + def _copy_model_params_to_main_params(self): + # Only needed for the float16 params. + model_data, main_data = self._get_model_and_main_params_data_float16() + _multi_tensor_copy_this_to_that(this=model_data, that=main_data, + overflow_buf=self._dummy_overflow_buf) + + + def state_dict(self): + state_dict = {} + state_dict['optimizer'] = self.optimizer.state_dict() + if self.grad_scaler: + state_dict['grad_scaler'] = self.grad_scaler.state_dict() + state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups + return state_dict + + + def load_state_dict(self, state_dict): + # Optimizer. + optimizer_key = 'optimizer' + if optimizer_key not in state_dict: + optimizer_key = 'optimizer_state_dict' + print_rank_0('***WARNING*** loading optimizer from ' + 'an old checkpoint ...') + self.optimizer.load_state_dict(state_dict[optimizer_key]) + + # Grad scaler. + if 'grad_scaler' not in state_dict: + print_rank_0('***WARNING*** found an old checkpoint, will not ' + 'load grad scaler ...') + else: + if self.grad_scaler: + self.grad_scaler.load_state_dict(state_dict['grad_scaler']) + else: + print_rank_0('***WARNING*** fould the grad scaler in the ' + 'checkpoint but it is None in the class. ' + 'Skipping loading grad scaler ...') + + # Copy data for the main params. + fp32_from_float16_params_key = 'fp32_from_fp16_params' + if fp32_from_float16_params_key not in state_dict: + fp32_from_float16_params_key = 'fp32_from_fp16' + for current_group, saved_group in zip( + self.fp32_from_float16_groups, + state_dict[fp32_from_float16_params_key]): + for current_param, saved_param in zip(current_group, saved_group): + current_param.data.copy_(saved_param.data) + + +class FP32Optimizer(MegatronOptimizer): + + def __init__(self, optimizer, clip_grad, + log_num_zeros_in_grad, + params_have_main_grad, + use_contiguous_buffers_in_local_ddp, + models): + + super(FP32Optimizer, self).__init__( + optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + models) + + self._scale = torch.cuda.FloatTensor([1.0]) + + + def zero_grad(self, set_to_none=True): + """Copied from torch.optim.optimizer""" + for group in self.optimizer.param_groups: + _zero_grad_group_helper(group['params'], set_to_none) + + + def get_loss_scale(self): + """FP32 optimizer does not do any scaling.""" + return self._scale + + + @torch.no_grad() + def step(self, args, timers): + """Clip gradients (if needed) and step the base optimizer. + Always return successful since there is no overflow.""" + + # Copy main_grads to grads. + timers('optimizer-copy-to-main-grad').start() + if self.params_have_main_grad: + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + param.grad = param.main_grad + + # Safe to de-reference model's main_grad after copying. + # (If using contiguous buffers, main_grad's memory should + # persist and therefore should not be deallocated.) + if not self.use_contiguous_buffers_in_local_ddp: + param.main_grad = None + timers('optimizer-copy-to-main-grad').stop() + + # Clip gradients. + timers('optimizer-clip-main-grad').start() + grad_norm = None + if self.clip_grad > 0.0: + grad_norm = self.clip_grad_norm(self.clip_grad) + timers('optimizer-clip-main-grad').stop() + + # count the zeros in the grads + timers('optimizer-count-zeros').start() + num_zeros_in_grad = self.count_zeros() if \ + self.log_num_zeros_in_grad else None + timers('optimizer-count-zeros').stop() + + # Update parameters. + timers('optimizer-inner-step').start() + self.optimizer.step() + timers('optimizer-inner-step').stop() + + # No overflow for FP32 optimizer. + return True, grad_norm, num_zeros_in_grad + + + def reload_model_params(self): + pass + + + def state_dict(self): + return self.optimizer.state_dict() + + + def load_state_dict(self, state_dict): + self.optimizer.load_state_dict(state_dict) diff --git a/megatron/optimizer_param_scheduler.py b/megatron/optimizer_param_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..30951c46475a5cf6a210c4aca014cd9c8b615cc2 --- /dev/null +++ b/megatron/optimizer_param_scheduler.py @@ -0,0 +1,234 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Learning rate decay and weight decay incr functions.""" + +import math + +from megatron import print_rank_0 + +class OptimizerParamScheduler(object): + """Anneals learning rate and weight decay""" + + def __init__(self, optimizer, max_lr, min_lr, + lr_warmup_steps, lr_decay_steps, lr_decay_style, + start_wd, end_wd, wd_incr_steps, wd_incr_style, + use_checkpoint_opt_param_scheduler=True, + override_opt_param_scheduler=False): + + # Class values. + self.optimizer = optimizer + + self.max_lr = float(max_lr) + self.min_lr = min_lr + assert self.min_lr >= 0.0 + assert self.max_lr >= self.min_lr + + self.lr_warmup_steps = lr_warmup_steps + self.num_steps = 0 + self.lr_decay_steps = lr_decay_steps + assert self.lr_decay_steps > 0 + assert self.lr_warmup_steps < self.lr_decay_steps + + self.lr_decay_style = lr_decay_style + + self.start_wd = start_wd + self.end_wd = end_wd + assert self.start_wd >= 0.0 + assert self.end_wd >= self.start_wd + self.wd_incr_steps = wd_incr_steps + self.wd_incr_style = wd_incr_style + + self.override_opt_param_scheduler = override_opt_param_scheduler + self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler + if self.override_opt_param_scheduler: + assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\ + 'use-checkpoint are set.' + + # Set the learning rate + self.step(0) + print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style)) + + + def get_wd(self): + """ Weight decay incr functions""" + if self.num_steps > self.wd_incr_steps: + return self.end_wd + + if self.wd_incr_style == 'constant': + assert self.start_wd == self.end_wd + return self.end_wd + + incr_ratio = float(self.num_steps) / float(self.wd_incr_steps) + assert incr_ratio >= 0.0 + assert incr_ratio <= 1.0 + delta_wd = self.end_wd - self.start_wd + + if self.wd_incr_style == 'linear': + coeff = incr_ratio + elif self.wd_incr_style == 'cosine': + coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0) + else: + raise Exception('{} weight decay increment style is not supported.'.format( + self.wd_incr_style)) + + return self.start_wd + coeff * delta_wd + + + def get_lr(self): + """Learning rate decay functions from: + https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" + + # Use linear warmup for the initial part. + if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps: + return self.max_lr * float(self.num_steps) / \ + float(self.lr_warmup_steps) + + # If the learning rate is constant, just return the initial value. + if self.lr_decay_style == 'constant': + return self.max_lr + + # For any steps larger than `self.lr_decay_steps`, use `self.min_lr`. + if self.num_steps > self.lr_decay_steps: + return self.min_lr + + # If we are done with the warmup period, use the decay style. + num_steps_ = self.num_steps - self.lr_warmup_steps + decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + assert decay_ratio >= 0.0 + assert decay_ratio <= 1.0 + delta_lr = self.max_lr - self.min_lr + + if self.lr_decay_style == 'linear': + coeff = (1.0 - decay_ratio) + elif self.lr_decay_style == 'cosine': + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + else: + raise Exception('{} decay style is not supported.'.format( + self.lr_decay_style)) + + return self.min_lr + coeff * delta_lr + + + def step(self, increment): + """Set lr for all parameters groups.""" + self.num_steps += increment + new_lr = self.get_lr() + new_wd = self.get_wd() + for group in self.optimizer.param_groups: + group['lr'] = new_lr * group.get('lr_mult', 1.0) + group['weight_decay'] = new_wd * group.get('wd_mult', 1.0) + + + def state_dict(self): + state_dict = { + 'max_lr': self.max_lr, + 'lr_warmup_steps': self.lr_warmup_steps, + 'num_steps': self.num_steps, + 'lr_decay_style': self.lr_decay_style, + 'lr_decay_steps': self.lr_decay_steps, + 'min_lr': self.min_lr, + 'start_wd': self.start_wd, + 'end_wd': self.end_wd, + 'wd_incr_style': self.wd_incr_style, + 'wd_incr_steps': self.wd_incr_steps + } + return state_dict + + + def _check_and_set(self, cls_value, sd_value, name): + """Auxiliary function for checking the values in the checkpoint and + setting them.""" + if self.override_opt_param_scheduler: + print_rank_0(' > overriding {} value to {}'.format(name, cls_value)) + return cls_value + + if not self.use_checkpoint_opt_param_scheduler: + assert cls_value == sd_value, \ + f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \ + f'value {sd_value} for {name} do not match' + print_rank_0(' > using checkpoint value {} for {}'.format(sd_value, + name)) + return sd_value + + + def load_state_dict(self, sd): + + if 'start_lr' in sd: + max_lr_ = sd['start_lr'] + else: + max_lr_ = sd['max_lr'] + self.max_lr = self._check_and_set(self.max_lr, max_lr_, + 'learning rate') + + self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], + 'minimum learning rate') + + if 'warmup_iter' in sd: + lr_warmup_steps_ = sd['warmup_iter'] + elif 'warmup_steps' in sd: + lr_warmup_steps_ = sd['warmup_steps'] + else: + lr_warmup_steps_ = sd['lr_warmup_steps'] + self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps, + lr_warmup_steps_, + 'warmup iterations') + + if 'end_iter' in sd: + lr_decay_steps_ = sd['end_iter'] + elif 'decay_steps' in sd: + lr_decay_steps_ = sd['decay_steps'] + else: + lr_decay_steps_ = sd['lr_decay_steps'] + self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_, + 'total number of iterations') + + if 'decay_style' in sd: + lr_decay_style_ = sd['decay_style'] + else: + lr_decay_style_ = sd['lr_decay_style'] + self.lr_decay_style = self._check_and_set(self.lr_decay_style, + lr_decay_style_, + 'learning rate decay style') + + if 'num_iters' in sd: + num_steps = sd['num_iters'] + else: + num_steps = sd['num_steps'] + self.step(increment=num_steps) + + + if 'start_wd' in sd: + self.start_wd = self._check_and_set(self.start_wd, + sd['start_wd'], + "start weight decay") + self.end_wd = self._check_and_set(self.end_wd, + sd['end_wd'], + "end weight decay") + self.wd_incr_steps = self._check_and_set(self.wd_incr_steps, + sd['wd_incr_steps'], + "total number of weight decay iterations") + self.wd_incr_style = self._check_and_set(self.wd_incr_style, + sd['wd_incr_style'], + "weight decay incr style") + + + + + + + + diff --git a/megatron/p2p_communication.py b/megatron/p2p_communication.py new file mode 100644 index 0000000000000000000000000000000000000000..219b9688c1a5284be677ed5d151660078c66a49d --- /dev/null +++ b/megatron/p2p_communication.py @@ -0,0 +1,312 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import reduce +import operator +import torch + +from megatron import get_args +from megatron import mpu + + +def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, + tensor_shape, + use_ring_exchange=False, + dtype_=None): + """Communicate tensors between stages. Used as helper method in other + communication methods that are used in megatron/schedules.py. + + Takes the following arguments: + tensor_send_next: tensor to send to next rank (no tensor sent if + set to None). + tensor_send_prev: tensor to send to prev rank (no tensor sent if + set to None). + recv_prev: boolean for whether tensor should be received from + previous rank. + recv_next: boolean for whether tensor should be received from + next rank. + tensor_shape: shape of tensor to receive (this method assumes that all + tensors sent and received in a single function call are + the same shape). + use_ring_exchange: boolean for whether torch.distributed.ring_exchange() + API should be used. + dtype_: optional, this is used when the tensor that needs to be + communicated is different from args.params_dtype. + Returns: + (tensor_recv_prev, tensor_recv_next) + """ + args = get_args() + + # Create placeholder tensors for receive in forward and backward directions + # if needed. + tensor_recv_prev = None + tensor_recv_next = None + + # Some legacy inference code doesn't set the tensor shape, do so now + # for the normal values for gpt/bert. This could be removed if inference + # code is changed to provide tensor_shape. + if tensor_shape is None: + tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) + + override_scatter_gather_tensors_in_pipeline = False + if args.scatter_gather_tensors_in_pipeline and \ + not args.sequence_parallel: + tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) + if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0: + tensor_chunk_shape = tensor_chunk_shape // \ + mpu.get_tensor_model_parallel_world_size() + else: + tensor_chunk_shape = tensor_shape + override_scatter_gather_tensors_in_pipeline = True + else: + tensor_chunk_shape = tensor_shape + dtype = args.params_dtype + if args.fp32_residual_connection: + dtype = torch.float + + requires_grad = True + if dtype_ is not None: + dtype = dtype_ + requires_grad = False + + if recv_prev: + tensor_recv_prev = torch.empty(tensor_chunk_shape, + requires_grad=requires_grad, + device=torch.cuda.current_device(), + dtype=dtype) + if recv_next: + tensor_recv_next = torch.empty(tensor_chunk_shape, + requires_grad=requires_grad, + device=torch.cuda.current_device(), + dtype=dtype) + + # Split tensor into smaller chunks if using scatter-gather optimization. + if not override_scatter_gather_tensors_in_pipeline and \ + args.scatter_gather_tensors_in_pipeline and \ + not args.sequence_parallel: + if tensor_send_next is not None: + tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) + + if tensor_send_prev is not None: + tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) + + # Send tensors in both the forward and backward directions as appropriate. + if use_ring_exchange: + torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, + tensor_recv_prev=tensor_recv_prev, + tensor_send_next=tensor_send_next, + tensor_recv_next=tensor_recv_next, + group=mpu.get_pipeline_model_parallel_group()) + else: + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_prev, + mpu.get_pipeline_model_parallel_prev_rank()) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor_recv_prev, + mpu.get_pipeline_model_parallel_prev_rank()) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_next, + mpu.get_pipeline_model_parallel_next_rank()) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor_recv_next, + mpu.get_pipeline_model_parallel_next_rank()) + ops.append(recv_next_op) + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() + + # If using scatter-gather optimization, gather smaller chunks. + if not override_scatter_gather_tensors_in_pipeline and \ + args.scatter_gather_tensors_in_pipeline and \ + not args.sequence_parallel: + if recv_prev: + tensor_recv_prev = mpu.gather_split_1d_tensor( + tensor_recv_prev).view(tensor_shape).requires_grad_() + tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev, + requires_grad = True, + keep_graph = False) + + if recv_next: + tensor_recv_next = mpu.gather_split_1d_tensor( + tensor_recv_next).view(tensor_shape).requires_grad_() + tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next, + requires_grad = True, + keep_graph = False) + + return tensor_recv_prev, tensor_recv_next + + +def recv_forward(tensor_shape=None, dtype_=None, timers=None): + """Receive tensor from previous rank in pipeline (forward receive).""" + + if mpu.is_pipeline_first_stage(): + input_tensor = None + else: + if timers is not None: + timers('forward-recv').start() + input_tensor, _ = _communicate( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=True, + recv_next=False, + tensor_shape=tensor_shape, + dtype_=dtype_) + if timers is not None: + timers('forward-recv').stop() + return input_tensor + + +def recv_backward(tensor_shape=None, timers=None): + """Receive tensor from next rank in pipeline (backward receive).""" + if mpu.is_pipeline_last_stage(): + output_tensor_grad = None + else: + if timers is not None: + timers('backward-recv').start() + _, output_tensor_grad = _communicate( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=False, + recv_next=True, + tensor_shape=tensor_shape) + if timers is not None: + timers('backward-recv').stop() + return output_tensor_grad + + +def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None): + """Send tensor to next rank in pipeline (forward send).""" + + if not mpu.is_pipeline_last_stage(): + if timers is not None: + timers('forward-send').start() + _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=False, + tensor_shape=tensor_shape, + dtype_=dtype_) + if timers is not None: + timers('forward-send').stop() + + +def send_backward(input_tensor_grad, tensor_shape=None, timers=None): + """Send tensor to previous rank in pipeline (backward send).""" + if not mpu.is_pipeline_first_stage(): + if timers is not None: + timers('backward-send').start() + _communicate( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=False, + recv_next=False, + tensor_shape=tensor_shape) + if timers is not None: + timers('backward-send').stop() + + +def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None): + """Batched send and recv with next rank in pipeline.""" + if mpu.is_pipeline_last_stage(): + output_tensor_grad = None + else: + if timers is not None: + timers('forward-send-backward-recv').start() + _, output_tensor_grad = _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=True, + tensor_shape=tensor_shape) + if timers is not None: + timers('forward-send-backward-recv').stop() + return output_tensor_grad + + +def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None): + """Batched send and recv with previous rank in pipeline.""" + if mpu.is_pipeline_first_stage(): + input_tensor = None + else: + if timers is not None: + timers('backward-send-forward-recv').start() + input_tensor, _ = _communicate( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=True, + recv_next=False, + tensor_shape=tensor_shape) + if timers is not None: + timers('backward-send-forward-recv').stop() + return input_tensor + + +def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None): + """Batched recv from previous rank and send to next rank in pipeline.""" + if timers is not None: + timers('forward-send-forward-recv').start() + input_tensor, _ = _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=recv_prev, + recv_next=False, + tensor_shape=tensor_shape) + if timers is not None: + timers('forward-send-forward-recv').stop() + return input_tensor + + +def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None): + """Batched recv from next rank and send to previous rank in pipeline.""" + if timers is not None: + timers('backward-send-backward-recv').start() + _, output_tensor_grad = _communicate( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=False, + recv_next=recv_next, + tensor_shape=tensor_shape) + if timers is not None: + timers('backward-send-backward-recv').stop() + return output_tensor_grad + + +def send_forward_backward_recv_forward_backward( + output_tensor, input_tensor_grad, recv_prev, + recv_next, tensor_shape=None, timers=None): + """Batched send and recv with previous and next ranks in pipeline.""" + if timers is not None: + timers('forward-backward-send-forward-backward-recv').start() + input_tensor, output_tensor_grad = _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + tensor_shape=tensor_shape) + if timers is not None: + timers('forward-backward-send-forward-backward-recv').stop() + return input_tensor, output_tensor_grad diff --git a/megatron/schedules.py b/megatron/schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..ac5ba6f67e89aa748aaa63a058d67e02033214d4 --- /dev/null +++ b/megatron/schedules.py @@ -0,0 +1,728 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +import torch +from torch.autograd.variable import Variable +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP + +from megatron import get_args +from megatron import get_num_microbatches +from megatron import get_timers +from megatron import mpu +from megatron import p2p_communication +from megatron.utils import unwrap_model +from megatron.model import DistributedDataParallel as LocalDDP +from megatron.model import Float16Module +from megatron.model import ModelType + + +def get_forward_backward_func(): + args = get_args() + if mpu.get_pipeline_model_parallel_world_size() > 1: + if args.virtual_pipeline_model_parallel_size is not None: + forward_backward_func = forward_backward_pipelining_with_interleaving + assert get_num_microbatches() % \ + args.pipeline_model_parallel_size == 0, \ + 'number of microbatches (%d) is not divisible by pipeline-' \ + 'model-parallel-size (%d) when using interleaved schedule' % ( + get_num_microbatches(), + args.pipeline_model_parallel_size, + ) + else: + forward_backward_func = forward_backward_pipelining_without_interleaving + else: + forward_backward_func = forward_backward_no_pipelining + return forward_backward_func + +def deallocate_output_tensor(out): + '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. + + This method should be called right after the output tensor has been + sent to the next pipeline stage. At this point, the output tensor is + only useful for its '.grad_fn' field, and not its '.data'. + ''' + if out is None: + return + assert isinstance(out, torch.Tensor), \ + "expected Tensor, found %s." % type(out).__name__ + assert out._base is None, \ + "counter-productive to free a view of another tensor." + out.data = torch.empty( + (1,), + device = out.device, + dtype = out.dtype, + ) + +def custom_backward(output, grad_output): + '''Directly call C++ autograd engine. + + To make the 'deallocate_output_tensor' (above) optimization work, the C++ + autograd engine must be called directly, bypassing Pytorch's + torch.autograd.backward. Pytorch's 'backward' checks that the output and + grad have the same shape, while C++'s 'backward' does not. + ''' + + assert output.numel() == 1, \ + "output should be pseudo-'freed' in schedule, to optimize memory" + assert isinstance(output, torch.Tensor), \ + "output == '%s'." % type(output).__name__ + assert isinstance(grad_output, (torch.Tensor, type(None))), \ + "grad_output == '%s'." % type(grad_output).__name__ + + # Handle scalar output + if grad_output is None: + assert output.numel() == 1, "implicit grad requires scalar output." + grad_output = torch.ones_like( + output, + memory_format = torch.preserve_format, + ) + + # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ] + Variable._execution_engine.run_backward( + tensors = (output,), + grad_tensors = (grad_output,), + keep_graph = False, + create_graph = False, + inputs = tuple(), + allow_unreachable=True, + accumulate_grad=True, + ) + + +def forward_step(forward_step_func, + data_iterator, + model, + input_tensor, + forward_data_store, + collect_non_loss_data=False): + """Forward step for passed-in model. + + If first stage, input tensor is obtained from data_iterator, otherwise + passed-in input_tensor is used. + + Returns output tensor.""" + args = get_args() + timers = get_timers() + + timers('forward-compute').start() + unwrapped_model = unwrap_model( + model, (torchDDP, LocalDDP, Float16Module)) + + unwrap_output_tensor = False + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + unwrap_output_tensor = True + + unwrapped_model.set_input_tensor(input_tensor) + output_tensor, loss_func = forward_step_func(data_iterator, model) + if mpu.is_pipeline_last_stage(): + if not collect_non_loss_data: + output_tensor = loss_func(output_tensor) + loss, loss_reduced = output_tensor + output_tensor = loss / get_num_microbatches() + forward_data_store.append(loss_reduced) + else: + data = loss_func(output_tensor, non_loss_data=True) + forward_data_store.append(data) + + timers('forward-compute').stop() + + # If T5 model (or other model with encoder and decoder) + # and in decoder stack, then send encoder_hidden_state + # downstream as well. + if mpu.is_pipeline_stage_after_split() and \ + args.model_type == ModelType.encoder_and_decoder: + return [output_tensor, input_tensor[-1]] + if unwrap_output_tensor: + return output_tensor + return [output_tensor] + + +def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): + """Backward step through passed-in output tensor. + + If last stage, output_tensor_grad is None, otherwise gradient of loss + with respect to stage's output tensor. + + Returns gradient of loss with respect to input tensor (None if first + stage).""" + + # NOTE: This code currently can handle at most one skip connection. It + # needs to be modified slightly to support arbitrary numbers of skip + # connections. + args = get_args() + + timers = get_timers() + timers('backward-compute').start() + + # Retain the grad on the input_tensor. + unwrap_input_tensor_grad = False + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + unwrap_input_tensor_grad = True + for x in input_tensor: + if x is not None: + x.retain_grad() + + if not isinstance(output_tensor, list): + output_tensor = [output_tensor] + if not isinstance(output_tensor_grad, list): + output_tensor_grad = [output_tensor_grad] + + # Backward pass. + if output_tensor_grad[0] is None: + output_tensor = optimizer.scale_loss(output_tensor[0]) + custom_backward(output_tensor[0], output_tensor_grad[0]) + + # Collect the grad of the input_tensor. + input_tensor_grad = [None] + if input_tensor is not None: + input_tensor_grad = [] + for x in input_tensor: + if x is None: + input_tensor_grad.append(None) + else: + input_tensor_grad.append(x.grad) + + # Handle single skip connection if it exists (encoder_hidden_state in + # model with encoder and decoder). + if mpu.get_pipeline_model_parallel_world_size() > 1 and \ + mpu.is_pipeline_stage_after_split() and \ + args.model_type == ModelType.encoder_and_decoder: + if output_tensor_grad[1] is not None: + input_tensor_grad[-1].add_(output_tensor_grad[1]) + if unwrap_input_tensor_grad: + input_tensor_grad = input_tensor_grad[0] + + timers('backward-compute').stop() + + return input_tensor_grad + + +@contextmanager +def dummy_handler(): + try: + yield + finally: + pass + + +def forward_backward_no_pipelining(forward_step_func, + data_iterator, model, + optimizer, + timers, + forward_only, + collect_non_loss_data=False): + """Run forward and backward passes with no pipeline parallelism + (no inter-stage communication). + + Returns dictionary with losses.""" + assert len(model) == 1 + model = model[0] + + context_handler = dummy_handler + if isinstance(model, torchDDP): + context_handler = model.no_sync + + forward_data_store = [] + input_tensor, output_tensor_grad = None, None + with context_handler(): + for i in range(get_num_microbatches() - 1): + output_tensor = forward_step(forward_step_func, data_iterator, + model, input_tensor, forward_data_store, + collect_non_loss_data) + if not forward_only: + backward_step(optimizer, input_tensor, output_tensor, + output_tensor_grad) + + # Run computation for last microbatch out of context handler (want to + # synchronize gradients). + output_tensor = forward_step(forward_step_func, data_iterator, + model, input_tensor, forward_data_store, + collect_non_loss_data) + if not forward_only: + backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) + + return forward_data_store + + +def forward_backward_pipelining_with_interleaving(forward_step_func, + data_iterator, model, + optimizer, + timers, + forward_only, + collect_non_loss_data=False): + """Run interleaved 1F1B schedule (model split into model chunks), with + communication between pipeline stages as needed. + + Returns dictionary with losses if the last stage, empty dict otherwise.""" + input_tensors = [[] for _ in range(len(model))] + output_tensors = [[] for _ in range(len(model))] + forward_data_store = [] + if not forward_only: + output_tensor_grads = [[] for _ in range(len(model))] + + pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size() + pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank() + + args = get_args() + if args.sequence_parallel: + seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() + else: + seq_length = args.seq_length + tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size) + + # Compute number of warmup and remaining microbatches. + num_model_chunks = len(model) + num_microbatches = get_num_microbatches() * num_model_chunks + all_warmup_microbatches = False + if forward_only: + num_warmup_microbatches = num_microbatches + else: + # Run all forward passes and then all backward passes if number of + # microbatches is just the number of pipeline stages. + # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on + # all workers, followed by more microbatches after depending on + # stage ID (more forward passes for earlier stages, later stages can + # immediately start with 1F1B). + if get_num_microbatches() == pipeline_parallel_size: + num_warmup_microbatches = num_microbatches + all_warmup_microbatches = True + else: + num_warmup_microbatches = \ + (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 + num_warmup_microbatches += ( + num_model_chunks - 1) * pipeline_parallel_size + num_warmup_microbatches = min(num_warmup_microbatches, + num_microbatches) + num_microbatches_remaining = \ + num_microbatches - num_warmup_microbatches + + def get_model_chunk_id(microbatch_id, forward): + """Helper method to get the model chunk ID given the iteration number.""" + microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) + model_chunk_id = microbatch_id_in_group // pipeline_parallel_size + if not forward: + model_chunk_id = (num_model_chunks - model_chunk_id - 1) + return model_chunk_id + + def forward_step_helper(microbatch_id): + """Helper method to run forward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + forward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) + mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) + + # forward step + if mpu.is_pipeline_first_stage(): + if len(input_tensors[model_chunk_id]) == \ + len(output_tensors[model_chunk_id]): + input_tensors[model_chunk_id].append(None) + input_tensor = input_tensors[model_chunk_id][-1] + output_tensor = forward_step(forward_step_func, + data_iterator[model_chunk_id], + model[model_chunk_id], + input_tensor, + forward_data_store, + collect_non_loss_data) + output_tensors[model_chunk_id].append(output_tensor) + + # if forward-only, no need to save tensors for a backward pass + if forward_only: + input_tensors[model_chunk_id].pop() + output_tensors[model_chunk_id].pop() + + return output_tensor + + def backward_step_helper(microbatch_id): + """Helper method to run backward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + backward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) + mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) + + if mpu.is_pipeline_last_stage(): + if len(output_tensor_grads[model_chunk_id]) == 0: + output_tensor_grads[model_chunk_id].append(None) + input_tensor = input_tensors[model_chunk_id].pop(0) + output_tensor = output_tensors[model_chunk_id].pop(0) + output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) + input_tensor_grad = \ + backward_step(optimizer, + input_tensor, + output_tensor, + output_tensor_grad) + + return input_tensor_grad + + # Run warmup forward passes. + mpu.set_virtual_pipeline_model_parallel_rank(0) + input_tensors[0].append( + p2p_communication.recv_forward(tensor_shape, timers=timers)) + for k in range(num_warmup_microbatches): + output_tensor = forward_step_helper(k) + + # Determine if tensor should be received from previous stage. + next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True) + recv_prev = True + if mpu.is_pipeline_first_stage(ignore_virtual=True): + if next_forward_model_chunk_id == 0: + recv_prev = False + if k == (num_microbatches - 1): + recv_prev = False + + # Don't send tensor downstream if on last stage. + if mpu.is_pipeline_last_stage(): + output_tensor = None + + # Send and receive tensors as appropriate (send tensors computed + # in this iteration; receive tensors for next iteration). + if k == (num_warmup_microbatches - 1) and not forward_only and \ + not all_warmup_microbatches: + input_tensor_grad = None + recv_next = True + if mpu.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + input_tensor, output_tensor_grad = \ + p2p_communication.send_forward_backward_recv_forward_backward( + output_tensor, input_tensor_grad, + recv_prev=recv_prev, recv_next=recv_next, + tensor_shape=tensor_shape, + timers=timers) + output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) + else: + input_tensor = \ + p2p_communication.send_forward_recv_forward( + output_tensor, recv_prev=recv_prev, + tensor_shape=tensor_shape, + timers=timers) + input_tensors[next_forward_model_chunk_id].append(input_tensor) + deallocate_output_tensor(output_tensor) + + # Run 1F1B in steady state. + for k in range(num_microbatches_remaining): + # Forward pass. + forward_k = k + num_warmup_microbatches + output_tensor = forward_step_helper(forward_k) + + # Backward pass. + backward_k = k + input_tensor_grad = backward_step_helper(backward_k) + + # Send output_tensor and input_tensor_grad, receive input_tensor + # and output_tensor_grad. + + # Determine if current stage has anything to send in either direction, + # otherwise set tensor to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) + if mpu.is_pipeline_last_stage(): + output_tensor = None + + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) + if mpu.is_pipeline_first_stage(): + input_tensor_grad = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if mpu.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id( + forward_k - (pipeline_parallel_size - 1), forward=True) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, + forward=True) + + recv_next = True + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, + forward=False) + + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False + + # Communicate tensors. + input_tensor, output_tensor_grad = \ + p2p_communication.send_forward_backward_recv_forward_backward( + output_tensor, input_tensor_grad, + recv_prev=recv_prev, recv_next=recv_next, + tensor_shape=tensor_shape, timers=timers) + deallocate_output_tensor(output_tensor) + + # Put input_tensor and output_tensor_grad in data structures in the + # right location. + if recv_prev: + input_tensors[next_forward_model_chunk_id].append(input_tensor) + if recv_next: + output_tensor_grads[next_backward_model_chunk_id].append( + output_tensor_grad) + + # Run cooldown backward passes (flush out pipeline). + if not forward_only: + if all_warmup_microbatches: + output_tensor_grads[num_model_chunks-1].append( + p2p_communication.recv_backward(tensor_shape, timers=timers)) + for k in range(num_microbatches_remaining, num_microbatches): + input_tensor_grad = backward_step_helper(k) + next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) + recv_next = True + if mpu.is_pipeline_last_stage(ignore_virtual=True): + if next_backward_model_chunk_id == (num_model_chunks - 1): + recv_next = False + if k == (num_microbatches - 1): + recv_next = False + output_tensor_grads[next_backward_model_chunk_id].append( + p2p_communication.send_backward_recv_backward( + input_tensor_grad, recv_next=recv_next, + tensor_shape=tensor_shape, + timers=timers)) + + return forward_data_store + + +def get_tensor_shapes(rank, model_type): + # Determine right tensor sizes (based on position of rank with respect to split + # rank) and model size. + # Send two tensors if model is T5 and rank is in decoder stage: + # first tensor is decoder (pre-transpose), + # second tensor is encoder (post-transpose). + # If model is T5 and rank is at the boundary: + # send one tensor (post-transpose from encoder). + # Otherwise, send one tensor (pre-transpose). + args = get_args() + tensor_shapes = [] + + if args.sequence_parallel: + seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() + else: + seq_length = args.seq_length + + if model_type == ModelType.encoder_and_decoder: + if args.sequence_parallel: + decoder_seq_length = args.decoder_seq_length // mpu.get_tensor_model_parallel_world_size() + else: + decoder_seq_length = args.decoder_seq_length + + if mpu.is_pipeline_stage_before_split(rank): + tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size)) + else: + tensor_shapes.append((decoder_seq_length, args.micro_batch_size, args.hidden_size)) + tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size)) + else: + tensor_shapes.append((seq_length, args.micro_batch_size, args.hidden_size)) + return tensor_shapes + + +def recv_forward(tensor_shapes, timers): + input_tensors = [] + for tensor_shape in tensor_shapes: + if tensor_shape is None: + input_tensors.append(None) + else: + input_tensors.append(p2p_communication.recv_forward(tensor_shape, + timers=timers)) + return input_tensors + + +def recv_backward(tensor_shapes, timers): + output_tensor_grads = [] + for tensor_shape in tensor_shapes: + if tensor_shape is None: + output_tensor_grads.append(None) + else: + output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, + timers=timers)) + return output_tensor_grads + + +def send_forward(output_tensors, tensor_shapes, timers): + if not isinstance(output_tensors, list): + output_tensors = [output_tensors] + for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes): + if tensor_shape is None: + continue + p2p_communication.send_forward(output_tensor, tensor_shape, timers=timers) + + +def send_backward(input_tensor_grads, tensor_shapes, timers): + if not isinstance(input_tensor_grads, list): + input_tensor_grads = [input_tensor_grads] + for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes): + if tensor_shape is None: + continue + p2p_communication.send_backward(input_tensor_grad, tensor_shape, timers=timers) + + +def send_forward_recv_backward(output_tensors, tensor_shapes, timers): + if not isinstance(output_tensors, list): + output_tensors = [output_tensors] + output_tensor_grads = [] + for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes): + if tensor_shape is None: + output_tensor_grads.append(None) + continue + output_tensor_grad = p2p_communication.send_forward_recv_backward( + output_tensor, tensor_shape, timers=timers) + output_tensor_grads.append(output_tensor_grad) + return output_tensor_grads + + +def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers): + if not isinstance(input_tensor_grads, list): + input_tensor_grads = [input_tensor_grads] + input_tensors = [] + for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes): + if tensor_shape is None: + input_tensors.append(None) + continue + input_tensor = p2p_communication.send_backward_recv_forward( + input_tensor_grad, tensor_shape, timers=timers) + input_tensors.append(input_tensor) + return input_tensors + + +def forward_backward_pipelining_without_interleaving(forward_step_func, + data_iterator, + model, + optimizer, + timers, + forward_only, + collect_non_loss_data=False): + """Run non-interleaved 1F1B schedule, with communication between pipeline + stages. + + Returns dictionary with losses if the last stage, empty dict otherwise.""" + args = get_args() + timers = get_timers() + + assert len(model) == 1 + model = model[0] + + # Compute number of warmup microbatches. + num_microbatches = get_num_microbatches() + num_warmup_microbatches = \ + (mpu.get_pipeline_model_parallel_world_size() - + mpu.get_pipeline_model_parallel_rank() - 1) + num_warmup_microbatches = min( + num_warmup_microbatches, + num_microbatches) + num_microbatches_remaining = \ + num_microbatches - num_warmup_microbatches + + unwrapped_model = unwrap_model( + model, (torchDDP, LocalDDP, Float16Module)) + model_type = unwrapped_model.model_type + rank = mpu.get_pipeline_model_parallel_rank() + recv_tensor_shapes = get_tensor_shapes(rank-1, model_type) + send_tensor_shapes = get_tensor_shapes(rank, model_type) + + # Input, output tensors only need to be saved when doing backward passes + input_tensors = None + output_tensors = None + if not forward_only: + input_tensors = [] + output_tensors = [] + forward_data_store = [] + + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + input_tensor = recv_forward(recv_tensor_shapes, timers=timers) + output_tensor = forward_step(forward_step_func, data_iterator, model, + input_tensor, forward_data_store, + collect_non_loss_data) + send_forward(output_tensor, send_tensor_shapes, timers=timers) + + if not forward_only: + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + deallocate_output_tensor(output_tensor[0]) + + # Before running 1F1B, need to receive first forward tensor. + # If all microbatches are run in warmup / cooldown phase, then no need to + # receive this tensor here. + if num_microbatches_remaining > 0: + input_tensor = recv_forward(recv_tensor_shapes, timers=timers) + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + last_iteration = (i == (num_microbatches_remaining - 1)) + + output_tensor = forward_step(forward_step_func, data_iterator, model, + input_tensor, forward_data_store, + collect_non_loss_data) + if forward_only: + send_forward(output_tensor, send_tensor_shapes, timers=timers) + + if not last_iteration: + input_tensor = recv_forward(recv_tensor_shapes, timers=timers) + + else: + output_tensor_grad = \ + send_forward_recv_backward(output_tensor, + send_tensor_shapes, + timers=timers) + + # Add input_tensor and output_tensor to end of list. + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + deallocate_output_tensor(output_tensor[0]) + + # Pop input_tensor and output_tensor from the start of the list for + # the backward pass. + input_tensor = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) + + input_tensor_grad = \ + backward_step(optimizer, input_tensor, output_tensor, + output_tensor_grad) + + if last_iteration: + input_tensor = None + send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers) + else: + input_tensor = \ + send_backward_recv_forward( + input_tensor_grad, recv_tensor_shapes, timers=timers) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_warmup_microbatches): + input_tensor = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) + + output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers) + + input_tensor_grad = \ + backward_step(optimizer, input_tensor, output_tensor, + output_tensor_grad) + + send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers) + + return forward_data_store diff --git a/megatron/static/index.html b/megatron/static/index.html new file mode 100644 index 0000000000000000000000000000000000000000..590ae89db413292a8c4535fbb11c9c0a7d83069a --- /dev/null +++ b/megatron/static/index.html @@ -0,0 +1,136 @@ + + + + + + + + + + + + + + + + + + + +Megatron + + + +
+

Prompt Megatron

+ + + + + +
+0 +/ 1000 +
+ +
+ + + + + diff --git a/megatron/text_generation/__init__.py b/megatron/text_generation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d145e9ba1119135e93497f88a45846a3f3a4d792 --- /dev/null +++ b/megatron/text_generation/__init__.py @@ -0,0 +1,20 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .api import ( + generate, + generate_and_post_process, + beam_search_and_post_process) diff --git a/megatron/text_generation/api.py b/megatron/text_generation/api.py new file mode 100644 index 0000000000000000000000000000000000000000..9362ea3a33796da6ce37467b277d58c9c6a6e684 --- /dev/null +++ b/megatron/text_generation/api.py @@ -0,0 +1,202 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inference API.""" + + +import torch + +from megatron import mpu +from .communication import broadcast_float_list +from .generation import ( + generate_tokens_probs_and_return_on_first_stage, + score_and_return_on_first_stage, + beam_search_and_return_on_first_stage) +from .tokenization import ( + tokenize_prompts, + detokenize_generations) + +def generate_and_post_process(model, + prompts=None, + tokens_to_generate=0, + return_output_log_probs=False, + top_k_sampling=0, + top_p_sampling=0.0, + top_p_decay=0.0, + top_p_bound=0.0, + temperature=1.0, + add_BOS=False, + use_eod_token_for_early_termination=True, + stop_on_double_eol=False, + stop_on_eol=False, + random_seed=-1): + """Run inference and post-process outputs, i.e., detokenize, + move to cpu and convert to list.""" + + # Main inference. + tokens, lengths, output_log_probs = generate( + model, + prompts=prompts, + tokens_to_generate=tokens_to_generate, + return_output_log_probs=return_output_log_probs, + top_k_sampling=top_k_sampling, + top_p_sampling=top_p_sampling, + top_p_decay=top_p_decay, + top_p_bound=top_p_bound, + temperature=temperature, + add_BOS=add_BOS, + use_eod_token_for_early_termination=use_eod_token_for_early_termination, + stop_on_double_eol=stop_on_double_eol, + stop_on_eol=stop_on_eol, + random_seed=random_seed) + + # Only post-process on first stage. + if mpu.is_pipeline_first_stage(): + tokens, prompts_plus_generations, prompts_plus_generations_segments = \ + detokenize_generations(tokens, lengths, True) + + if return_output_log_probs: + output_log_probs = output_log_probs.cpu().numpy().tolist() + for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)): + output_log_probs[i] = prob[:len(seg)-1] + + return prompts_plus_generations, prompts_plus_generations_segments, \ + output_log_probs, tokens + + return None + +def generate(model, + prompts=None, + tokens_to_generate=0, + return_output_log_probs=False, + top_k_sampling=0, + top_p_sampling=0.0, + top_p_decay=0.0, + top_p_bound=0.0, + temperature=1.0, + add_BOS=False, + use_eod_token_for_early_termination=True, + stop_on_double_eol=False, + stop_on_eol=False, + random_seed=-1): + """Given prompts and input parameters, run inference and return: + tokens: prompts plus the generated tokens. + lengths: length of the prompt + generations. Note that we can + discard tokens in the tokens tensor that are after the + corresponding length. + output_log_probs: log probs of the tokens. + """ + + # Make sure input params are avaialble to all ranks. + values = [tokens_to_generate, + return_output_log_probs, + top_k_sampling, top_p_sampling, top_p_decay, top_p_bound, + temperature, add_BOS, use_eod_token_for_early_termination, + stop_on_double_eol, + stop_on_eol, + random_seed] + values_float_tensor = broadcast_float_list(12, float_list=values) + tokens_to_generate = int(values_float_tensor[0].item()) + return_output_log_probs = bool(values_float_tensor[1].item()) + top_k_sampling = int(values_float_tensor[2].item()) + top_p_sampling = values_float_tensor[3].item() + top_p_decay = values_float_tensor[4].item() + top_p_bound = values_float_tensor[5].item() + temperature = values_float_tensor[6].item() + add_BOS = bool(values_float_tensor[7].item()) + use_eod_token_for_early_termination = bool(values_float_tensor[8].item()) + stop_on_double_eol = bool(values_float_tensor[9].item()) + stop_on_eol = bool(values_float_tensor[10].item()) + random_seed = int(values_float_tensor[11].item()) + + if random_seed != -1: + torch.random.manual_seed(random_seed) + + # Tokenize prompts and get the batch. + # Note that these tensors are broadcaseted to all ranks. + if torch.distributed.get_rank() == 0: + assert prompts is not None + + context_tokens_tensor, context_length_tensor = tokenize_prompts( + prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS) + + if tokens_to_generate == 0: + return score_and_return_on_first_stage( + model, context_tokens_tensor, context_length_tensor) + + # Main inference function. + # Note that the outputs are available on the first stage. + return generate_tokens_probs_and_return_on_first_stage( + model, context_tokens_tensor, context_length_tensor, + return_output_log_probs=return_output_log_probs, + top_k=top_k_sampling, + top_p=top_p_sampling, + top_p_decay=top_p_decay, + top_p_bound=top_p_bound, + temperature=temperature, + use_eod_token_for_early_termination=use_eod_token_for_early_termination, + stop_on_double_eol=stop_on_double_eol, + stop_on_eol=stop_on_eol) + +def beam_search_and_post_process(model, + prompts=None, + tokens_to_generate=0, + beam_size=0, + add_BOS=False, + stop_token=50256, + num_return_gen=1, + length_penalty=1): + """Run beam search and post-process outputs, i.e., detokenize, + move to cpu and convert to list.""" + + # Main inference. + tokens, scores = beam_search(model, + prompts=prompts, + tokens_to_generate=tokens_to_generate, + beam_size=beam_size, + add_BOS=add_BOS, + stop_token=stop_token, + num_return_gen=num_return_gen, + length_penalty=length_penalty) + # Only post-process on first stage. + if mpu.is_pipeline_first_stage(): + lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device()) + tokens, prompts_plus_generations, prompts_plus_generations_segments = detokenize_generations(tokens, lengths, True) + scores = scores.cpu().numpy().tolist() + return prompts_plus_generations, prompts_plus_generations_segments, scores + + return None + +def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256, num_return_gen=1, length_penalty=1): + # Make sure input params are avaialble to all ranks. + values = [tokens_to_generate, + beam_size, + add_BOS, + stop_token, + num_return_gen, + length_penalty] + values_float_tensor = broadcast_float_list(6, float_list=values) + tokens_to_generate = int(values_float_tensor[0].item()) + beam_size = int(values_float_tensor[1].item()) + add_BOS = bool(values_float_tensor[2].item()) + stop_token = int(values_float_tensor[3].item()) + num_return_gen = int(values_float_tensor[4].item()) + length_penalty = values_float_tensor[5].item() + + context_tokens_tensor, context_length_tensor = tokenize_prompts( + prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS) + + return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor, + beam_size, stop_token=stop_token, num_return_gen=num_return_gen, length_penalty=length_penalty) diff --git a/megatron/text_generation/beam_utils.py b/megatron/text_generation/beam_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..911a64143a86c8521abd9741df22de528a82f692 --- /dev/null +++ b/megatron/text_generation/beam_utils.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +## from huggingface beam search +class BeamHypotheses(object): + def __init__(self, num_beams, length_penalty=1.0, early_stopping=False): + """ + Initialize n-best list of hypotheses. + """ + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.num_beams = num_beams + self.beams = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.beams) + + def add(self, hyp, sum_logprobs, length): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / length ** self.length_penalty + if len(self) < self.num_beams or score > self.worst_score: + self.beams.append((score, hyp)) + if len(self) > self.num_beams: + sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) + del self.beams[sorted_scores[0][1]] + self.worst_score = sorted_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs, cur_len): + """ + If there are enough hypotheses and that none of the hypotheses being generated + can become better than the worst one in the heap, then we are done with this sentence. + """ + + if len(self) < self.num_beams: + return False + elif self.early_stopping: + return True + else: + cur_score = best_sum_logprobs / cur_len ** self.length_penalty + ret = self.worst_score >= cur_score + return ret + diff --git a/megatron/text_generation/communication.py b/megatron/text_generation/communication.py new file mode 100644 index 0000000000000000000000000000000000000000..198ca1406501c623f956c5958a51967071938178 --- /dev/null +++ b/megatron/text_generation/communication.py @@ -0,0 +1,198 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Communications utilities.""" + + +import torch + +from megatron import mpu + + + +# TODO: use functions from megatron/p2p +def recv_from_prev_pipeline_rank_(recv_buffer=None): + """Receive from previous pipeline stage and update the + input buffer inplace.""" + if not mpu.is_pipeline_first_stage(): + assert recv_buffer is not None + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_buffer, + mpu.get_pipeline_model_parallel_prev_rank()) + reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() + + + +# TODO: use functions from megatron/p2p +def send_to_next_pipeline_rank(tensor=None): + """Send output to the next pipeline stage.""" + if not mpu.is_pipeline_last_stage(): + assert tensor is not None + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, + mpu.get_pipeline_model_parallel_next_rank()) + reqs = torch.distributed.batch_isend_irecv([send_next_op]) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() + + + +def _is_cuda(tensor): + """Check if a tensor is not none and is cuda.""" + assert tensor is not None + assert tensor.is_cuda + + + +def _is_cuda_contiguous(tensor): + """Check if a tensor is not none, is cuda, and is contiguous.""" + _is_cuda(tensor) + assert tensor.is_contiguous() + + + +def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): + """Broadcast a tensor from last pipeline stage to all ranks.""" + + is_last_stage = mpu.is_pipeline_last_stage() + # If first stage and last state are the same, then there is no + # pipeline parallelism and no need to communicate. + if mpu.is_pipeline_first_stage() and is_last_stage: + return tensor + + if is_last_stage: + _is_cuda_contiguous(tensor) + else: + tensor = torch.empty(size, + dtype=dtype, + device=torch.cuda.current_device()) + # Get the group and corresponding source rank. + src = mpu.get_pipeline_model_parallel_last_rank() + group = mpu.get_pipeline_model_parallel_group() + torch.distributed.broadcast(tensor, src, group) + + return tensor + + + +def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): + """Broadcast tensor values from last stage into the first stage.""" + + is_last_stage = mpu.is_pipeline_last_stage() + is_first_stage = mpu.is_pipeline_first_stage() + # If first stage and last state are the same, then there is no + # pipeline parallelism and no need to communicate. + if is_first_stage and is_last_stage: + return tensor + # Only first and last stage pipeline stages need to be involved. + if is_last_stage or is_first_stage: + if is_last_stage: + _is_cuda_contiguous(tensor) + else: + tensor = torch.empty(size, + dtype=dtype, + device=torch.cuda.current_device()) + src = mpu.get_pipeline_model_parallel_last_rank() + group = mpu.get_embedding_group() + # Broadcast from last stage into the first stage. + torch.distributed.broadcast(tensor, src, group) + else: + tensor = None + + return tensor + + + +def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): + """Copy tensor values from last stage into the first stage. + Note that the input tensor is updated in place.""" + + is_last_stage = mpu.is_pipeline_last_stage() + is_first_stage = mpu.is_pipeline_first_stage() + # If first stage and last state are the same, then there is no + # pipeline parallelism and no need to communicate. + if is_first_stage and is_last_stage: + return + # Only first and last stage pipeline stages need to be involved. + if is_last_stage or is_first_stage: + _is_cuda(tensor) + is_contiguous = tensor.is_contiguous() + src = mpu.get_pipeline_model_parallel_last_rank() + group = mpu.get_embedding_group() + if is_contiguous: + tensor_ = tensor + else: + if is_last_stage: + tensor_ = tensor.contiguous() + else: + tensor_ = torch.empty(size, + dtype=dtype, + device=torch.cuda.current_device()) + # Broadcast from last stage into the first stage. + torch.distributed.broadcast(tensor_, src, group) + # Update the first stage tensor + if is_first_stage and not is_contiguous: + tensor[...] = tensor_ + + + +def broadcast_tensor(size, dtype, tensor=None, rank=0): + """ Given size and type of a tensor on all ranks and the tensor value + only on a specific rank, broadcast from that rank to all other ranks. + """ + + if torch.distributed.get_rank() == rank: + _is_cuda_contiguous(tensor) + else: + tensor = torch.empty(size, + dtype=dtype, + device=torch.cuda.current_device()) + + torch.distributed.broadcast(tensor, rank) + + return tensor + + + +def broadcast_list(size, dtype, list_values=None, rank=0): + """Broadcast a list of values with a given type.""" + + tensor = None + if torch.distributed.get_rank() == rank: + tensor = torch.tensor(list_values, dtype=dtype, + device=torch.cuda.current_device()) + + return broadcast_tensor(size, dtype, tensor=tensor, rank=rank) + + + +def broadcast_int_list(size, int_list=None, rank=0): + """Broadcast a list of interger values.""" + + return broadcast_list(size, torch.int64, list_values=int_list, rank=rank) + + + +def broadcast_float_list(size, float_list=None, rank=0): + """Broadcast a list of float values.""" + + return broadcast_list(size, torch.float32, list_values=float_list, + rank=rank) diff --git a/megatron/text_generation/forward_step.py b/megatron/text_generation/forward_step.py new file mode 100644 index 0000000000000000000000000000000000000000..763081dadada03442f80aaaae030074205900d4e --- /dev/null +++ b/megatron/text_generation/forward_step.py @@ -0,0 +1,219 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Forward step utilities.""" + +from collections.abc import Iterable + +import torch + +from megatron import ( + get_args, + mpu) +from .communication import ( + send_to_next_pipeline_rank, + recv_from_prev_pipeline_rank_) + + + +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + def __init__(self, max_batch_size, max_sequence_len): + """Note that offsets are set to zero and we always set the + flag to allocate memory. After the first call, make sure to + set this flag to False.""" + self.max_sequence_len = max_sequence_len + self.max_batch_size = max_batch_size + self.sequence_len_offset = 0 + self.batch_size_offset = 0 + self.key_value_memory_dict = {} + + def swap_key_value_dict(self, batch_idx): + "swap between batches" + if len(self.key_value_memory_dict) == 0: + raise ValueError("should not swap when dict in empty") + + for layer_number in self.key_value_memory_dict.keys(): + inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] + assert len(batch_idx) == inference_key_memory.shape[1] ## make sure batch size is the same + new_inference_key_memory = inference_key_memory[:, batch_idx] + new_inference_value_memory = inference_value_memory[:, batch_idx] + self.key_value_memory_dict[layer_number] = ( + new_inference_key_memory, new_inference_value_memory) + +class ForwardStep: + """Forward step function with all the communications. + We use a class here to hide the inference parameters + from the outside caller.""" + + def __init__(self, model, max_batch_size, max_sequence_len): + """Set values so we don't need to do it multiple times.""" + # Make sure model is in eval mode. + assert not isinstance(model, Iterable), \ + 'interleaving schedule is not supported for inference' + model.eval() + self.model = model + # Initialize inference parameters. + self.inference_params = InferenceParams(max_batch_size, + max_sequence_len) + # Pipelining arguments. + args = get_args() + self.pipeline_size_larger_than_one = ( + args.pipeline_model_parallel_size > 1) + # Threshold of pipelining. + self.pipelining_batch_x_seqlen = \ + args.inference_batch_times_seqlen_threshold + + + def __call__(self, tokens, position_ids, attention_mask): + """Invocation of the forward methods. Note that self.inference_params + is being modified by the forward step.""" + # Pipelining case. + if self.pipeline_size_larger_than_one: + current_batch_x_seqlen = tokens.size(0) * tokens.size(1) + if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen: + micro_batch_size = \ + max(1, self.pipelining_batch_x_seqlen // tokens.size(1)) + return _with_pipelining_forward_step(self.model, + tokens, + position_ids, + attention_mask, + self.inference_params, + micro_batch_size) + + return _no_pipelining_forward_step(self.model, + tokens, + position_ids, + attention_mask, + self.inference_params) + + + +def _get_recv_buffer_dtype(args): + """Receive happens between the layers.""" + if args.fp32_residual_connection: + return torch.float + return args.params_dtype + + + +def _allocate_recv_buffer(batch_size, sequence_length): + """Receive happens between the layers with size [s, b, h].""" + if mpu.is_pipeline_first_stage(): + return None + args = get_args() + recv_size = (sequence_length, batch_size, args.hidden_size) + return torch.empty(recv_size, + dtype=_get_recv_buffer_dtype(args), + device=torch.cuda.current_device()) + + + +def _forward_step_helper(model, tokens, position_ids, attention_mask, + inference_params, recv_buffer=None): + """Single forward step. Update the allocate memory flag so + only the first time the memory is allocated.""" + batch_size = tokens.size(0) + sequence_length = tokens.size(1) + if recv_buffer is None: + recv_buffer = _allocate_recv_buffer(batch_size, sequence_length) + + # Receive from previous stage. + recv_from_prev_pipeline_rank_(recv_buffer) + + # Forward pass through the model. + model.set_input_tensor(recv_buffer) + output_tensor = model(tokens, position_ids, attention_mask, + inference_params=inference_params) + + # Send output to the next stage. + send_to_next_pipeline_rank(output_tensor) + + return output_tensor + + + +def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask, + inference_params, recv_buffer=None): + """If recv_buffer is none, we will allocate one on the fly.""" + # Run a simple forward pass. + output_tensor = _forward_step_helper(model, tokens, position_ids, + attention_mask, inference_params, + recv_buffer=recv_buffer) + # Update the sequence length offset. + inference_params.sequence_len_offset += tokens.size(1) + + logits = None + if mpu.is_pipeline_last_stage(): + logits = output_tensor + + return logits + + + +def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask, + inference_params, micro_batch_size): + """No interleaving is supported.""" + sequence_length = tokens.size(1) + batch_size = tokens.size(0) + + # Divide the batch dimension into micro batches. + num_micro_batches, last_chunk = divmod(batch_size, + micro_batch_size) + if last_chunk > 0: + num_micro_batches += 1 + + # Preallocate memory for output logits. + logits = None + if mpu.is_pipeline_last_stage(): + args = get_args() + logits = torch.empty( + (batch_size, sequence_length, args.padded_vocab_size), + dtype=torch.float32, device=torch.cuda.current_device()) + + # Preallocate recv buffer. + recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length) + + for micro_batch_index in range(num_micro_batches): + # Slice among the batch dimenion. + start = micro_batch_index * micro_batch_size + end = min(start + micro_batch_size, batch_size) + this_micro_batch_size = end - start + tokens2use = tokens[start:end, ...] + position_ids2use = position_ids[start:end, ...] + + # Run a simple forward pass. + if this_micro_batch_size != micro_batch_size: + recv_buffer = None + output = _forward_step_helper(model, tokens2use, position_ids2use, + attention_mask, inference_params, + recv_buffer=recv_buffer) + + # Adjust the batch size offset to account for the micro-batch. + inference_params.batch_size_offset += this_micro_batch_size + + # Copy logits. + if mpu.is_pipeline_last_stage(): + logits[start:end, ...] = output + + # Once we are done with all the micro-batches, we can + # adjust the sequence length offset. + inference_params.sequence_len_offset += sequence_length + # and reset the batch size offset + inference_params.batch_size_offset = 0 + + return logits diff --git a/megatron/text_generation/generation.py b/megatron/text_generation/generation.py new file mode 100644 index 0000000000000000000000000000000000000000..2b81e89d0a68a92da1cfdc5731a4bd96d368e949 --- /dev/null +++ b/megatron/text_generation/generation.py @@ -0,0 +1,431 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generation utilities.""" + +import torch +import torch.nn.functional as F + +from megatron import get_args, get_tokenizer, mpu +from megatron.utils import get_ltor_masks_and_position_ids +from .communication import ( + copy_from_last_to_first_pipeline_stage, + broadcast_from_last_pipeline_stage, + broadcast_from_last_to_first_pipeline_stage) +from .forward_step import ForwardStep +from .sampling import sample +from .beam_utils import BeamHypotheses + +MAX_TOKENS_TO_OOM = 12000 # (rprenger) Perfect value depends on hardware and network + +def score_and_return_on_first_stage(model, tokens, lengths): + """Function for just scoring. + Arguments: + model: no interleaving is supported. + tokens: prompt tokens extended to be of size [b, max_prompt_length] + lengths: original prompt length, size: [b] + Note: Outside of model, other parameters only need to be available on + rank 0. + Outputs: + output_log_probs: log probability of the selected tokens. size: [b, s] + """ + + args = get_args() + + batch_size = tokens.size(0) + max_prompt_length = lengths.max().item() + assert max_prompt_length == tokens.size(1) + max_sequence_length = min(max_prompt_length, args.max_position_embeddings) + + # forward step. + forward_step = ForwardStep(model, batch_size, max_sequence_length) + + # =================== + # Pre-allocate memory + # =================== + + # Log probability of the sequence (prompt + generated tokens). + output_log_probs = None + output_log_probs_size = (batch_size, max_sequence_length - 1) + + if mpu.is_pipeline_last_stage(): + output_log_probs = torch.empty(output_log_probs_size, + dtype=torch.float32, + device=torch.cuda.current_device()) + + # ============= + # Run infernece + # ============= + with torch.no_grad(): + attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens) + + # logits will be meanigful only in the last pipeline stage. + logits = forward_step(tokens, position_ids, attention_mask) + + if mpu.is_pipeline_last_stage(): + # Always the last stage should have an output. + assert logits is not None + log_probs = F.log_softmax(logits, dim=2) + + # Pick the tokens that we need to get the log + # probabilities for. Note that next input token is + # the token which we selected in the current logits, + # so shift by 1. + indices = torch.unsqueeze(tokens[:, 1:], 2) + output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2) + + # ====================================== + # Broadcast to the first pipeline stage. + # ====================================== + output_log_probs = broadcast_from_last_to_first_pipeline_stage( + output_log_probs_size, torch.float32, output_log_probs) + + return tokens, lengths, output_log_probs + +def generate_tokens_probs_and_return_on_first_stage( + model, tokens, lengths, + return_output_log_probs=False, + top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0, + temperature=1.0, + use_eod_token_for_early_termination=True, + stop_on_double_eol=False, + stop_on_eol=False + ): + """Main token generation function. + Arguments: + model: no interleaving is supported. + tokens: prompt tokens extended to be of size [b, max-sequence-length] + lengths: original prompt length, size: [b] + return_output_log_probs: flag to calculate the log probability of + the generated tokens. Note that the log probability is the one + from the original logit. + top_k, top_p: top-k and top-p sampling parameters. + Note that top-k = 1 is gready. Also, these paramters are + exclusive meaning that: + if top-k > 0 then we expect top-p=0. + if top-p > 0 then we check for top-k=0. + temperature: sampling temperature. + use_eod_token_for_early_termination: if True, do early termination if + all the sequences have reached this token. + Note: Outside of model, other parameters only need to be available on + rank 0. + Outputs: Note that is size is adjusted to a lower value than + max-sequence-length if generation is terminated early. + tokens: prompt and generated tokens. size: [b, :] + generated_sequence_lengths: total length (including prompt) of + the generated sequence. size: [b] + output_log_probs: log probability of the selected tokens. size: [b, s] + """ + + args = get_args() + tokenizer = get_tokenizer() + + batch_size = tokens.size(0) + min_prompt_length = lengths.min().item() + max_sequence_length = tokens.size(1) + + if max_sequence_length > args.max_position_embeddings: + raise ValueError("Length of prompt + tokens_to_generate longer than allowed") + + if max_sequence_length * batch_size >= MAX_TOKENS_TO_OOM: + raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(MAX_TOKENS_TO_OOM)) + + # forward step. + forward_step = ForwardStep(model, batch_size, max_sequence_length) + + # Added termination_id to support the case that we want to terminate the + # generation once that id is generated. + if hasattr(args, 'eos_id'): + termination_id = args.eos_id + else: + termination_id = tokenizer.eod + + # =================== + # Pre-allocate memory + # =================== + + # Log probability of the sequence (prompt + generated tokens). + output_log_probs = None + output_log_probs_size = (batch_size, max_sequence_length - 1) + # Lengths of generated seuquence including including prompts. + generated_sequence_lengths = None + if mpu.is_pipeline_last_stage(): + if return_output_log_probs: + output_log_probs = torch.empty(output_log_probs_size, + dtype=torch.float32, + device=torch.cuda.current_device()) + generated_sequence_lengths = torch.ones( + batch_size, dtype=torch.int64, + device=torch.cuda.current_device()) * max_sequence_length + + # Whether we have reached a termination id. + is_generation_done = torch.zeros(batch_size, dtype=torch.uint8, + device=torch.cuda.current_device()) + + # ============= + # Run infernece + # ============= + + with torch.no_grad(): + attention_mask, position_ids = _build_attention_mask_and_position_ids( + tokens) + prev_context_length = 0 + for context_length in range(min_prompt_length, max_sequence_length): + + # Pick the slice that we need to pass through the network. + tokens2use = tokens[:, prev_context_length:context_length] + positions2use = position_ids[:, prev_context_length:context_length] + attention_mask2use = attention_mask[ + ..., prev_context_length:context_length, :context_length] + + # logits will be meanigful only in the last pipeline stage. + logits = forward_step(tokens2use, positions2use, attention_mask2use) + + if mpu.is_pipeline_last_stage(): + # Always the last stage should have an output. + assert logits is not None + + # Sample. + last_token_logits = logits[:, -1, :] + new_sample = sample(last_token_logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + vocab_size=tokenizer.vocab_size) + if top_p > 0.0 and top_p_decay > 0.0: + top_p = top_p * top_p_decay + if top_p_bound > 0.0: + top_p = max(top_p, top_p_bound) + + # If a prompt length is smaller or equal th current context + # length, it means we have started generating tokens + started = lengths <= context_length + # Update the tokens. + tokens[started, context_length] = new_sample[started] + + # Calculate the log probabilities. + if return_output_log_probs: + log_probs = F.log_softmax(logits, dim=2) + if return_output_log_probs: + # Pick the tokens that we need to get the log + # probabilities for. Note that next input token is + # the token which we selected in the current logits, + # so shift by 1. + indices = torch.unsqueeze( + tokens[ + :, + (prev_context_length + 1):(context_length + 1)], + 2) + output_log_probs[:, + prev_context_length:context_length] = \ + torch.gather(log_probs, 2, indices).squeeze(2) + + # Update the tokens on the first stage so the next input to + # the network is correct. + copy_from_last_to_first_pipeline_stage(batch_size, torch.int64, + tokens[:, context_length]) + + # Update the context length for the next token generation. + prev_context_length = context_length + + # Check if all the sequences have hit the termination_id. + done = None + if mpu.is_pipeline_last_stage(): + # TODO(rprenger) These stopping methods are tokenizer dependent + # instead tokenization should be in the inference loop so stop sequences can be used + if stop_on_double_eol: + hit_double_eol = (new_sample == 628).byte() & started.byte() + hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte() + done_token = hit_double_eol | hit_two_eols + elif stop_on_eol: + hit_double_eol = (new_sample == 628).byte() & started.byte() + hit_eol = (new_sample == 198).byte() & started.byte() + done_token = hit_double_eol | hit_eol + else: + done_token = (new_sample == termination_id).byte() & \ + started.byte() + + just_finished = (done_token & ~is_generation_done).bool() + generated_sequence_lengths[just_finished.view(-1)] = \ + context_length + 1 + is_generation_done = is_generation_done | done_token + done = torch.all(is_generation_done) + done = broadcast_from_last_pipeline_stage(1, torch.uint8, + tensor=done) + if use_eod_token_for_early_termination and done: + break + + # =================================================== + # Update the length of based on max generated length. + # =================================================== + + tokens = tokens[:, :(context_length + 1)] + if mpu.is_pipeline_last_stage(): + if return_output_log_probs: + output_log_probs = output_log_probs[:, :context_length] + + # ====================================== + # Broadcast to the first pipeline stage. + # ====================================== + + generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage( + batch_size, torch.int64, generated_sequence_lengths) + if return_output_log_probs: + output_log_probs_size = (batch_size, context_length) + output_log_probs = broadcast_from_last_to_first_pipeline_stage( + output_log_probs_size, torch.float32, output_log_probs) + + return tokens, generated_sequence_lengths, output_log_probs + +def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty): + args = get_args() + tokenizer = get_tokenizer() + + batch_size = tokens.size(0) + assert(batch_size == 1) + prompt_length = lengths.item() + final_sequence_length = tokens.size(1) + final_sequence_length = min(final_sequence_length, args.max_position_embeddings) + + # If the context is too big, this happens + if prompt_length >= final_sequence_length: + raise ValueError("context length + tokens_to_generate too large") + + # forward step. + forward_step = ForwardStep(model, beam_size, final_sequence_length) + + beam_hyp = BeamHypotheses(beam_size, length_penalty) + best_batches = None + done = torch.zeros(1, dtype=torch.uint8, device=torch.cuda.current_device()) + scores = torch.zeros(beam_size, + dtype=torch.float32, + device=torch.cuda.current_device()).unsqueeze(1) + scores_size_tensor, tokens_size_tensor = None, None + # ============= + # Run infernece + # ============= + with torch.no_grad(): + tokens = tokens.repeat(beam_size, 1) + attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens) + prev_context_length = 0 + for context_length in range(prompt_length, final_sequence_length): + + # Pick the slice that we need to pass through the network. + tokens2use = tokens[:, prev_context_length:context_length] + positions2use = position_ids[:, prev_context_length:context_length] + attention_mask2use = attention_mask[ + ..., prev_context_length:context_length, :context_length] + + # logits will be meanigful only in the last pipeline stage. + logits = forward_step(tokens2use, positions2use, attention_mask2use) + + if mpu.is_pipeline_last_stage(): + vocab_size = logits.size(2) + log_probs = F.log_softmax(logits, dim=2) + new_scores = log_probs[:, -1, :] + scores + + if context_length == prompt_length: # if this is the first one + sorted_scores, indices = torch.sort(new_scores[0,:], descending=True) + else: + sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True) + + best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long() + best_words = indices[:2 * beam_size] % vocab_size + best_scores = sorted_scores[: 2 * beam_size] + + next_beams = [] + for beam_token_rank, (token_id, beam_score, beam_id) in enumerate( + zip(best_words, best_scores, best_beam_ids) + ): + if token_id.item() == stop_token: + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size + if is_beam_token_worse_than_top_num_beams: + continue + beam_hyp.add( + tokens[beam_id].clone(), + beam_score, + context_length + 1 - prompt_length + ) + else: + # add next predicted token since it is not eos_token + next_beams.append((token_id, beam_score, beam_id)) + + if len(next_beams) == beam_size: + break + + if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length): + done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device()) + + best_batches = tokens.new([item[2] for item in next_beams]) + tokens = tokens[best_batches,:] + tokens[:, context_length] = tokens.new([item[0] for item in next_beams]) + scores = scores.new([item[1] for item in next_beams]).unsqueeze(1) + + # torch.distributed.barrier() + done = broadcast_from_last_pipeline_stage(1, torch.uint8, done) + if done: + break + + # Update the tokens on the first stage so the next input to + # the network is correct. + copy_from_last_to_first_pipeline_stage(tokens.size(), torch.int64, + tokens) + + # set inference key values to make it consistent with best beam index + best_batches = broadcast_from_last_pipeline_stage(beam_size, torch.int64, best_batches) + forward_step.inference_params.swap_key_value_dict(best_batches) + + # Update the context length for the next token generation. + prev_context_length = context_length + + if mpu.is_pipeline_last_stage(): + # if cannot find stop token, add open beams to hyps + if not done: + for beam_id in range(beam_size): + beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length) + + # rank based on scores + sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True) + num_return_gen = min(num_return_gen, len(sorted_hyps)) + scores = [sorted_hyps[i][0] for i in range(num_return_gen)] + tokens = [sorted_hyps[i][1] for i in range(num_return_gen)] + scores = torch.stack(scores, dim=0) + tokens = torch.stack(tokens, dim=0) + scores_size_tensor = torch.tensor(scores.shape, dtype=torch.int64, device=torch.cuda.current_device()) + tokens_size_tensor = torch.tensor(tokens.shape, dtype=torch.int64, device=torch.cuda.current_device()) + + scores_size_tensor = broadcast_from_last_pipeline_stage(1, torch.int64, scores_size_tensor) + tokens_size_tensor = broadcast_from_last_pipeline_stage(2, torch.int64, tokens_size_tensor) + + scores = broadcast_from_last_to_first_pipeline_stage(tuple(scores_size_tensor), torch.float32, scores) + tokens = broadcast_from_last_to_first_pipeline_stage(tuple(tokens_size_tensor), torch.int64, tokens) + + return tokens, scores + + +def _build_attention_mask_and_position_ids(tokens): + """Build the attention mask and postition ids for the input tokens.""" + + # Since we are not interested in loss-mask and reset attention/position + # is also False, eod_token is not used so it is safe to set it to None. + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + data=tokens, + eod_token=None, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False) + + return attention_mask, position_ids diff --git a/megatron/text_generation/sampling.py b/megatron/text_generation/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..4809ae3fc57330743772c3243679a25c0d870eda --- /dev/null +++ b/megatron/text_generation/sampling.py @@ -0,0 +1,106 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sampling utilities. +Part of this code is inspired by: + - https://github.com/ari-holtzman/degen/blob/master/gen.py + - https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html +""" + + +import torch + + + +def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf.""" + + filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(filter_, float('-Inf')) + + + +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf.""" + + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Filteration based on the cumulative sum. + filter_ = cumulative_probs > top_p + # This shift by 1 is weird and I cannot justify it. This existed + # in the original implementation: + # https://github.com/ari-holtzman/degen/blob/master/gen.py + # and I guess it is needed so keeping it for now. + filter_[:, 1:] = filter_[:, :-1].clone() + # Make sure we at least have one token to select from. + filter_[..., 0] = 0 + + # Fill in the filtered part + filter_ = filter_.scatter(1, sorted_indices, filter_) + logits.masked_fill_(filter_, float('-Inf')) + + + +def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None): + """ Sample and generate a token. + Note: logits has the dimension [b, v] where b is the batch size + and v is the vocabulary size. + If vocab_size is provided, we will make sure the sample that is + generated is in [0, vocab-size). This will avoid out of vocabulary + generations due to padding. + """ + + # Check logits for consistency. + assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.' + assert logits.type() == 'torch.cuda.FloatTensor', \ + 'input logits should be floats.' + + + # Greedy is just simple argmax. + if top_k == 1: + assert top_p == 0.0, 'cannot set both greedy and top-p samplings.' + samples = torch.argmax(logits, dim=-1) + + # Top-k or top-p sampling. + else: + # Clone so we do not modify the inputs, + logits = logits.clone() + # Apply temperature in place. + if temperature != 1.0: + logits.div_(temperature) + + if top_k > 1: + assert top_p == 0.0, 'cannot set both top-k and top-p samplings.' + assert top_k <= logits.size(1), 'top-k is larger than logit size.' + if vocab_size: + assert top_k < vocab_size, 'top-k is larger than vocab size.' + modify_logits_for_top_k_filtering(logits, top_k) + + elif top_p > 0.0: + assert top_p <= 1.0, 'top-p should be in (0, 1].' + modify_logits_for_top_p_filtering(logits, top_p) + + # After filtering, we need to recalculate the distribution. + probs = logits.softmax(dim=-1) + samples = torch.multinomial(probs, num_samples=1).view(-1) + + # If vocab size is provided, make sure the samples are in + # in the range [0, vocab-size). + if vocab_size: + samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) + + return samples diff --git a/megatron/text_generation/tokenization.py b/megatron/text_generation/tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..623a96c619023406179e38062c6f4bcfeacfdcf2 --- /dev/null +++ b/megatron/text_generation/tokenization.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization utilities.""" + + +import torch + + +from megatron import get_tokenizer +from .communication import broadcast_int_list, broadcast_tensor + + +def detokenize_generations(tokens_gpu_tensor, + lengths_gpu_tensor, + return_segments): + """Detokenize the generated tokens.""" + + tokenizer = get_tokenizer() + + prompts_plus_generations = [] + if return_segments: + prompts_plus_generations_segments = [] + + tokens = tokens_gpu_tensor.cpu().numpy().tolist() + lengths = lengths_gpu_tensor.cpu().numpy().tolist() + for sequence_tokens, length in zip(tokens, lengths): + sequence_tokens = sequence_tokens[:length] + prompts_plus_generations.append( + tokenizer.detokenize(sequence_tokens)) + if return_segments: + words = [] + for token in sequence_tokens: + word = tokenizer.tokenizer.decoder[token] + word = bytearray( + [tokenizer.tokenizer.byte_decoder[c] for c in word]).decode( + 'utf-8', errors='replace') + words.append(word) + prompts_plus_generations_segments.append(words) + + if return_segments: + return tokens, prompts_plus_generations, \ + prompts_plus_generations_segments + + return tokens, prompts_plus_generations + + +def tokenize_prompts(prompts=None, tokens_to_generate=None, + add_BOS=None, rank=0): + """Tokenize prompts and make them avaiable on all ranks.""" + + # On all ranks set to None so we can pass them to functions + sizes_list = None + prompts_tokens_cuda_long_tensor = None + prompts_length_cuda_long_tensor = None + + # On the specified rank, build the above. + if torch.distributed.get_rank() == rank: + assert prompts is not None + assert tokens_to_generate is not None + # Tensor of tokens padded and their unpadded length. + prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor = \ + _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS) + # We need the sizes of these tensors for the boradcast + sizes_list = [prompts_tokens_cuda_long_tensor.size(0), # Batch size + prompts_tokens_cuda_long_tensor.size(1)] # Sequence lenght + + # First, broadcast the sizes. + sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=rank) + + # Now that we have the sizes, we can boradcast the tokens + # and length tensors. + sizes = sizes_tensor.tolist() + prompts_tokens_cuda_long_tensor = broadcast_tensor( + sizes, torch.int64, tensor=prompts_tokens_cuda_long_tensor, rank=rank) + prompts_length_cuda_long_tensor = broadcast_tensor( + sizes[0], torch.int64, tensor=prompts_length_cuda_long_tensor, + rank=rank) + + return prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor + + +def _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS): + """Given a set of prompts and number of tokens to generate: + - tokenize prompts + - set the sequence length to be the max of length of prompts + plus the number of tokens we would like to generate + - pad all the sequences to this length so we can convert them + into a 2D tensor. + """ + + # Tokenize all the prompts. + tokenizer = get_tokenizer() + if add_BOS: + prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt) + for prompt in prompts] + else: + prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts] + + # Now we have a list of list of tokens which each list has a different + # size. We want to extend this list to: + # - incorporate the tokens that need to be generated + # - make all the sequences equal length. + # Get the prompts length. + prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens] + # Get the max prompts length. + max_prompt_len = max(prompts_length) + # Number of tokens in the each sample of the batch. + samples_length = max_prompt_len + tokens_to_generate + # Now update the list of list to be of the same size: samples_length. + for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length): + padding_size = samples_length - prompt_length + prompt_tokens.extend([tokenizer.eod] * padding_size) + + # Now we are in a structured format, we can convert to tensors. + prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens) + prompts_length_tensor = torch.cuda.LongTensor(prompts_length) + + return prompts_tokens_tensor, prompts_length_tensor diff --git a/megatron/text_generation_server.py b/megatron/text_generation_server.py new file mode 100644 index 0000000000000000000000000000000000000000..cad5c34bcfc58ba18ed769ff438596266a92e2e5 --- /dev/null +++ b/megatron/text_generation_server.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import torch +import json +import threading +from flask import Flask, request, jsonify, current_app +from flask_restful import Resource, Api +from megatron import get_args +from megatron.text_generation import generate_and_post_process +from megatron.text_generation import beam_search_and_post_process + + +GENERATE_NUM = 0 +BEAM_NUM = 1 +lock = threading.Lock() + +class MegatronGenerate(Resource): + def __init__(self, model): + self.model = model + + @staticmethod + def send_do_generate(): + choice = torch.cuda.LongTensor([GENERATE_NUM]) + torch.distributed.broadcast(choice, 0) + + @staticmethod + def send_do_beam_search(): + choice = torch.cuda.LongTensor([BEAM_NUM]) + torch.distributed.broadcast(choice, 0) + + def put(self): + args = get_args() + + if not "prompts" in request.get_json(): + return "prompts argument required", 400 + + if "max_len" in request.get_json(): + return "max_len is no longer used. Replace with tokens_to_generate", 400 + + if "sentences" in request.get_json(): + return "sentences is no longer used. Replace with prompts", 400 + + prompts = request.get_json()["prompts"] + if len(prompts) > 128: + return "Maximum number of prompts is 128", 400 + + tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow + if "tokens_to_generate" in request.get_json(): + tokens_to_generate = request.get_json()["tokens_to_generate"] + if not isinstance(tokens_to_generate, int): + return "tokens_to_generate must be an integer greater than 0" + if tokens_to_generate < 0: + return "tokens_to_generate must be an integer greater than or equal to 0" + + logprobs = False + if "logprobs" in request.get_json(): + logprobs = request.get_json()["logprobs"] + if not isinstance(logprobs, bool): + return "logprobs must be a boolean value" + + if tokens_to_generate == 0 and not logprobs: + return "tokens_to_generate=0 implies logprobs should be True" + + temperature = 1.0 + if "temperature" in request.get_json(): + temperature = request.get_json()["temperature"] + if not (type(temperature) == int or type(temperature) == float): + return "temperature must be a positive number less than or equal to 100.0" + if not (0.0 < temperature <= 100.0): + return "temperature must be a positive number less than or equal to 100.0" + + top_k = 0.0 + if "top_k" in request.get_json(): + top_k = request.get_json()["top_k"] + if not (type(top_k) == int): + return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000" + if not (0 <= top_k <= 1000): + return "top_k must be equal to or greater than 0 and less than or equal to 1000" + + top_p = 0.0 + if "top_p" in request.get_json(): + top_p = request.get_json()["top_p"] + if not (type(top_p) == float): + return "top_p must be a positive float less than or equal to 1.0" + if top_p > 0.0 and top_k > 0.0: + return "cannot set both top-k and top-p samplings." + if not (0 <= top_p <= 1.0): + return "top_p must be less than or equal to 1.0" + + top_p_decay = 0.0 + if "top_p_decay" in request.get_json(): + top_p_decay = request.get_json()["top_p_decay"] + if not (type(top_p_decay) == float): + return "top_p_decay must be a positive float less than or equal to 1.0" + if top_p == 0.0: + return "top_p_decay cannot be set without top_p" + if not (0 <= top_p_decay <= 1.0): + return "top_p_decay must be less than or equal to 1.0" + + top_p_bound = 0.0 + if "top_p_bound" in request.get_json(): + top_p_bound = request.get_json()["top_p_bound"] + if not (type(top_p_bound) == float): + return "top_p_bound must be a positive float less than or equal to top_p" + if top_p == 0.0: + return "top_p_bound cannot be set without top_p" + if not (0.0 < top_p_bound <= top_p): + return "top_p_bound must be greater than 0 and less than top_p" + + add_BOS = False + if "add_BOS" in request.get_json(): + add_BOS = request.get_json()["add_BOS"] + if not isinstance(add_BOS, bool): + return "add_BOS must be a boolean value" + + if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS: + return "Empty prompts require add_BOS=true" + + stop_on_double_eol = False + if "stop_on_double_eol" in request.get_json(): + stop_on_double_eol = request.get_json()["stop_on_double_eol"] + if not isinstance(stop_on_double_eol, bool): + return "stop_on_double_eol must be a boolean value" + + stop_on_eol = False + if "stop_on_eol" in request.get_json(): + stop_on_eol = request.get_json()["stop_on_eol"] + if not isinstance(stop_on_eol, bool): + return "stop_on_eol must be a boolean value" + + random_seed = -1 + if "random_seed" in request.get_json(): + random_seed = request.get_json()["random_seed"] + if not isinstance(random_seed, int): + return "random_seed must be integer" + if random_seed < 0: + return "random_seed must be a positive integer" + + no_log = False + if "no_log" in request.get_json(): + no_log = request.get_json()["no_log"] + if not isinstance(no_log, bool): + return "no_log must be a boolean value" + + beam_width = None + if "beam_width" in request.get_json(): + beam_width = request.get_json()["beam_width"] + if not isinstance(beam_width, int): + return "beam_width must be integer" + if beam_width < 1: + return "beam_width must be an integer > 1" + if len(prompts) > 1: + return "When doing beam_search, batch size must be 1" + + stop_token=50256 + if "stop_token" in request.get_json(): + stop_token = request.get_json()["stop_token"] + if not isinstance(stop_token, int): + return "stop_token must be an integer" + + length_penalty = 1 + if "length_penalty" in request.get_json(): + length_penalty = request.get_json()["length_penalty"] + if not isinstance(length_penalty, float): + return "length_penalty must be a float" + + with lock: # Need to get lock to keep multiple threads from hitting code + + if not no_log: + print("request IP: " + str(request.remote_addr)) + print(json.dumps(request.get_json()),flush=True) + print("start time: ", datetime.datetime.now()) + + try: + if beam_width is not None: + MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search + response, response_seg, response_scores = \ + beam_search_and_post_process( + self.model, + prompts=prompts, + tokens_to_generate=tokens_to_generate, + beam_size = beam_width, + add_BOS=add_BOS, + stop_token=stop_token, + num_return_gen=beam_width, # Returning whole beam + length_penalty=length_penalty + ) + + return jsonify({"text": response, + "segments": response_seg, + "scores": response_scores}) + else: + MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate + response, response_seg, response_logprobs, _ = \ + generate_and_post_process( + self.model, + prompts=prompts, + tokens_to_generate=tokens_to_generate, + return_output_log_probs=logprobs, + top_k_sampling=top_k, + top_p_sampling=top_p, + top_p_decay=top_p_decay, + top_p_bound=top_p_bound, + temperature=temperature, + add_BOS=add_BOS, + use_eod_token_for_early_termination=True, + stop_on_double_eol=stop_on_double_eol, + stop_on_eol=stop_on_eol, + random_seed=random_seed) + + return jsonify({"text": response, + "segments": response_seg, + "logprobs": response_logprobs}) + + except ValueError as ve: + return ve.args[0] + print("end time: ", datetime.datetime.now()) + + +class MegatronServer(object): + def __init__(self, model): + self.app = Flask(__name__, static_url_path='') + api = Api(self.app) + api.add_resource(MegatronGenerate, '/api', resource_class_args=[model]) + + def run(self, url): + self.app.run(url, threaded=True, debug=False) diff --git a/megatron/tokenizer/__init__.py b/megatron/tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..311f2fdca99ee28b1bd2864649b626a8dad1c7e2 --- /dev/null +++ b/megatron/tokenizer/__init__.py @@ -0,0 +1,17 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .tokenizer import build_tokenizer diff --git a/megatron/tokenizer/bert_tokenization.py b/megatron/tokenizer/bert_tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..99f9a87958d93bc96b7e48cd75dba67ddaf7e038 --- /dev/null +++ b/megatron/tokenizer/bert_tokenization.py @@ -0,0 +1,431 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re +import unicodedata +import six + + +def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): + """Checks whether the casing config is consistent with the checkpoint name.""" + + # The casing has to be passed in by the user and there is no explicit check + # as to whether it matches the checkpoint. The casing information probably + # should have been stored in the bert_config.json file, but it's not, so + # we have to heuristically detect it to validate. + + if not init_checkpoint: + return + + m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) + if m is None: + return + + model_name = m.group(1) + + lower_models = [ + "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", + "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" + ] + + cased_models = [ + "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", + "multi_cased_L-12_H-768_A-12" + ] + + is_bad_config = False + if model_name in lower_models and not do_lower_case: + is_bad_config = True + actual_flag = "False" + case_name = "lowercased" + opposite_flag = "True" + + if model_name in cased_models and do_lower_case: + is_bad_config = True + actual_flag = "True" + case_name = "cased" + opposite_flag = "False" + + if is_bad_config: + raise ValueError( + "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " + "However, `%s` seems to be a %s model, so you " + "should pass in `--do_lower_case=%s` so that the fine-tuning matches " + "how the model was pre-training. If this error is wrong, please " + "just comment out this check." % (actual_flag, init_checkpoint, + model_name, case_name, opposite_flag)) + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + elif isinstance(text, unicode): + return text + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def printable_text(text): + """Returns text encoded in a way suitable for print or `tf.logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text + elif isinstance(text, unicode): + return text.encode("utf-8") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r") as reader: + while True: + token = convert_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def convert_by_vocab(vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output + + +def convert_tokens_to_ids(vocab, tokens): + return convert_by_vocab(vocab, tokens) + + +def convert_ids_to_tokens(inv_vocab, ids): + return convert_by_vocab(inv_vocab, ids) + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class FullTokenizer(object): + """Runs end-to-end tokenziation.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_tokens_to_ids(self, tokens): + return convert_by_vocab(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return convert_by_vocab(self.inv_vocab, ids) + + @staticmethod + def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): + """ Converts a sequence of tokens (string) in a single string. """ + + def clean_up_tokenization(out_string): + """ Clean up a list of simple English tokenization artifacts + like spaces before punctuations and abreviated forms. + """ + out_string = ( + out_string.replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") + ) + return out_string + + text = ' '.join(tokens).replace(' ##', '').strip() + if clean_up_tokenization_spaces: + clean_text = clean_up_tokenization(text) + return clean_text + else: + return text + + def vocab_size(self): + return len(self.vocab) + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenziation.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = convert_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/megatron/tokenizer/gpt2_tokenization.py b/megatron/tokenizer/gpt2_tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..3f37e449089b8f779c93ac7f034085fd7607bfb0 --- /dev/null +++ b/megatron/tokenizer/gpt2_tokenization.py @@ -0,0 +1,321 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for OpenAI GPT.""" + +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import sys +import json +import logging +import os +import regex as re +from io import open + +try: + from functools import lru_cache +except ImportError: + # Just a dummy decorator to get the checks to run on python2 + # because honestly I don't want to support a byte-level unicode BPE + # tokenizer on python 2 right now. + def lru_cache(): + return lambda func: func + + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", +} +PRETRAINED_MERGES_ARCHIVE_MAP = { + 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", +} +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'gpt2': 1024, +} +VOCAB_NAME = 'vocab.json' +MERGES_NAME = 'merges.txt' +SPECIAL_TOKENS_NAME = 'special_tokens.txt' + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + _chr = unichr if sys.version_info[0] == 2 else chr + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \ + list(range(ord("®"), ord("ÿ") + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [_chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class GPT2Tokenizer(object): + """ + GPT-2 BPE tokenizer. Peculiarities: + - Byte-level BPE + """ + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] + special_tokens_file = None + else: + vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) + merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) + special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) + if not os.path.exists(special_tokens_file): + special_tokens_file = None + else: + logger.info("loading special tokens file {}".format(special_tokens_file)) + # redirect to the cache, if necessary + try: + from .file_utils import cached_path + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find files {} and {} " + "at this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + pretrained_model_name_or_path, + vocab_file, merges_file)) + return None + if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: + logger.info("loading vocabulary file {}".format(vocab_file)) + logger.info("loading merges file {}".format(merges_file)) + else: + logger.info("loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file)) + logger.info("loading merges file {} from cache at {}".format( + merges_file, resolved_merges_file)) + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + if special_tokens_file and 'special_tokens' not in kwargs: + special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + else: + special_tokens = kwargs.pop('special_tokens', []) + tokenizer = cls( + resolved_vocab_file, + resolved_merges_file, + special_tokens=special_tokens, + *inputs, + **kwargs) + return tokenizer + + def __init__(self, vocab_file, merges_file, errors='replace', + special_tokens=None, max_len=None): + self.max_len = max_len if max_len is not None else int(1e12) + self.encoder = json.load(open(vocab_file)) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_data] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for + # capitalized versions of contractions + self.pat = re.compile( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + self.special_tokens = {} + self.special_tokens_decoder = {} + self.set_special_tokens(special_tokens) + + def __len__(self): + return len(self.encoder) + len(self.special_tokens) + + def set_special_tokens(self, special_tokens): + """ Add a list of additional tokens to the encoder. + The additional tokens are indexed starting from the last index of the + current vocabulary in the order of the `special_tokens` list. + """ + if not special_tokens: + self.special_tokens = {} + self.special_tokens_decoder = {} + return + self.special_tokens = dict((tok, len(self.encoder) + i) + for i, tok in enumerate(special_tokens)) + self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()} + logger.info("Special tokens {}".format(self.special_tokens)) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except BaseException: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def tokenize(self, text): + """ Tokenize a string. """ + bpe_tokens = [] + for token in re.findall(self.pat, text): + if sys.version_info[0] == 2: + token = ''.join(self.byte_encoder[ord(b)] for b in token) + else: + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def convert_tokens_to_ids(self, tokens): + """ Converts a sequence of tokens into ids using the vocab. """ + ids = [] + if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.encoder.get(tokens, 0) + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.encoder.get(token, 0)) + if len(ids) > self.max_len: + logger.warning( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this OpenAI GPT model ({} > {}). Running this" + " sequence through the model will result in indexing errors".format( + len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """Converts a sequence of ids in BPE tokens using the vocab.""" + tokens = [] + for i in ids: + if i in self.special_tokens_decoder: + if not skip_special_tokens: + tokens.append(self.special_tokens_decoder[i]) + else: + tokens.append(self.decoder[i]) + return tokens + + def encode(self, text): + return self.convert_tokens_to_ids(self.tokenize(text)) + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) + return text + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary and merge files to a directory.""" + if not os.path.isdir(vocab_path): + logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + return + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + merge_file = os.path.join(vocab_path, MERGES_NAME) + special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) + + with open(vocab_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(self.encoder, ensure_ascii=False)) + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write(u'#version: 0.2\n') + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(merge_file)) + index = token_index + writer.write(' '.join(bpe_tokens) + u'\n') + index += 1 + + index = len(self.encoder) + with open(special_tokens_file, 'w', encoding='utf-8') as writer: + for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + + return vocab_file, merge_file, special_tokens_file diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..13085a81c9feb4f4ba2ffe3f254a3bd7d3f4c144 --- /dev/null +++ b/megatron/tokenizer/tokenizer.py @@ -0,0 +1,291 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron tokenizers.""" + +from abc import ABC +from abc import abstractmethod + +from .bert_tokenization import FullTokenizer as FullBertTokenizer +from .gpt2_tokenization import GPT2Tokenizer + + +def build_tokenizer(args): + """Initialize tokenizer.""" + if args.rank == 0: + print('> building {} tokenizer ...'.format(args.tokenizer_type), + flush=True) + + # Select and instantiate the tokenizer. + assert args.vocab_file is not None + if args.tokenizer_type == 'BertWordPieceLowerCase': + tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, + lower_case=True, + vocab_extra_ids=args.vocab_extra_ids) + elif args.tokenizer_type == 'BertWordPieceCase': + tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, + lower_case=False, + vocab_extra_ids=args.vocab_extra_ids) + elif args.tokenizer_type == 'GPT2BPETokenizer': + assert args.merge_file is not None + tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) + else: + raise NotImplementedError('{} tokenizer is not ' + 'implemented.'.format(args.tokenizer_type)) + + # Add vocab size. + args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, + args) + + return tokenizer + + +def _vocab_size_with_padding(orig_vocab_size, args): + """Pad vocab size so it is divisible by model parallel size and + still having GPU friendly size.""" + + after = orig_vocab_size + multiple = args.make_vocab_size_divisible_by * \ + args.tensor_model_parallel_size + while (after % multiple) != 0: + after += 1 + if args.rank == 0: + print(' > padded vocab (size: {}) with {} dummy tokens ' + '(new size: {})'.format( + orig_vocab_size, after - orig_vocab_size, after), flush=True) + return after + + +class AbstractTokenizer(ABC): + """Abstract class for tokenizer.""" + + def __init__(self, name): + self.name = name + super().__init__() + + @property + @abstractmethod + def vocab_size(self): + pass + + @property + @abstractmethod + def vocab(self): + """Dictionary from vocab text token to id token.""" + pass + + @property + @abstractmethod + def inv_vocab(self): + """Dictionary from vocab id token to text token.""" + pass + + @abstractmethod + def tokenize(self, text): + pass + + def detokenize(self, token_ids): + raise NotImplementedError('detokenizer is not implemented for {} ' + 'tokenizer'.format(self.name)) + + @property + def cls(self): + raise NotImplementedError('CLS is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def sep(self): + raise NotImplementedError('SEP is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def pad(self): + raise NotImplementedError('PAD is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def eod(self): + raise NotImplementedError('EOD is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def mask(self): + raise NotImplementedError('MASK is not provided for {} ' + 'tokenizer'.format(self.name)) + + +class _BertWordPieceTokenizer(AbstractTokenizer): + """Original BERT wordpiece tokenizer.""" + + def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): + if lower_case: + name = 'BERT Lower Case' + else: + name = 'BERT Upper Case' + super().__init__(name) + self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case) + self.cls_id = self.tokenizer.vocab['[CLS]'] + self.sep_id = self.tokenizer.vocab['[SEP]'] + self.pad_id = self.tokenizer.vocab['[PAD]'] + self.mask_id = self.tokenizer.vocab['[MASK]'] + self._additional_special_tokens = [] + + # (dsachan) Add BOS and EOS tokens + SPECIAL_TOKENS = {'eos_token': '[EOS]', + 'bos_token': '[BOS]'} + self._bos_token = '[BOS]' + self.add_token(self._bos_token) + self._bos_token_id = self.vocab.get(self._bos_token) + + self._eos_token = '[EOS]' + self.add_token(self._eos_token) + self._eos_token_id = self.vocab.get(self._eos_token) + + # (dsachan) Add additional special tokens + # These can be used as sentinel tokens in T5 model inputs + additional_special_tokens = [] + additional_special_tokens.extend( + ["".format(i) for i in range(vocab_extra_ids)]) + self.add_additional_special_tokens(additional_special_tokens) + + def add_token(self, token): + if token not in self.vocab: + self.inv_vocab[self.vocab_size] = token + # self.vocab_size comes from len(vocab) + # and it will increase as we add elements + self.vocab[token] = self.vocab_size + + def add_additional_special_tokens(self, tokens_list): + setattr(self, "additional_special_tokens", tokens_list) + for value in tokens_list: + self.add_token(value) + + @property + def vocab_size(self): + return self.tokenizer.vocab_size() + + @property + def vocab(self): + return self.tokenizer.vocab + + @property + def inv_vocab(self): + return self.tokenizer.inv_vocab + + def tokenize(self, text): + text_tokens = self.tokenizer.tokenize(text) + return self.tokenizer.convert_tokens_to_ids(text_tokens) + + def decode(self, ids): + tokens = self.tokenizer.convert_ids_to_tokens(ids) + return self.tokenizer.convert_tokens_to_string(tokens) + + def decode_token_ids(self, token_ids): + tokens = self.tokenizer.convert_ids_to_tokens(token_ids) + exclude_list = ['[PAD]', '[CLS]'] + non_pads = [t for t in tokens if t not in exclude_list] + + result = "" + for s in non_pads: + if s.startswith("##"): + result += s[2:] + else: + result += " " + s + + return result + + @property + def cls(self): + return self.cls_id + + @property + def sep(self): + return self.sep_id + + @property + def pad(self): + return self.pad_id + + @property + def mask(self): + return self.mask_id + + @property + def bos_token(self): + """ Beginning of sentence token id """ + return self._bos_token + + @property + def eos_token(self): + """ End of sentence token id """ + return self._eos_token + + @property + def additional_special_tokens(self): + """ All the additional special tokens you may want to use (list of strings).""" + return self._additional_special_tokens + + @property + def bos_token_id(self): + """ Id of the beginning of sentence token in the vocabulary.""" + return self._bos_token_id + + @property + def eos_token_id(self): + """ Id of the end of sentence token in the vocabulary.""" + return self._eos_token_id + + @property + def additional_special_tokens_ids(self): + """ Ids of all the additional special tokens in the vocabulary (list of integers).""" + return [self.vocab.get(token) for token in self._additional_special_tokens] + + @additional_special_tokens.setter + def additional_special_tokens(self, value): + self._additional_special_tokens = value + + +class _GPT2BPETokenizer(AbstractTokenizer): + """Original GPT2 BPE tokenizer.""" + + def __init__(self, vocab_file, merge_file): + name = 'GPT2 BPE' + super().__init__(name) + + self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace', + special_tokens=[], max_len=None) + self.eod_id = self.tokenizer.encoder['<|endoftext|>'] + + @property + def vocab_size(self): + return len(self.tokenizer.encoder) + + @property + def vocab(self): + return self.tokenizer.encoder + + @property + def inv_vocab(self): + return self.tokenizer.decoder + + def tokenize(self, text): + return self.tokenizer.encode(text) + + def detokenize(self, token_ids): + return self.tokenizer.decode(token_ids) + + @property + def eod(self): + return self.eod_id diff --git a/megatron/training.py b/megatron/training.py new file mode 100644 index 0000000000000000000000000000000000000000..79e8550cbf3a6cc856ec384fc540b3e3e1a41cb7 --- /dev/null +++ b/megatron/training.py @@ -0,0 +1,1011 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pretrain utilities.""" + +from datetime import datetime +import math +import sys +import time +# The earliest we can measure the start time. +_TRAIN_START_TIME = time.time() +import torch +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP +from torch.utils.tensorboard import SummaryWriter + +from megatron import get_args +from megatron import get_signal_handler +from megatron import get_timers +from megatron import get_tensorboard_writer +from megatron import get_current_global_batch_size +from megatron import get_num_microbatches +from megatron import is_last_rank +from megatron import update_num_microbatches +from megatron import mpu +from megatron import print_rank_0 +from megatron import print_rank_last +from megatron.checkpointing import load_checkpoint +from megatron.checkpointing import save_checkpoint +from megatron.model import Float16Module +from megatron.model import ModelType +from megatron.optimizer import get_megatron_optimizer +from megatron.initialize import initialize_megatron +from megatron.initialize import write_args_to_tensorboard +from megatron.initialize import set_jit_fusion_options +from megatron.optimizer_param_scheduler import OptimizerParamScheduler +from megatron.model import DistributedDataParallel as LocalDDP +from megatron.utils import check_adlr_autoresume_termination +from megatron.utils import unwrap_model +from megatron.data.data_samplers import build_pretraining_data_loader +from megatron.utils import calc_params_l2_norm +from megatron.schedules import get_forward_backward_func +from megatron.utils import report_memory +from megatron.model.vision.knn_monitor import compute_feature_bank + + +def print_datetime(string): + """Note that this call will sync across all ranks.""" + torch.distributed.barrier() + time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + print_rank_0('[' + string + '] datetime: {} '.format(time_str)) + + +def pretrain(train_valid_test_dataset_provider, + model_provider, + model_type, + forward_step_func, + process_non_loss_data_func=None, + extra_args_provider=None, + args_defaults={}): + """Main training program. + + This function will run the followings in the order provided: + 1) initialize Megatron. + 2) setup model, optimizer and lr schedule using the model_provider. + 3) call train_val_test_data_provider to get train/val/test datasets. + 4) train the modle using the forward_step_func. + + Arguments: + train_valid_test_dataset_provider: a function that takes the size of + train/valid/test dataset and returns `train, valid, test` datasets. + model_provider: a function that returns a vanilla version of the + model. By vanilla we mean a simple model on cpu with no fp16 or ddp. + model_type: an enum that specifies the type of model being trained. + forward_step_func: a function that takes a `data iterator` and `model`, + and returns a `loss` scalar with a dictionary with key:values being + the info we would like to monitor during training, for example + `lm-loss: value`. We also require that this function add + `batch generator` to the timers class. + process_non_loss_data_func: a function to post process outputs of the + network. It can be used for dumping output tensors (e.g images) to + tensorboard. It takes `collected data`(list of tensors), + `current iteration index` and `tensorboard writer` as arguments. + extra_args_provider: a function that takes a parser and adds arguments + to it. It is used for programs to add their own arguments. + args_defaults: a dictionary from argument-name to argument-value. It + to set already parse arguments. + """ + + # Initalize and get arguments, timers, and Tensorboard writer. + initialize_megatron(extra_args_provider=extra_args_provider, + args_defaults=args_defaults) + # Set pytorch JIT layer fusion options and warmup JIT functions. + set_jit_fusion_options() + + # Adjust the startup time so it reflects the largest value. + # This will be closer to what scheduler will see (outside of + # image ... launches. + global _TRAIN_START_TIME + start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME]) + torch.distributed.all_reduce(start_time_tensor, + op=torch.distributed.ReduceOp.MIN) + _TRAIN_START_TIME = start_time_tensor.item() + print_rank_0('time to initialize megatron (seconds): {:.3f}'.format( + time.time() - _TRAIN_START_TIME)) + print_datetime('after megatron is initialized') + + args = get_args() + timers = get_timers() + + # Model, optimizer, and learning rate. + timers('model-and-optimizer-setup').start() + model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, + model_type) + timers('model-and-optimizer-setup').stop() + print_datetime('after model, optimizer, and learning rate ' + 'scheduler are built') + + # Data stuff. + timers('train/valid/test-data-iterators-setup').start() + if args.virtual_pipeline_model_parallel_size is not None: + all_data_iterators = [ + build_train_valid_test_data_iterators(train_valid_test_dataset_provider) + for _ in range(len(model)) + ] + train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators] + valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators] + test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators] + else: + train_data_iterator, valid_data_iterator, test_data_iterator \ + = build_train_valid_test_data_iterators( + train_valid_test_dataset_provider) + timers('train/valid/test-data-iterators-setup').stop() + print_datetime('after dataloaders are built') + + # Print setup timing. + print_rank_0('done with setup ...') + timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup']) + print_rank_0('training ...') + + iteration = 0 + if args.do_train and args.train_iters > 0: + iteration = train(forward_step_func, + model, optimizer, opt_param_scheduler, + train_data_iterator, valid_data_iterator, + process_non_loss_data_func) + print_datetime('after training is done') + + if args.do_valid: + prefix = 'the end of training for val data' + evaluate_and_print_results(prefix, forward_step_func, + valid_data_iterator, model, + iteration, process_non_loss_data_func, + False) + + if args.save and iteration != 0: + save_checkpoint(iteration, model, optimizer, opt_param_scheduler) + + if args.do_test: + # Run on test data. + prefix = 'the end of training for test data' + evaluate_and_print_results(prefix, forward_step_func, + test_data_iterator, model, + 0, process_non_loss_data_func, + True) + +def update_train_iters(args): + + # For iteration-based training, we don't need to do anything + if args.train_iters: + return + + # Constant batch size with sample-based training. + if args.rampup_batch_size is None: + args.train_iters = args.train_samples // args.global_batch_size + + else: + # Sample based training with rampup batch size. + iterations = 0 + consumed_samples = 0 + # Rampup phase. + while consumed_samples <= int(args.rampup_batch_size[2]): + update_num_microbatches(consumed_samples, consistency_check=False) + consumed_samples += get_current_global_batch_size() + iterations += 1 + # Reset + update_num_microbatches(0, consistency_check=False) + # Constant phase + # Note that we throw away any partial last batch. + iterations += (args.train_samples - consumed_samples) // \ + args.global_batch_size + args.train_iters = iterations + + print_rank_0('setting training iterations to {}'.format(args.train_iters)) + + +def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True): + """Build the model.""" + args = get_args() + args.model_type = model_type + + # Build model. + if mpu.get_pipeline_model_parallel_world_size() > 1 and \ + args.virtual_pipeline_model_parallel_size is not None: + assert model_type != ModelType.encoder_and_decoder, \ + "Interleaved schedule not supported for model with both encoder and decoder" + model = [] + for i in range(args.virtual_pipeline_model_parallel_size): + mpu.set_virtual_pipeline_model_parallel_rank(i) + # Set pre_process and post_process only after virtual rank is set. + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + this_model = model_provider_func( + pre_process=pre_process, + post_process=post_process + ) + this_model.model_type = model_type + model.append(this_model) + else: + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + add_encoder = True + add_decoder = True + if model_type == ModelType.encoder_and_decoder: + if mpu.get_pipeline_model_parallel_world_size() > 1: + assert args.pipeline_model_parallel_split_rank is not None, \ + "Split rank needs to be specified for model with both encoder and decoder" + rank = mpu.get_pipeline_model_parallel_rank() + split_rank = args.pipeline_model_parallel_split_rank + world_size = mpu.get_pipeline_model_parallel_world_size() + pre_process = rank == 0 or rank == split_rank + post_process = (rank == (split_rank - 1)) or ( + rank == (world_size - 1)) + add_encoder = mpu.is_pipeline_stage_before_split() + add_decoder = mpu.is_pipeline_stage_after_split() + model = model_provider_func( + pre_process=pre_process, + post_process=post_process, + add_encoder=add_encoder, + add_decoder=add_decoder) + else: + model = model_provider_func( + pre_process=pre_process, + post_process=post_process + ) + model.model_type = model_type + + if not isinstance(model, list): + model = [model] + + # Set tensor model parallel attributes if not set. + # Only parameters that are already tensor model parallel have these + # attributes set for them. We should make sure the default attributes + # are set for all params so the optimizer can use them. + for model_module in model: + for param in model_module.parameters(): + mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + # Print number of parameters. + if mpu.get_data_parallel_rank() == 0: + print(' > number of parameters on (tensor, pipeline) ' + 'model parallel rank ({}, {}): {}'.format( + mpu.get_tensor_model_parallel_rank(), + mpu.get_pipeline_model_parallel_rank(), + sum([sum([p.nelement() for p in model_module.parameters()]) + for model_module in model])), flush=True) + + # GPU allocation. + for model_module in model: + model_module.cuda(torch.cuda.current_device()) + + # Fp16 conversion. + if args.fp16 or args.bf16: + model = [Float16Module(model_module, args) for model_module in model] + + if wrap_with_ddp: + if args.DDP_impl == 'torch': + i = torch.cuda.current_device() + model = [torchDDP(model_module, device_ids=[i], output_device=i, + process_group=mpu.get_data_parallel_group()) + for model_module in model] + + elif args.DDP_impl == 'local': + model = [LocalDDP(model_module, + args.accumulate_allreduce_grads_in_fp32, + args.use_contiguous_buffers_in_local_ddp) + for model_module in model] + # broad cast params from data parallel src rank to other data parallel ranks + if args.data_parallel_random_init: + for model_module in model: + model_module.broadcast_params() + else: + raise NotImplementedError('Unknown DDP implementation specified: ' + '{}. Exiting.'.format(args.DDP_impl)) + + return model + + +def get_optimizer_param_scheduler(optimizer): + """Build the learning rate scheduler.""" + args = get_args() + + # Iteration-based training. + if args.train_iters: + if args.lr_decay_iters is None: + args.lr_decay_iters = args.train_iters + lr_decay_steps = args.lr_decay_iters * args.global_batch_size + wd_incr_steps = args.train_iters * args.global_batch_size + if args.lr_warmup_fraction is not None: + lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps + else: + lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size + # Sample-based training. + elif args.train_samples: + # We need to set training iters for later use. Technically + # we need to adjust the training samples too (due to last + # batch being incomplete) but we leave it as is for now. + update_train_iters(args) + if args.lr_decay_samples is None: + args.lr_decay_samples = args.train_samples + lr_decay_steps = args.lr_decay_samples + wd_incr_steps = args.train_samples + if args.lr_warmup_fraction is not None: + lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps + else: + lr_warmup_steps = args.lr_warmup_samples + else: + raise Exception( + 'either train-iters or train-samples should be provided.') + + opt_param_scheduler = OptimizerParamScheduler( + optimizer, + max_lr=args.lr, + min_lr=args.min_lr, + lr_warmup_steps=lr_warmup_steps, + lr_decay_steps=lr_decay_steps, + lr_decay_style=args.lr_decay_style, + start_wd=args.start_weight_decay, + end_wd=args.end_weight_decay, + wd_incr_steps=wd_incr_steps, + wd_incr_style=args.weight_decay_incr_style, + use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler, + override_opt_param_scheduler=args.override_opt_param_scheduler) + + return opt_param_scheduler + + +def setup_model_and_optimizer(model_provider_func, + model_type, + no_wd_decay_cond=None, + scale_lr_cond=None, + lr_mult=1.0): + """Setup model and optimizer.""" + args = get_args() + + model = get_model(model_provider_func, model_type) + unwrapped_model = unwrap_model(model, + (torchDDP, LocalDDP, Float16Module)) + + optimizer = get_megatron_optimizer(model, no_wd_decay_cond, + scale_lr_cond, lr_mult) + opt_param_scheduler = get_optimizer_param_scheduler(optimizer) + + if args.load is not None: + timers = get_timers() + # Extra barrier is added to make sure all ranks report the + # max time. + torch.distributed.barrier() + timers('load-checkpoint').start() + args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler) + torch.distributed.barrier() + timers('load-checkpoint').stop() + timers.log(['load-checkpoint']) + else: + args.iteration = 0 + + # We only support local DDP with multiple micro-batches. + if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1: + assert args.DDP_impl == 'local' + + # get model without FP16 and/or TorchDDP wrappers + if args.iteration == 0 and len(unwrapped_model) == 1 \ + and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'): + print_rank_0("Initializing ICT from pretrained BERT model") + unwrapped_model[0].init_state_dict_from_bert() + if args.fp16: + optimizer.reload_model_params() + + return model, optimizer, opt_param_scheduler + + +def train_step(forward_step_func, data_iterator, + model, optimizer, opt_param_scheduler): + """Single training step.""" + args = get_args() + timers = get_timers() + + # Set grad to zero. + if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp: + for partition in model: + partition.zero_grad_buffer() + optimizer.zero_grad() + + # Forward pass. + forward_backward_func = get_forward_backward_func() + losses_reduced = forward_backward_func( + forward_step_func, data_iterator, model, + optimizer, timers, forward_only=False) + + # Empty unused memory. + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + # Reduce gradients. + timers('backward-reduce-model-grads').start() + optimizer.reduce_model_grads(args, timers) + timers('backward-reduce-model-grads').stop() + + # Vision gradients. + if args.vision_pretraining and args.vision_pretraining_type == "dino": + unwrapped_model = unwrap_model(model[0], + (torchDDP, LocalDDP, Float16Module)) + unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) + + # Update parameters. + timers('optimizer').start() + update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers) + timers('optimizer').stop() + + # Gather params. + if update_successful: + timers('backward-gather-model-params').start() + optimizer.gather_model_params(args, timers) + timers('backward-gather-model-params').stop() + + # Vision momentum. + if args.vision_pretraining and args.vision_pretraining_type == "dino": + unwrapped_model = unwrap_model(model[0], + (torchDDP, LocalDDP, Float16Module)) + unwrapped_model.update_momentum(args.curr_iteration) + + # Update learning rate. + if update_successful: + increment = get_num_microbatches() * \ + args.micro_batch_size * \ + args.data_parallel_size + opt_param_scheduler.step(increment=increment) + skipped_iter = 0 + else: + skipped_iter = 1 + + # Empty unused memory. + if args.empty_unused_memory_level >= 2: + torch.cuda.empty_cache() + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + loss_reduced = {} + for key in losses_reduced[0]: + if key == "describe": + continue + losses_reduced_for_key = [x[key] for x in losses_reduced] + loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) + return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad + return {}, skipped_iter, grad_norm, num_zeros_in_grad + + +def training_log(loss_dict, total_loss_dict, learning_rate, iteration, + loss_scale, report_memory_flag, skipped_iter, + grad_norm, params_norm, num_zeros_in_grad, my_writer): + """Log training information such as losses, timing, ....""" + args = get_args() + timers = get_timers() + writer = get_tensorboard_writer() + + # Advanced, skipped, and Nan iterations. + advanced_iters_key = 'advanced iterations' + skipped_iters_key = 'skipped iterations' + nan_iters_key = 'nan iterations' + # Advanced iterations. + if not skipped_iter: + total_loss_dict[advanced_iters_key] = total_loss_dict.get( + advanced_iters_key, 0) + 1 + else: + if advanced_iters_key not in total_loss_dict: + total_loss_dict[advanced_iters_key] = 0 + # Skipped iterations. + total_loss_dict[skipped_iters_key] = total_loss_dict.get( + skipped_iters_key, 0) + skipped_iter + # Update losses and set nan iterations + got_nan = False + for key in loss_dict: + if not skipped_iter: + total_loss_dict[key] = total_loss_dict.get( + key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] + else: + value = loss_dict[key].float().sum().item() + is_nan = value == float('inf') or \ + value == -float('inf') or \ + value != value + got_nan = got_nan or is_nan + total_loss_dict[nan_iters_key] = total_loss_dict.get( + nan_iters_key, 0) + int(got_nan) + + # Logging. + timers_to_log = [] + + def add_to_logging(name): + if name in timers.timers: + timers_to_log.append(name) + # add_to_logging('forward-compute') + # add_to_logging('forward-recv') + # add_to_logging('forward-send') + # add_to_logging('forward-backward-send-forward-backward-recv') + # add_to_logging('backward-compute') + # add_to_logging('backward-recv') + # add_to_logging('backward-send') + # add_to_logging('backward-send-forward-recv') + # add_to_logging('backward-send-backward-recv') + # add_to_logging('backward-params-all-reduce') + # add_to_logging('backward-layernorm-all-reduce') + # add_to_logging('backward-embedding-all-reduce') + # add_to_logging('backward-reduce-model-grads') + # add_to_logging('backward-gather-model-params') + # add_to_logging('optimizer-copy-to-main-grad') + # add_to_logging('optimizer-unscale-and-check-inf') + # add_to_logging('optimizer-clip-main-grad') + # add_to_logging('optimizer-count-zeros') + # add_to_logging('optimizer-inner-step') + # add_to_logging('optimizer-copy-main-to-model-params') + # add_to_logging('optimizer') + # add_to_logging('batch-generator') + + # Calculate batch size. + batch_size = args.micro_batch_size * args.data_parallel_size * \ + get_num_microbatches() + + total_iterations = total_loss_dict[advanced_iters_key] + \ + total_loss_dict[skipped_iters_key] + + # Tensorboard values. + # if writer and (iteration % args.tensorboard_log_interval == 0 ) and \ + # is_last_rank(): + # if args.log_learning_rate_to_tensorboard: + # writer.add_scalar('learning-rate', learning_rate, iteration) + # writer.add_scalar('learning-rate vs samples', learning_rate, + # args.consumed_train_samples) + # if args.log_batch_size_to_tensorboard: + # writer.add_scalar('batch-size', batch_size, iteration) + # writer.add_scalar('batch-size vs samples', batch_size, + # args.consumed_train_samples) + # for key in loss_dict: + # writer.add_scalar(key , loss_dict[key], iteration) + # writer.add_scalar(key + ' vs samples', loss_dict[key], + # args.consumed_train_samples) + # if args.log_loss_scale_to_tensorboard: + # writer.add_scalar('loss-scale', loss_scale, iteration) + # writer.add_scalar('loss-scale vs samples', loss_scale, + # args.consumed_train_samples) + # if args.log_world_size_to_tensorboard: + # writer.add_scalar('world-size', args.world_size, iteration) + # writer.add_scalar('world-size vs samples', args.world_size, + # args.consumed_train_samples) + # if grad_norm is not None: + # writer.add_scalar('grad-norm', grad_norm, iteration) + # writer.add_scalar('grad-norm vs samples', grad_norm, + # args.consumed_train_samples) + # if num_zeros_in_grad is not None: + # writer.add_scalar('num-zeros', num_zeros_in_grad, iteration) + # writer.add_scalar('num-zeros vs samples', num_zeros_in_grad, + # args.consumed_train_samples) + # if params_norm is not None: + # writer.add_scalar('params-norm', params_norm, iteration) + # writer.add_scalar('params-norm vs samples', params_norm, + # args.consumed_train_samples) + # if args.log_timers_to_tensorboard: + # timers.write(timers_to_log, writer, iteration, + # normalizer=total_iterations) + # if args.log_memory_to_tensorboard: + # mem_stats = torch.cuda.memory_stats() + # writer.add_scalar( + # "mem-reserved-bytes", + # mem_stats["reserved_bytes.all.current"], + # iteration, + # ) + # writer.add_scalar( + # "mem-allocated-bytes", + # mem_stats["allocated_bytes.all.current"], + # iteration, + # ) + # writer.add_scalar( + # "mem-allocated-count", + # mem_stats["allocation.all.current"], + # iteration, + # ) + + if my_writer and iteration % args.log_interval == 0 and is_last_rank(): + # record learning rate + my_writer.add_scalar('train/learning-rate', learning_rate, iteration) + + # record loss and ppl + total_train_loss = 0 + for key in loss_dict: + my_writer.add_scalar("train/" + key , loss_dict[key], iteration) + ppl = math.exp(min(20, loss_dict[key])) + my_writer.add_scalar("train/" + key + "_ppl" , ppl, iteration) + total_train_loss += loss_dict[key] + my_writer.add_scalar("train/total-loss", total_train_loss, iteration) + # record loss scaling + my_writer.add_scalar('train/loss-scale', loss_scale, iteration) + + if iteration % args.log_interval == 0: + elapsed_time = timers('interval-time').elapsed() + elapsed_time_per_iteration = elapsed_time / total_iterations + if writer: + if args.log_timers_to_tensorboard: + writer.add_scalar('iteration-time', + elapsed_time_per_iteration, iteration) + log_string = ' iteration {:8d}/{:8d} |'.format( + iteration, args.train_iters) + log_string += ' consumed samples: {:12d} |'.format( + args.consumed_train_samples) + log_string += ' Task: {} |'.format( + args.task) + log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( + elapsed_time_per_iteration * 1000.0) + log_string += ' learning rate: {:.3E} |'.format(learning_rate) + log_string += ' global batch size: {:5d} |'.format(batch_size) + for key in total_loss_dict: + if key not in [advanced_iters_key, skipped_iters_key, + nan_iters_key]: + avg = total_loss_dict[key].item() / \ + float(max(1, total_loss_dict[advanced_iters_key])) + if avg > 0.0: + log_string += ' {}: {:.6E} |'.format(key, avg) + total_loss_dict[key] = torch.cuda.FloatTensor([0.0]) + log_string += ' loss scale: {:.1f} |'.format(loss_scale) + if grad_norm is not None: + log_string += ' grad norm: {:.3f} |'.format(grad_norm) + if num_zeros_in_grad is not None: + log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad) + if params_norm is not None: + log_string += ' params norm: {:.3f} |'.format(params_norm) + log_string += ' number of skipped iterations: {:3d} |'.format( + total_loss_dict[skipped_iters_key]) + log_string += ' number of nan iterations: {:3d} |'.format( + total_loss_dict[nan_iters_key]) + total_loss_dict[advanced_iters_key] = 0 + total_loss_dict[skipped_iters_key] = 0 + total_loss_dict[nan_iters_key] = 0 + print_rank_last(log_string) + if report_memory_flag and learning_rate > 0.: + # Report memory after optimizer state has been initialized. + report_memory('(after {} iterations)'.format(iteration)) + report_memory_flag = False + timers.log(timers_to_log, normalizer=args.log_interval) + + return report_memory_flag + + +def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler): + timers = get_timers() + # Extra barrier is added to make sure + # all ranks report the max time. + torch.distributed.barrier() + timers('save-checkpoint').start() + save_checkpoint(iteration, model, optimizer, opt_param_scheduler) + torch.distributed.barrier() + timers('save-checkpoint').stop() + timers.log(['save-checkpoint']) + + +def train(forward_step_func, model, optimizer, opt_param_scheduler, + train_data_iterator, valid_data_iterator, + process_non_loss_data_func): + """Train the model function.""" + args = get_args() + timers = get_timers() + + # Write args to tensorboard + write_args_to_tensorboard() + + # Turn on training mode which enables dropout. + for model_module in model: + model_module.train() + + # Tracking loss. + total_loss_dict = {} + + # Iterations. + iteration = args.iteration + + timers('interval-time').start() + print_datetime('before the start of training step') + if is_last_rank(): + my_writer = SummaryWriter(args.save + "/tb_res") + else: + my_writer = None + report_memory_flag = True + while iteration < args.train_iters: + update_num_microbatches(args.consumed_train_samples) + args.curr_iteration = iteration + loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ + train_step(forward_step_func, + train_data_iterator, + model, + optimizer, + opt_param_scheduler) + iteration += 1 + args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * \ + get_num_microbatches() + + # Logging. + loss_scale = optimizer.get_loss_scale().item() + params_norm = None + if args.log_params_norm: + params_norm = calc_params_l2_norm(model) + report_memory_flag = training_log(loss_dict, total_loss_dict, + optimizer.param_groups[0]['lr'], + iteration, loss_scale, + report_memory_flag, skipped_iter, + grad_norm, params_norm, num_zeros_in_grad, my_writer) + + # Autoresume + if args.adlr_autoresume and \ + (iteration % args.adlr_autoresume_interval == 0): + check_adlr_autoresume_termination(iteration, model, optimizer, + opt_param_scheduler) + + # Evaluation + if args.eval_interval and iteration % args.eval_interval == 0 and \ + args.do_valid: + prefix = 'iteration {}'.format(iteration) + evaluate_and_print_results(prefix, forward_step_func, + valid_data_iterator, model, + iteration, process_non_loss_data_func, my_writer, + False) + + # Checkpointing + saved_checkpoint = False + if args.exit_signal_handler: + signal_handler = get_signal_handler() + if any(signal_handler.signals_received()): + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler) + print_datetime('exiting program after receiving SIGTERM.') + sys.exit() + + if args.save and args.save_interval and \ + iteration % args.save_interval == 0: + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler) + saved_checkpoint = True + + # Exiting based on duration + if args.exit_duration_in_mins: + train_time = (time.time() - _TRAIN_START_TIME) / 60.0 + done_cuda = torch.cuda.IntTensor( + [train_time > args.exit_duration_in_mins]) + torch.distributed.all_reduce( + done_cuda, op=torch.distributed.ReduceOp.MAX) + done = done_cuda.item() + if done: + if not saved_checkpoint: + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler) + print_datetime('exiting program after {} minutes'.format(train_time)) + sys.exit() + + # Exiting based on iterations + if args.exit_interval and iteration % args.exit_interval == 0: + if not saved_checkpoint: + save_checkpoint_and_time(iteration, model, optimizer, + opt_param_scheduler) + torch.distributed.barrier() + print_datetime('exiting program at iteration {}'.format(iteration)) + sys.exit() + + + return iteration + + +def evaluate(forward_step_func, + data_iterator, + model, + process_non_loss_data_func, + verbose=False): + """Evaluation.""" + args = get_args() + + if args.vision_pretraining and args.vision_pretraining_type == "dino": + compute_feature_bank(model) + + # Turn on evaluation mode which disables dropout. + for model_module in model: + model_module.eval() + + total_loss_dict = {} + + with torch.no_grad(): + iteration = 0 + while iteration < args.eval_iters: + iteration += 1 + if verbose and iteration % args.log_interval == 0: + print_rank_0('Evaluating iter {}/{}'.format(iteration, + args.eval_iters)) + + forward_backward_func = get_forward_backward_func() + loss_dicts = forward_backward_func( + forward_step_func, data_iterator, model, optimizer=None, + timers=None, forward_only=True) + + # Empty unused memory + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Reduce across processes. + for loss_dict in loss_dicts: + for key in loss_dict: + if key == "describe": + continue + total_loss_dict[key] = total_loss_dict.get( + key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] + + args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ + * args.micro_batch_size \ + * get_num_microbatches() + collected_non_loss_data = None + if process_non_loss_data_func is not None and is_last_rank(): + collected_non_loss_data = forward_backward_func( + forward_step_func, data_iterator, model, optimizer=None, + timers=None, forward_only=True, collect_non_loss_data=True) + + # Move model back to the train mode. + for model_module in model: + model_module.train() + + for key in total_loss_dict: + total_loss_dict[key] /= args.eval_iters * get_num_microbatches() + if "describe" in loss_dict: + total_loss_dict["describe"] = loss_dict["describe"] + + return total_loss_dict, collected_non_loss_data + +def evaluate_and_print_results(prefix, forward_step_func, + data_iterator, model, + iteration, process_non_loss_data_func, my_writer, + verbose=False): + """Helper function to evaluate and dump results on screen.""" + args = get_args() + writer = get_tensorboard_writer() + + total_loss_dict, collected_non_loss_data = evaluate( + forward_step_func, data_iterator, model, + process_non_loss_data_func, verbose) + string = ' validation loss at {} | '.format(prefix) + total_val_loss = 0 + for key in total_loss_dict: + if key == "describe": + if isinstance(total_loss_dict["describe"], str): + string += total_loss_dict["describe"] + continue + elif isinstance(total_loss_dict["describe"], dict): + continue + else: + raise "Not Imp" + string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) + # ppl = math.exp(min(20, total_loss_dict[key].item())) + # string += '{} PPL: {:.6E} | '.format(key, ppl) + + if my_writer and is_last_rank(): + my_writer.add_scalar('val/' + key, total_loss_dict[key].item(), iteration) + my_writer.add_scalar('val/' + key + '_ppl', ppl, iteration) + total_val_loss += total_loss_dict[key].item() + # if writer: + # writer.add_scalar('{} validation'.format(key), + # total_loss_dict[key].item(), + # iteration) + # writer.add_scalar('{} validation vs samples'.format(key), + # total_loss_dict[key].item(), + # args.consumed_train_samples) + # if args.log_validation_ppl_to_tensorboard: + # writer.add_scalar('{} validation ppl'.format(key), ppl, + # iteration) + # writer.add_scalar('{} validation ppl vs samples'.format(key), + # ppl, args.consumed_train_samples) + + if process_non_loss_data_func is not None and writer and is_last_rank(): + process_non_loss_data_func(collected_non_loss_data, iteration, writer) + + if my_writer and is_last_rank(): + my_writer.add_scalar("val/total-loss", total_val_loss, iteration) + + length = len(string) + 1 + print_rank_last('-' * length) + print_rank_last(string) + if "describe" in total_loss_dict and isinstance(total_loss_dict["describe"], dict): + for k, v in total_loss_dict["describe"].items(): + out_str = " : ".join([k, v]) + print_rank_last(out_str) + print_rank_last('-' * length) + + +def cyclic_iter(iter): + while True: + for x in iter: + yield x + +def build_train_valid_test_data_iterators( + build_train_valid_test_datasets_provider): + """XXX""" + args = get_args() + + (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) + + print_rank_0('> building train, validation, and test datasets ...') + + # Backward compatibility, assume fixed batch size. + if args.iteration > 0 and args.consumed_train_samples == 0: + assert args.train_samples is None, \ + 'only backward compatiblity support for iteration-based training' + args.consumed_train_samples = args.iteration * args.global_batch_size + if args.iteration > 0 and args.consumed_valid_samples == 0: + if args.train_samples is None: + args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ + args.eval_iters * args.global_batch_size + + # Data loader only on rank 0 of each model parallel group. + if mpu.get_tensor_model_parallel_rank() == 0: + + # Number of train/valid/test samples. + if args.train_samples: + train_samples = args.train_samples + else: + train_samples = args.train_iters * args.global_batch_size + eval_iters = (args.train_iters // args.eval_interval + 1) * \ + args.eval_iters + test_iters = args.eval_iters + train_val_test_num_samples = [train_samples, + eval_iters * args.global_batch_size, + test_iters * args.global_batch_size] + print_rank_0(' > datasets target sizes (minimum size):') + print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) + print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) + print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) + + # Build the datasets. + train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider( + train_val_test_num_samples) + + # Build dataloders. + train_dataloader = build_pretraining_data_loader( + train_ds, args.consumed_train_samples) + valid_dataloader = build_pretraining_data_loader( + valid_ds, args.consumed_valid_samples) + test_dataloader = build_pretraining_data_loader(test_ds, 0) + + # Flags to know if we need to do training/validation/testing. + do_train = train_dataloader is not None and args.train_iters > 0 + do_valid = valid_dataloader is not None and args.eval_iters > 0 + do_test = test_dataloader is not None and args.eval_iters > 0 + # Need to broadcast num_tokens and num_type_tokens. + flags = torch.cuda.LongTensor( + [int(do_train), int(do_valid), int(do_test)]) + else: + flags = torch.cuda.LongTensor([0, 0, 0]) + + # Broadcast num tokens. + torch.distributed.broadcast(flags, + mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group()) + args.do_train = flags[0].item() + args.do_valid = flags[1].item() + args.do_test = flags[2].item() + + # Build iterators. + dl_type = args.dataloader_type + assert dl_type in ['single', 'cyclic'] + + if train_dataloader is not None: + train_data_iterator = iter(train_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(train_dataloader)) + else: + train_data_iterator = None + + if valid_dataloader is not None: + valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(valid_dataloader)) + else: + valid_data_iterator = None + + if test_dataloader is not None: + test_data_iterator = iter(test_dataloader) if dl_type == 'single' \ + else iter(cyclic_iter(test_dataloader)) + else: + test_data_iterator = None + + return train_data_iterator, valid_data_iterator, test_data_iterator diff --git a/megatron/utils.py b/megatron/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..89bdba94aa2a90d92660c23a237072cfbba05191 --- /dev/null +++ b/megatron/utils.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General utilities.""" + +import sys + +import torch +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from apex.multi_tensor_apply import multi_tensor_applier +import amp_C + +from megatron import get_args +from megatron import get_adlr_autoresume +from megatron import mpu +from megatron.model.module import param_is_not_shared +from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate + + +def unwrap_model(model, module_instances=(torchDDP)): + return_list = True + if not isinstance(model, list): + model = [model] + return_list = False + unwrapped_model = [] + for model_module in model: + while isinstance(model_module, module_instances): + model_module = model_module.module + unwrapped_model.append(model_module) + if not return_list: + return unwrapped_model[0] + return unwrapped_model + + +def calc_params_l2_norm(model): + """Calculate l2 norm of parameters """ + args = get_args() + if not isinstance(model, list): + model = [model] + # Remove duplicate params. + params_data = [] + for model_ in model: + for param in model_.parameters(): + is_not_shared = param_is_not_shared(param) + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if is_not_shared and is_not_tp_duplicate: + if args.bf16: + params_data.append(param.data.float()) + else: + params_data.append(param.data) + # Calculate norm + dummy_overflow_buf = torch.cuda.IntTensor([0]) + norm, _ = multi_tensor_applier( + amp_C.multi_tensor_l2norm, + dummy_overflow_buf, + [params_data], + False # no per-parameter norm + ) + norm_2 = norm * norm + # Sum across all model-parallel GPUs. + torch.distributed.all_reduce(norm_2, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_model_parallel_group()) + return norm_2.item() ** 0.5 + + +def average_losses_across_data_parallel_group(losses): + """Reduce a tensor of losses across all GPUs.""" + averaged_losses = torch.cat( + [loss.clone().detach().view(1) for loss in losses]) + torch.distributed.all_reduce(averaged_losses, + group=mpu.get_data_parallel_group()) + averaged_losses = averaged_losses / \ + torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) + + return averaged_losses + + +def report_memory(name): + """Simple GPU memory report.""" + mega_bytes = 1024.0 * 1024.0 + string = name + ' memory (MB)' + string += ' | allocated: {}'.format( + torch.cuda.memory_allocated() / mega_bytes) + string += ' | max allocated: {}'.format( + torch.cuda.max_memory_allocated() / mega_bytes) + string += ' | reserved: {}'.format( + torch.cuda.memory_reserved() / mega_bytes) + string += ' | max reserved: {}'.format( + torch.cuda.max_memory_reserved() / mega_bytes) + if mpu.get_data_parallel_rank() == 0: + print("[Rank {}] {}".format(torch.distributed.get_rank(), string), + flush=True) + + +def print_params_min_max_norm(optimizer, iteration): + """Print min, max, and norm of all parameters.""" + index = 0 + rank = torch.distributed.get_rank() + string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n' + optimizer_ = optimizer.optimizer + for param_group in optimizer_.param_groups: + for param in param_group['params']: + index += 1 + min_ = param.data.min() + max_ = param.data.max() + norm = torch.linalg.norm(param.data) + string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( + iteration, rank, index, int(param.tensor_model_parallel)) + string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) + print(string, flush=True) + + +def check_adlr_autoresume_termination(iteration, model, + optimizer, opt_param_scheduler): + """Check for autoresume signal and exit if it is received.""" + from megatron.checkpointing import save_checkpoint + + args = get_args() + autoresume = get_adlr_autoresume() + # Add barrier to ensure consistnecy. + torch.distributed.barrier() + if autoresume.termination_requested(): + if args.save: + save_checkpoint(iteration, model, optimizer, opt_param_scheduler) + print_rank_0(">>> autoresume termination request found!") + if torch.distributed.get_rank() == 0: + autoresume.request_resume() + print_rank_0(">>> training terminated. Returning") + sys.exit(0) + + +def get_ltor_masks_and_position_ids(data, + eod_token, + reset_position_ids, + reset_attention_mask, + eod_mask_loss): + """Build masks and position id for left to right model.""" + + # Extract batch size and sequence length. + micro_batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if reset_attention_mask: + att_mask_batch = micro_batch_size + else: + att_mask_batch = 1 + attention_mask = torch.tril(torch.ones( + (att_mask_batch, seq_length, seq_length), device=data.device)).view( + att_mask_batch, 1, seq_length, seq_length) + + # Loss mask. + loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) + if eod_mask_loss: + loss_mask[data == eod_token] = 0.0 + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, + device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data) + # We need to clone as the ids will be modifed based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Loop through the batches: + for b in range(micro_batch_size): + + # Find indecies where EOD token is. + eod_index = position_ids[b, data[b] == eod_token] + # Detach indecies from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indecies: + prev_index = 0 + for j in range(eod_index.size()[0]): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask: + attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[b, (i + 1):] -= (i + 1 - prev_index) + prev_index = i + 1 + + # Convert attention mask to binary: + attention_mask = (attention_mask < 0.5) + + return attention_mask, loss_mask, position_ids + + +def print_rank_0(message): + """If distributed is initialized, print only on rank 0.""" + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + print(message, flush=True) + else: + print(message, flush=True) + +def is_last_rank(): + return torch.distributed.get_rank() == ( + torch.distributed.get_world_size() - 1) + +def print_rank_last(message): + """If distributed is initialized, print only on last rank.""" + if torch.distributed.is_initialized(): + if is_last_rank(): + print(message, flush=True) + else: + print(message, flush=True)