Shawn001 commited on
Commit
23bd7af
1 Parent(s): bb5d7de

Upload 131 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. megatron/__init__.py +31 -0
  3. megatron/arguments.py +1018 -0
  4. megatron/checkpointing.py +675 -0
  5. megatron/data/Makefile +9 -0
  6. megatron/data/__init__.py +1 -0
  7. megatron/data/autoaugment.py +320 -0
  8. megatron/data/bert_dataset.py +234 -0
  9. megatron/data/biencoder_dataset_utils.py +208 -0
  10. megatron/data/blendable_dataset.py +68 -0
  11. megatron/data/data_samplers.py +199 -0
  12. megatron/data/dataset_utils.py +938 -0
  13. megatron/data/glm_dataset.py +377 -0
  14. megatron/data/gpt_dataset.py +430 -0
  15. megatron/data/helpers.cpp +717 -0
  16. megatron/data/helpers.cpython-38-x86_64-linux-gnu.so +0 -0
  17. megatron/data/helpers.cpython-39-x86_64-linux-gnu.so +0 -0
  18. megatron/data/ict_dataset.py +156 -0
  19. megatron/data/image_folder.py +302 -0
  20. megatron/data/indexed_dataset.py +576 -0
  21. megatron/data/orqa_wiki_dataset.py +205 -0
  22. megatron/data/realm_dataset_utils.py +198 -0
  23. megatron/data/realm_index.py +224 -0
  24. megatron/data/t5_dataset.py +270 -0
  25. megatron/data/test/test_indexed_dataset.py +125 -0
  26. megatron/data/test/test_preprocess_data.sh +10 -0
  27. megatron/data/vit_dataset.py +262 -0
  28. megatron/dist_signal_handler.py +81 -0
  29. megatron/fp16_deprecated/loss_scaler.py +39 -0
  30. megatron/fused_kernels/__init__.py +125 -0
  31. megatron/fused_kernels/build/.ninja_deps +0 -0
  32. megatron/fused_kernels/build/.ninja_log +99 -0
  33. megatron/fused_kernels/build/build.ninja +28 -0
  34. megatron/fused_kernels/build/fused_mix_prec_layer_norm_cuda.so +0 -0
  35. megatron/fused_kernels/build/layer_norm_cuda.o +0 -0
  36. megatron/fused_kernels/build/layer_norm_cuda_kernel.cuda.o +0 -0
  37. megatron/fused_kernels/build/scaled_masked_softmax.o +0 -0
  38. megatron/fused_kernels/build/scaled_masked_softmax_cuda.cuda.o +3 -0
  39. megatron/fused_kernels/build/scaled_masked_softmax_cuda.so +3 -0
  40. megatron/fused_kernels/build/scaled_softmax.o +0 -0
  41. megatron/fused_kernels/build/scaled_softmax_cuda.cuda.o +3 -0
  42. megatron/fused_kernels/build/scaled_softmax_cuda.so +3 -0
  43. megatron/fused_kernels/build/scaled_upper_triang_masked_softmax.o +0 -0
  44. megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.cuda.o +0 -0
  45. megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.so +3 -0
  46. megatron/fused_kernels/compat.h +31 -0
  47. megatron/fused_kernels/fused_weight_gradient_dense.cpp +47 -0
  48. megatron/fused_kernels/fused_weight_gradient_dense.cu +157 -0
  49. megatron/fused_kernels/layer_norm_cuda.cpp +201 -0
  50. megatron/fused_kernels/layer_norm_cuda_kernel.cu +832 -0
.gitattributes CHANGED
@@ -36,3 +36,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  clue_data/csl/train.json filter=lfs diff=lfs merge=lfs -text
37
  clue_data/iflytek/train.json filter=lfs diff=lfs merge=lfs -text
38
  clue_data/ocnli/train.json filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
36
  clue_data/csl/train.json filter=lfs diff=lfs merge=lfs -text
37
  clue_data/iflytek/train.json filter=lfs diff=lfs merge=lfs -text
38
  clue_data/ocnli/train.json filter=lfs diff=lfs merge=lfs -text
39
+ megatron/fused_kernels/build/scaled_masked_softmax_cuda.cuda.o filter=lfs diff=lfs merge=lfs -text
40
+ megatron/fused_kernels/build/scaled_masked_softmax_cuda.so filter=lfs diff=lfs merge=lfs -text
41
+ megatron/fused_kernels/build/scaled_softmax_cuda.cuda.o filter=lfs diff=lfs merge=lfs -text
42
+ megatron/fused_kernels/build/scaled_softmax_cuda.so filter=lfs diff=lfs merge=lfs -text
43
+ megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.so filter=lfs diff=lfs merge=lfs -text
megatron/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+
17
+ from .global_vars import get_args
18
+ from .global_vars import get_current_global_batch_size
19
+ from .global_vars import get_num_microbatches
20
+ from .global_vars import get_signal_handler
21
+ from .global_vars import update_num_microbatches
22
+ from .global_vars import get_tokenizer
23
+ from .global_vars import get_tensorboard_writer
24
+ from .global_vars import get_adlr_autoresume
25
+ from .global_vars import get_timers
26
+ from .global_vars import get_global_memory_buffer
27
+ from .initialize import initialize_megatron
28
+
29
+ from .utils import (print_rank_0,
30
+ is_last_rank,
31
+ print_rank_last)
megatron/arguments.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Megatron arguments."""
17
+
18
+ import argparse
19
+ import os
20
+
21
+ import torch
22
+
23
+ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
24
+ """Parse all arguments."""
25
+ parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
26
+ allow_abbrev=False)
27
+
28
+ # Standard arguments.
29
+ parser = _add_network_size_args(parser)
30
+ parser = _add_regularization_args(parser)
31
+ parser = _add_training_args(parser)
32
+ parser = _add_initialization_args(parser)
33
+ parser = _add_learning_rate_args(parser)
34
+ parser = _add_checkpointing_args(parser)
35
+ parser = _add_mixed_precision_args(parser)
36
+ parser = _add_distributed_args(parser)
37
+ parser = _add_validation_args(parser)
38
+ parser = _add_data_args(parser)
39
+ parser = _add_autoresume_args(parser)
40
+ parser = _add_biencoder_args(parser)
41
+ parser = _add_vision_args(parser)
42
+ parser = _add_logging_args(parser)
43
+ parser = _add_inference_args(parser)
44
+
45
+ # Custom arguments.
46
+ if extra_args_provider is not None:
47
+ parser = extra_args_provider(parser)
48
+
49
+ # Parse.
50
+ if ignore_unknown_args:
51
+ args, _ = parser.parse_known_args()
52
+ else:
53
+ args = parser.parse_args()
54
+
55
+ # Args from environment
56
+ args.rank = int(os.getenv('RANK', '0'))
57
+ args.world_size = int(os.getenv("WORLD_SIZE", '1'))
58
+
59
+ return args
60
+
61
+ def validate_args(args, defaults={}):
62
+ # Tensor model parallel size.
63
+ args.tensor_model_parallel_size = min(
64
+ args.tensor_model_parallel_size, args.world_size)
65
+ assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
66
+ ' ({}) is not divisible by tensor model parallel size ({})'.format(
67
+ args.world_size, args.tensor_model_parallel_size)
68
+ # Pipeline model parallel size.
69
+ args.pipeline_model_parallel_size = min(
70
+ args.pipeline_model_parallel_size,
71
+ (args.world_size // args.tensor_model_parallel_size))
72
+ args.transformer_pipeline_model_parallel_size = (
73
+ args.pipeline_model_parallel_size - 1
74
+ if args.standalone_embedding_stage else
75
+ args.pipeline_model_parallel_size
76
+ )
77
+ # Checks.
78
+ model_parallel_size = args.pipeline_model_parallel_size * \
79
+ args.tensor_model_parallel_size
80
+ assert args.world_size % model_parallel_size == 0, 'world size is not'\
81
+ ' divisible by tensor parallel size ({}) times pipeline parallel ' \
82
+ 'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
83
+ args.pipeline_model_parallel_size)
84
+ args.data_parallel_size = args.world_size // model_parallel_size
85
+ if args.rank == 0:
86
+ print('using world size: {}, data-parallel-size: {}, '
87
+ 'tensor-model-parallel size: {}, '
88
+ 'pipeline-model-parallel size: {} '.format(
89
+ args.world_size, args.data_parallel_size,
90
+ args.tensor_model_parallel_size,
91
+ args.pipeline_model_parallel_size), flush=True)
92
+ if args.pipeline_model_parallel_size > 1:
93
+ if args.pipeline_model_parallel_split_rank is not None:
94
+ assert args.pipeline_model_parallel_split_rank < \
95
+ args.pipeline_model_parallel_size, 'split rank needs'\
96
+ ' to be less than pipeline model parallel size ({})'.format(
97
+ args.pipeline_model_parallel_size)
98
+ if args.data_path:
99
+ # Dataset arguments
100
+ data_path = args.data_path
101
+ processed_data_path = []
102
+ for path in data_path:
103
+ files = os.listdir(path)
104
+ idx_files = [fn[:-4] for fn in files if fn.endswith('.idx')]
105
+ bin_files = [fn[:-4] for fn in files if fn.endswith('.bin')]
106
+ for idx_fn in idx_files:
107
+ if idx_fn in bin_files:
108
+ # add weight and data path
109
+ processed_data_path.append('1')
110
+ processed_data_path.append(os.path.join(path, idx_fn))
111
+ args.raw_data_path = data_path
112
+ args.data_path = processed_data_path
113
+
114
+
115
+ # Deprecated arguments
116
+ assert args.batch_size is None, '--batch-size argument is no longer ' \
117
+ 'valid, use --micro-batch-size instead'
118
+ del args.batch_size
119
+ assert args.warmup is None, '--warmup argument is no longer valid, use ' \
120
+ '--lr-warmup-fraction instead'
121
+ del args.warmup
122
+ assert args.model_parallel_size is None, '--model-parallel-size is no ' \
123
+ 'longer valid, use --tensor-model-parallel-size instead'
124
+ del args.model_parallel_size
125
+
126
+ if args.checkpoint_activations:
127
+ args.recompute_granularity = 'full'
128
+ args.recompute_method = 'uniform'
129
+ if args.rank == 0:
130
+ print('--checkpoint-activations is no longer valid, '
131
+ 'use --recompute-granularity and --recompute-method instead. '
132
+ 'Defaulting to recompute-granularity=full and recompute-method=uniform.')
133
+ del args.checkpoint_activations
134
+
135
+ if args.recompute_activations:
136
+ args.recompute_granularity = 'selective'
137
+ del args.recompute_activations
138
+
139
+ # Set input defaults.
140
+ for key in defaults:
141
+ # For default to be valid, it should not be provided in the
142
+ # arguments that are passed to the program. We check this by
143
+ # ensuring the arg is set to None.
144
+ if getattr(args, key) is not None:
145
+ if args.rank == 0:
146
+ print('WARNING: overriding default arguments for {key}:{v} \
147
+ with {key}:{v2}'.format(key=key, v=defaults[key],
148
+ v2=getattr(args, key)),
149
+ flush=True)
150
+ else:
151
+ setattr(args, key, defaults[key])
152
+
153
+ # Batch size.
154
+ assert args.micro_batch_size is not None
155
+ assert args.micro_batch_size > 0
156
+ if args.global_batch_size is None:
157
+ args.global_batch_size = args.micro_batch_size * args.data_parallel_size
158
+ if args.rank == 0:
159
+ print('setting global batch size to {}'.format(
160
+ args.global_batch_size), flush=True)
161
+ assert args.global_batch_size > 0
162
+ if args.num_layers_per_virtual_pipeline_stage is not None:
163
+ assert args.pipeline_model_parallel_size > 2, \
164
+ 'pipeline-model-parallel size should be greater than 2 with ' \
165
+ 'interleaved schedule'
166
+ assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
167
+ 'number of layers is not divisible by number of layers per virtual ' \
168
+ 'pipeline stage'
169
+ args.virtual_pipeline_model_parallel_size = \
170
+ (args.num_layers // args.transformer_pipeline_model_parallel_size) // \
171
+ args.num_layers_per_virtual_pipeline_stage
172
+ else:
173
+ args.virtual_pipeline_model_parallel_size = None
174
+
175
+ # Parameters dtype.
176
+ args.params_dtype = torch.float
177
+ if args.fp16:
178
+ assert not args.bf16
179
+ args.params_dtype = torch.half
180
+ if args.bf16:
181
+ assert not args.fp16
182
+ args.params_dtype = torch.bfloat16
183
+ # bfloat16 requires gradient accumulation and all-reduce to
184
+ # be done in fp32.
185
+ if not args.accumulate_allreduce_grads_in_fp32:
186
+ args.accumulate_allreduce_grads_in_fp32 = True
187
+ if args.rank == 0:
188
+ print('accumulate and all-reduce gradients in fp32 for '
189
+ 'bfloat16 data type.', flush=True)
190
+
191
+ if args.rank == 0:
192
+ print('using {} for parameters ...'.format(args.params_dtype),
193
+ flush=True)
194
+
195
+ # If we do accumulation and all-reduces in fp32, we need to have local DDP
196
+ # and we should make sure use-contiguous-buffers-in-local-ddp is not off.
197
+ if args.accumulate_allreduce_grads_in_fp32:
198
+ assert args.DDP_impl == 'local'
199
+ assert args.use_contiguous_buffers_in_local_ddp
200
+ else:
201
+ if args.gradient_accumulation_fusion:
202
+ args.gradient_accumulation_fusion = False
203
+ if args.rank == 0:
204
+ print('Gradient accumulation fusion to linear layer weight '
205
+ 'gradient computation is supported only with fp32 '
206
+ 'gradient accumulation. Setting gradient_accumulation_fusion '
207
+ 'to False', flush=True)
208
+
209
+ # If we use the distributed optimizer, we need to have local DDP
210
+ # and we should make sure use-contiguous-buffers-in-local-ddp is on.
211
+ if args.use_distributed_optimizer:
212
+ assert args.DDP_impl == 'local'
213
+ assert args.use_contiguous_buffers_in_local_ddp
214
+
215
+ # For torch DDP, we do not use contiguous buffer
216
+ if args.DDP_impl == 'torch':
217
+ args.use_contiguous_buffers_in_local_ddp = False
218
+
219
+ if args.dataloader_type is None:
220
+ args.dataloader_type = 'single'
221
+
222
+ # Consumed tokens.
223
+ args.consumed_train_samples = 0
224
+ args.consumed_valid_samples = 0
225
+
226
+ # Iteration-based training.
227
+ if args.train_iters:
228
+ # If we use iteration-based training, make sure the
229
+ # sample-based options are off.
230
+ assert args.train_samples is None, \
231
+ 'expected iteration-based training'
232
+ assert args.lr_decay_samples is None, \
233
+ 'expected iteration-based learning rate decay'
234
+ assert args.lr_warmup_samples == 0, \
235
+ 'expected iteration-based learning rate warmup'
236
+ assert args.rampup_batch_size is None, \
237
+ 'expected no batch-size rampup for iteration-based training'
238
+ if args.lr_warmup_fraction is not None:
239
+ assert args.lr_warmup_iters == 0, \
240
+ 'can only specify one of lr-warmup-fraction and lr-warmup-iters'
241
+
242
+ # Sample-based training.
243
+ if args.train_samples:
244
+ # If we use sample-based training, make sure the
245
+ # iteration-based options are off.
246
+ assert args.train_iters is None, \
247
+ 'expected sample-based training'
248
+ assert args.lr_decay_iters is None, \
249
+ 'expected sample-based learning rate decay'
250
+ assert args.lr_warmup_iters == 0, \
251
+ 'expected sample-based learnig rate warmup'
252
+ if args.lr_warmup_fraction is not None:
253
+ assert args.lr_warmup_samples == 0, \
254
+ 'can only specify one of lr-warmup-fraction ' \
255
+ 'and lr-warmup-samples'
256
+
257
+ # Check required arguments.
258
+ required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
259
+ 'max_position_embeddings']
260
+ for req_arg in required_args:
261
+ _check_arg_is_not_none(args, req_arg)
262
+
263
+ # Checks.
264
+ if args.ffn_hidden_size is None:
265
+ args.ffn_hidden_size = 4 * args.hidden_size
266
+
267
+ if args.kv_channels is None:
268
+ assert args.hidden_size % args.num_attention_heads == 0
269
+ args.kv_channels = args.hidden_size // args.num_attention_heads
270
+
271
+ if args.seq_length is not None:
272
+ assert args.encoder_seq_length is None
273
+ args.encoder_seq_length = args.seq_length
274
+ else:
275
+ assert args.encoder_seq_length is not None
276
+ args.seq_length = args.encoder_seq_length
277
+
278
+ if args.seq_length is not None:
279
+ assert args.max_position_embeddings >= args.seq_length
280
+ if args.decoder_seq_length is not None:
281
+ assert args.max_position_embeddings >= args.decoder_seq_length
282
+ if args.lr is not None:
283
+ assert args.min_lr <= args.lr
284
+ if args.save is not None:
285
+ assert args.save_interval is not None
286
+ # Mixed precision checks.
287
+ if args.fp16_lm_cross_entropy:
288
+ assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
289
+ if args.fp32_residual_connection:
290
+ assert args.fp16 or args.bf16, \
291
+ 'residual connection in fp32 only supported when using fp16 or bf16.'
292
+
293
+ if args.weight_decay_incr_style == 'constant':
294
+ assert args.start_weight_decay is None
295
+ assert args.end_weight_decay is None
296
+ args.start_weight_decay = args.weight_decay
297
+ args.end_weight_decay = args.weight_decay
298
+ else:
299
+ assert args.start_weight_decay is not None
300
+ assert args.end_weight_decay is not None
301
+
302
+ TORCH_MAJOR = int(torch.__version__.split('.')[0])
303
+ TORCH_MINOR = int(torch.__version__.split('.')[1])
304
+ # Persistent fused layer norm.
305
+ if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
306
+ args.no_persist_layer_norm = True
307
+ if args.rank == 0:
308
+ print('Persistent fused layer norm kernel is supported from '
309
+ 'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
310
+ 'Defaulting to no_persist_layer_norm=True')
311
+
312
+ # Activation recomputing.
313
+ if args.distribute_saved_activations:
314
+ assert args.tensor_model_parallel_size > 1, 'can distribute ' \
315
+ 'recomputed activations only across tensor model ' \
316
+ 'parallel groups'
317
+ assert args.recompute_granularity == 'full', \
318
+ 'distributed recompute activations is only '\
319
+ 'application to full recompute granularity'
320
+ assert args.recompute_method is not None, \
321
+ 'for distributed recompute activations to work you '\
322
+ 'need to use a recompute method '
323
+ assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
324
+ 'distributed recompute activations are supported for pytorch ' \
325
+ 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
326
+ 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
327
+
328
+ if args.recompute_granularity == 'selective':
329
+ assert args.recompute_method is None, \
330
+ 'recompute method is not yet supported for ' \
331
+ 'selective recomputing granularity'
332
+
333
+ # disable sequence parallelism when tp=1
334
+ # to avoid change in numerics when
335
+ # sequence_parallelism is enabled.
336
+ if args.tensor_model_parallel_size == 1:
337
+ args.sequence_parallel = False
338
+
339
+ # disable async_tensor_model_parallel_allreduce when
340
+ # model parallel memory optimization is enabled
341
+ if args.sequence_parallel:
342
+ args.async_tensor_model_parallel_allreduce = False
343
+
344
+ _print_args(args)
345
+ return args
346
+
347
+
348
+ def _print_args(args):
349
+ """Print arguments."""
350
+ if args.rank == 0:
351
+ print('------------------------ arguments ------------------------',
352
+ flush=True)
353
+ str_list = []
354
+ for arg in vars(args):
355
+ dots = '.' * (48 - len(arg))
356
+ str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
357
+ for arg in sorted(str_list, key=lambda x: x.lower()):
358
+ print(arg, flush=True)
359
+ print('-------------------- end of arguments ---------------------',
360
+ flush=True)
361
+
362
+
363
+ def _check_arg_is_not_none(args, arg):
364
+ assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
365
+
366
+
367
+ def _add_inference_args(parser):
368
+ group = parser.add_argument_group(title='inference')
369
+
370
+ group.add_argument('--inference-batch-times-seqlen-threshold',
371
+ type=int, default=512,
372
+ help='During inference, if batch-size times '
373
+ 'sequence-length is smaller than this threshold '
374
+ 'then we will not use pipelining, otherwise we will.')
375
+
376
+ return parser
377
+
378
+
379
+ def _add_network_size_args(parser):
380
+ group = parser.add_argument_group(title='network size')
381
+
382
+ group.add_argument('--num-layers', type=int, default=None,
383
+ help='Number of transformer layers.')
384
+ group.add_argument('--num-layers-decoder', type=int, default=None,
385
+ help='Number of transformer layers decoder.')
386
+ group.add_argument('--hidden-size', type=int, default=None,
387
+ help='Tansformer hidden size.')
388
+ group.add_argument('--ffn-hidden-size', type=int, default=None,
389
+ help='Transformer Feed-Forward Network hidden size. '
390
+ 'This is set to 4*hidden-size if not provided')
391
+ group.add_argument('--num-attention-heads', type=int, default=None,
392
+ help='Number of transformer attention heads.')
393
+ group.add_argument('--kv-channels', type=int, default=None,
394
+ help='Projection weights dimension in multi-head '
395
+ 'attention. This is set to '
396
+ ' args.hidden_size // args.num_attention_heads '
397
+ 'if not provided.')
398
+ group.add_argument('--max-position-embeddings', type=int, default=None,
399
+ help='Maximum number of position embeddings to use. '
400
+ 'This is the size of position embedding.')
401
+ group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
402
+ help='Pad the vocab size to be divisible by this value.'
403
+ 'This is added for computational efficieny reasons.')
404
+ group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
405
+ help='Layer norm epsilon.')
406
+ group.add_argument('--apply-residual-connection-post-layernorm',
407
+ action='store_true',
408
+ help='If set, use original BERT residula connection '
409
+ 'ordering.')
410
+ group.add_argument('--openai-gelu', action='store_true',
411
+ help='Use OpenAIs GeLU implementation. This option'
412
+ 'should not be used unless for backward compatibility'
413
+ 'reasons.')
414
+ group.add_argument('--onnx-safe', type=bool, required=False,
415
+ help='Use workarounds for known problems with '
416
+ 'Torch ONNX exporter')
417
+ group.add_argument('--bert-no-binary-head', action='store_false',
418
+ help='Disable BERT binary head.',
419
+ dest='bert_binary_head')
420
+ group.add_argument('--num-experts', type=int, default=None,
421
+ help='Number of Experts in Switch Transformer (None means no Switch)')
422
+ return parser
423
+
424
+
425
+ def _add_logging_args(parser):
426
+ group = parser.add_argument_group(title='logging')
427
+
428
+ group.add_argument('--log-params-norm', action='store_true',
429
+ help='If set, calculate and log parameters norm.')
430
+ group.add_argument('--log-num-zeros-in-grad', action='store_true',
431
+ help='If set, calculate and log the number of zeros in gradient.')
432
+ group.add_argument('--tensorboard-log-interval', type=int, default=1,
433
+ help='Report to tensorboard interval.')
434
+ group.add_argument('--tensorboard-queue-size', type=int, default=1000,
435
+ help='Size of the tensorboard queue for pending events '
436
+ 'and summaries before one of the ‘add’ calls forces a '
437
+ 'flush to disk.')
438
+ group.add_argument('--log-timers-to-tensorboard', action='store_true',
439
+ help='If set, write timers to tensorboard.')
440
+ group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
441
+ help='If set, write batch-size to tensorboard.')
442
+ group.add_argument('--no-log-learnig-rate-to-tensorboard',
443
+ action='store_false',
444
+ help='Disable learning rate logging to tensorboard.',
445
+ dest='log_learning_rate_to_tensorboard')
446
+ group.add_argument('--no-log-loss-scale-to-tensorboard',
447
+ action='store_false',
448
+ help='Disable loss-scale logging to tensorboard.',
449
+ dest='log_loss_scale_to_tensorboard')
450
+ group.add_argument('--log-validation-ppl-to-tensorboard',
451
+ action='store_true',
452
+ help='If set, write validation perplexity to '
453
+ 'tensorboard.')
454
+ group.add_argument('--log-memory-to-tensorboard',
455
+ action='store_true',
456
+ help='Enable memory logging to tensorboard.')
457
+ group.add_argument('--log-world-size-to-tensorboard',
458
+ action='store_true',
459
+ help='Enable world size logging to tensorboard.')
460
+
461
+ return parser
462
+
463
+
464
+ def _add_regularization_args(parser):
465
+ group = parser.add_argument_group(title='regularization')
466
+
467
+ group.add_argument('--attention-dropout', type=float, default=0.1,
468
+ help='Post attention dropout probability.')
469
+ group.add_argument('--hidden-dropout', type=float, default=0.1,
470
+ help='Dropout probability for hidden state transformer.')
471
+ group.add_argument('--weight-decay', type=float, default=0.01,
472
+ help='Weight decay coefficient for L2 regularization.')
473
+ group.add_argument('--start-weight-decay', type=float,
474
+ help='Initial weight decay coefficient for L2 regularization.')
475
+ group.add_argument('--end-weight-decay', type=float,
476
+ help='End of run weight decay coefficient for L2 regularization.')
477
+ group.add_argument('--weight-decay-incr-style', type=str, default='constant',
478
+ choices=['constant', 'linear', 'cosine'],
479
+ help='Weight decay increment function.')
480
+ group.add_argument('--clip-grad', type=float, default=1.0,
481
+ help='Gradient clipping based on global L2 norm.')
482
+ group.add_argument('--adam-beta1', type=float, default=0.9,
483
+ help='First coefficient for computing running averages '
484
+ 'of gradient and its square')
485
+ group.add_argument('--adam-beta2', type=float, default=0.999,
486
+ help='Second coefficient for computing running averages '
487
+ 'of gradient and its square')
488
+ group.add_argument('--adam-eps', type=float, default=1e-08,
489
+ help='Term added to the denominator to improve'
490
+ 'numerical stability')
491
+ group.add_argument('--sgd-momentum', type=float, default=0.9,
492
+ help='Momentum factor for sgd')
493
+
494
+ return parser
495
+
496
+
497
+ def _add_training_args(parser):
498
+ group = parser.add_argument_group(title='training')
499
+
500
+ group.add_argument('--micro-batch-size', type=int, default=None,
501
+ help='Batch size per model instance (local batch size). '
502
+ 'Global batch size is local batch size times data '
503
+ 'parallel size times number of micro batches.')
504
+ group.add_argument('--batch-size', type=int, default=None,
505
+ help='Old batch size parameter, do not use. '
506
+ 'Use --micro-batch-size instead')
507
+ group.add_argument('--global-batch-size', type=int, default=None,
508
+ help='Training batch size. If set, it should be a '
509
+ 'multiple of micro-batch-size times data-parallel-size. '
510
+ 'If this value is None, then '
511
+ 'use micro-batch-size * data-parallel-size as the '
512
+ 'global batch size. This choice will result in 1 for '
513
+ 'number of micro-batches.')
514
+ group.add_argument('--rampup-batch-size', nargs='*', default=None,
515
+ help='Batch size ramp up with the following values:'
516
+ ' --rampup-batch-size <start batch size> '
517
+ ' <batch size incerement> '
518
+ ' <ramp-up samples> '
519
+ 'For example:'
520
+ ' --rampup-batch-size 16 8 300000 \ '
521
+ ' --global-batch-size 1024'
522
+ 'will start with global batch size 16 and over '
523
+ ' (1024 - 16) / 8 = 126 intervals will increase'
524
+ 'the batch size linearly to 1024. In each interval'
525
+ 'we will use approximately 300000 / 126 = 2380 samples.')
526
+ group.add_argument('--recompute-activations', action='store_true',
527
+ help='recompute activation to allow for training '
528
+ 'with larger models, sequences, and batch sizes.')
529
+ group.add_argument('--recompute-granularity', type=str, default=None,
530
+ choices=['full', 'selective'],
531
+ help='Checkpoint activations to allow for training '
532
+ 'with larger models, sequences, and batch sizes. '
533
+ 'It is supported at two granularities 1) full: '
534
+ 'whole transformer layer is recomputed, '
535
+ '2) selective: core attention part of the transformer '
536
+ 'layer is recomputed.')
537
+ group.add_argument('--distribute-saved-activations',
538
+ action='store_true',
539
+ help='If set, distribute recomputed activations '
540
+ 'across model parallel group.')
541
+ group.add_argument('--recompute-method', type=str, default=None,
542
+ choices=['uniform', 'block'],
543
+ help='1) uniform: uniformly divide the total number of '
544
+ 'Transformer layers and recompute the input activation of '
545
+ 'each divided chunk at specified granularity, '
546
+ '2) recompute the input activations of only a set number of '
547
+ 'individual Transformer layers per pipeline stage and do the '
548
+ 'rest without any recomputing at specified granularity'
549
+ 'default) do not apply activations recompute to any layers')
550
+ group.add_argument('--recompute-num-layers', type=int, default=1,
551
+ help='1) uniform: the number of Transformer layers in each '
552
+ 'uniformly divided recompute unit, '
553
+ '2) block: the number of individual Transformer layers '
554
+ 'to recompute within each pipeline stage.')
555
+
556
+ # deprecated
557
+ group.add_argument('--checkpoint-activations', action='store_true',
558
+ help='Checkpoint activation to allow for training '
559
+ 'with larger models, sequences, and batch sizes.')
560
+ group.add_argument('--train-iters', type=int, default=None,
561
+ help='Total number of iterations to train over all '
562
+ 'training runs. Note that either train-iters or '
563
+ 'train-samples should be provided.')
564
+ group.add_argument('--train-samples', type=int, default=None,
565
+ help='Total number of samples to train over all '
566
+ 'training runs. Note that either train-iters or '
567
+ 'train-samples should be provided.')
568
+ group.add_argument('--log-interval', type=int, default=100,
569
+ help='Report loss and timing interval.')
570
+ group.add_argument('--exit-interval', type=int, default=None,
571
+ help='Exit the program after the iteration is divisible '
572
+ 'by this value.')
573
+ group.add_argument('--exit-duration-in-mins', type=int, default=None,
574
+ help='Exit the program after this many minutes.')
575
+ group.add_argument('--exit-signal-handler', action='store_true',
576
+ help='Dynamically save the checkpoint and shutdown the '
577
+ 'training if SIGTERM is received')
578
+ group.add_argument('--tensorboard-dir', type=str, default=None,
579
+ help='Write TensorBoard logs to this directory.')
580
+ group.add_argument('--no-masked-softmax-fusion',
581
+ action='store_false',
582
+ help='Disable fusion of query_key_value scaling, '
583
+ 'masking, and softmax.',
584
+ dest='masked_softmax_fusion')
585
+ group.add_argument('--no-bias-gelu-fusion', action='store_false',
586
+ help='Disable bias and gelu fusion.',
587
+ dest='bias_gelu_fusion')
588
+ group.add_argument('--no-bias-dropout-fusion', action='store_false',
589
+ help='Disable bias and dropout fusion.',
590
+ dest='bias_dropout_fusion')
591
+ group.add_argument('--optimizer', type=str, default='adam',
592
+ choices=['adam', 'sgd'],
593
+ help='Optimizer function')
594
+ group.add_argument('--dataloader-type', type=str, default=None,
595
+ choices=['single', 'cyclic'],
596
+ help='Single pass vs multiple pass data loader')
597
+ group.add_argument('--no-async-tensor-model-parallel-allreduce',
598
+ action='store_false',
599
+ help='Disable asynchronous execution of '
600
+ 'tensor-model-parallel all-reduce with weight '
601
+ 'gradient compuation of a column-linear layer.',
602
+ dest='async_tensor_model_parallel_allreduce')
603
+ group.add_argument('--no-persist-layer-norm', action='store_true',
604
+ help='Disable using persistent fused layer norm kernel. '
605
+ 'This kernel supports only a set of hidden sizes. Please '
606
+ 'check persist_ln_hidden_sizes if your hidden '
607
+ 'size is supported.')
608
+ group.add_argument('--sequence-parallel', action='store_true',
609
+ help='Enable sequence parallel optimization.')
610
+ group.add_argument('--no-gradient-accumulation-fusion',
611
+ action='store_false',
612
+ help='Disable fusing gradient accumulation to weight '
613
+ 'gradient computation of linear layers',
614
+ dest='gradient_accumulation_fusion')
615
+ return parser
616
+
617
+
618
+ def _add_initialization_args(parser):
619
+ group = parser.add_argument_group(title='initialization')
620
+
621
+ group.add_argument('--seed', type=int, default=1234,
622
+ help='Random seed used for python, numpy, '
623
+ 'pytorch, and cuda.')
624
+ group.add_argument('--data-parallel-random-init', action='store_true',
625
+ help='Enable random initialization of params '
626
+ 'across data parallel ranks')
627
+ group.add_argument('--init-method-std', type=float, default=0.02,
628
+ help='Standard deviation of the zero mean normal '
629
+ 'distribution used for weight initialization.')
630
+ group.add_argument('--init-method-xavier-uniform', action='store_true',
631
+ help='Enable Xavier uniform parameter initialization')
632
+
633
+ return parser
634
+
635
+
636
+ def _add_learning_rate_args(parser):
637
+ group = parser.add_argument_group(title='learning rate')
638
+
639
+ group.add_argument('--lr', type=float, default=None,
640
+ help='Initial learning rate. Depending on decay style '
641
+ 'and initial warmup, the learing rate at each '
642
+ 'iteration would be different.')
643
+ group.add_argument('--lr-decay-style', type=str, default='linear',
644
+ choices=['constant', 'linear', 'cosine'],
645
+ help='Learning rate decay function.')
646
+ group.add_argument('--lr-decay-iters', type=int, default=None,
647
+ help='number of iterations to decay learning rate over,'
648
+ ' If None defaults to `--train-iters`')
649
+ group.add_argument('--lr-decay-samples', type=int, default=None,
650
+ help='number of samples to decay learning rate over,'
651
+ ' If None defaults to `--train-samples`')
652
+ group.add_argument('--lr-warmup-fraction', type=float, default=None,
653
+ help='fraction of lr-warmup-(iters/samples) to use '
654
+ 'for warmup (as a float)')
655
+ group.add_argument('--lr-warmup-iters', type=int, default=0,
656
+ help='number of iterations to linearly warmup '
657
+ 'learning rate over.')
658
+ group.add_argument('--lr-warmup-samples', type=int, default=0,
659
+ help='number of samples to linearly warmup '
660
+ 'learning rate over.')
661
+ group.add_argument('--warmup', type=int, default=None,
662
+ help='Old lr warmup argument, do not use. Use one of the'
663
+ '--lr-warmup-* arguments above')
664
+ group.add_argument('--min-lr', type=float, default=0.0,
665
+ help='Minumum value for learning rate. The scheduler'
666
+ 'clip values below this threshold.')
667
+ group.add_argument('--override-opt_param-scheduler', action='store_true',
668
+ help='Reset the values of the scheduler (learning rate,'
669
+ 'warmup iterations, minimum learning rate, maximum '
670
+ 'number of iterations, and decay style from input '
671
+ 'arguments and ignore values from checkpoints. Note'
672
+ 'that all the above values will be reset.')
673
+ group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
674
+ help='Use checkpoint to set the values of the scheduler '
675
+ '(learning rate, warmup iterations, minimum learning '
676
+ 'rate, maximum number of iterations, and decay style '
677
+ 'from checkpoint and ignore input arguments.')
678
+
679
+ return parser
680
+
681
+
682
+ def _add_checkpointing_args(parser):
683
+ group = parser.add_argument_group(title='checkpointing')
684
+
685
+ group.add_argument('--save', type=str, default=None,
686
+ help='Output directory to save checkpoints to.')
687
+ group.add_argument('--save-interval', type=int, default=None,
688
+ help='Number of iterations between checkpoint saves.')
689
+ group.add_argument('--no-save-optim', action='store_true', default=None,
690
+ help='Do not save current optimizer.')
691
+ group.add_argument('--no-save-rng', action='store_true', default=None,
692
+ help='Do not save current rng state.')
693
+ group.add_argument('--load', type=str, default=None,
694
+ help='Directory containing a model checkpoint.')
695
+ group.add_argument('--no-load-optim', action='store_true', default=None,
696
+ help='Do not load optimizer when loading checkpoint.')
697
+ group.add_argument('--no-load-rng', action='store_true', default=None,
698
+ help='Do not load rng state when loading checkpoint.')
699
+ group.add_argument('--finetune', action='store_true',
700
+ help='Load model for finetuning. Do not load optimizer '
701
+ 'or rng state from checkpoint and set iteration to 0. '
702
+ 'Assumed when loading a release checkpoint.')
703
+ group.add_argument('--no-initialization', action='store_false',
704
+ help='Do not perform initialization when building model, '
705
+ 'can reduce startup time when definitely loading from a '
706
+ 'checkpoint',
707
+ dest='perform_initialization')
708
+ group.add_argument('--use-checkpoint-args', action='store_true',
709
+ help='Override any command line arguments with arguments '
710
+ 'from the checkpoint')
711
+
712
+ return parser
713
+
714
+
715
+ def _add_mixed_precision_args(parser):
716
+ group = parser.add_argument_group(title='mixed precision')
717
+
718
+ group.add_argument('--fp16', action='store_true',
719
+ help='Run model in fp16 mode.')
720
+ group.add_argument('--bf16', action='store_true',
721
+ help='Run model in bfloat16 mode.')
722
+ group.add_argument('--loss-scale', type=float, default=None,
723
+ help='Static loss scaling, positive power of 2 '
724
+ 'values can improve fp16 convergence. If None, dynamic'
725
+ 'loss scaling is used.')
726
+ group.add_argument('--initial-loss-scale', type=float, default=2**32,
727
+ help='Initial loss-scale for dynamic loss scaling.')
728
+ group.add_argument('--min-loss-scale', type=float, default=1.0,
729
+ help='Minimum loss scale for dynamic loss scale.')
730
+ group.add_argument('--loss-scale-window', type=float, default=1000,
731
+ help='Window over which to raise/lower dynamic scale.')
732
+ group.add_argument('--hysteresis', type=int, default=2,
733
+ help='hysteresis for dynamic loss scaling')
734
+ group.add_argument('--fp32-residual-connection', action='store_true',
735
+ help='Move residual connections to fp32.')
736
+ group.add_argument('--no-query-key-layer-scaling', action='store_false',
737
+ help='Do not scale Q * K^T by 1 / layer-number.',
738
+ dest='apply_query_key_layer_scaling')
739
+ group.add_argument('--attention-softmax-in-fp32', action='store_true',
740
+ help='Run attention masking and softmax in fp32. '
741
+ 'This flag is ignored unless '
742
+ '--no-query-key-layer-scaling is specified.')
743
+ group.add_argument('--accumulate-allreduce-grads-in-fp32',
744
+ action='store_true',
745
+ help='Gradient accumulation and all-reduce in fp32.')
746
+ group.add_argument('--fp16-lm-cross-entropy', action='store_true',
747
+ help='Move the cross entropy unreduced loss calculation'
748
+ 'for lm head to fp16.')
749
+
750
+ return parser
751
+
752
+
753
+ def _add_distributed_args(parser):
754
+ group = parser.add_argument_group(title='distributed')
755
+
756
+ group.add_argument('--tensor-model-parallel-size', type=int, default=1,
757
+ help='Degree of tensor model parallelism.')
758
+ group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
759
+ help='Degree of pipeline model parallelism.')
760
+ group.add_argument('--pipeline-model-parallel-split-rank',
761
+ type=int, default=None,
762
+ help='Rank where encoder and decoder should be split.')
763
+ group.add_argument('--model-parallel-size', type=int, default=None,
764
+ help='Old model parallel argument, do not use. Use '
765
+ '--tensor-model-parallel-size instead.')
766
+ group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
767
+ help='Number of layers per virtual pipeline stage')
768
+ group.add_argument('--distributed-backend', default='nccl',
769
+ choices=['nccl', 'gloo'],
770
+ help='Which backend to use for distributed training.')
771
+ group.add_argument('--DDP-impl', default='local',
772
+ choices=['local', 'torch'],
773
+ help='which DistributedDataParallel implementation '
774
+ 'to use.')
775
+ group.add_argument('--no-contiguous-buffers-in-local-ddp',
776
+ action='store_false', help='If set, dont use '
777
+ 'contiguous buffer in local DDP.',
778
+ dest='use_contiguous_buffers_in_local_ddp')
779
+ group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
780
+ help='Use scatter/gather to optimize communication of tensors in pipeline',
781
+ dest='scatter_gather_tensors_in_pipeline')
782
+ group.add_argument('--local_rank', type=int, default=None,
783
+ help='local rank passed from distributed launcher.')
784
+ group.add_argument('--lazy-mpu-init', type=bool, required=False,
785
+ help='If set to True, initialize_megatron() '
786
+ 'skips DDP initialization and returns function to '
787
+ 'complete it instead.Also turns on '
788
+ '--use-cpu-initialization flag. This is for '
789
+ 'external DDP manager.' )
790
+ group.add_argument('--use-cpu-initialization', action='store_true',
791
+ default=None, help='If set, affine parallel weights '
792
+ 'initialization uses CPU' )
793
+ group.add_argument('--empty-unused-memory-level', default=0, type=int,
794
+ choices=[0, 1, 2],
795
+ help='Call torch.cuda.empty_cache() each iteration '
796
+ '(training and eval), to reduce fragmentation.'
797
+ '0=off, 1=moderate, 2=aggressive.')
798
+ group.add_argument('--standalone-embedding-stage', action='store_true',
799
+ default=False, help='If set, *input* embedding layer '
800
+ 'is placed on its own pipeline stage, without any '
801
+ 'transformer layers. (For T5, this flag currently only '
802
+ 'affects the encoder embedding.)')
803
+ group.add_argument('--use-distributed-optimizer', action='store_true',
804
+ help='Use distributed optimizer.')
805
+
806
+ return parser
807
+
808
+
809
+ def _add_validation_args(parser):
810
+ group = parser.add_argument_group(title='validation')
811
+
812
+ group.add_argument('--eval-iters', type=int, default=100,
813
+ help='Number of iterations to run for evaluation'
814
+ 'validation/test for.')
815
+ group.add_argument('--eval-interval', type=int, default=1000,
816
+ help='Interval between running evaluation on '
817
+ 'validation set.')
818
+
819
+ return parser
820
+
821
+
822
+ def _add_data_args(parser):
823
+ group = parser.add_argument_group(title='data and dataloader')
824
+
825
+ group.add_argument('--data-path', nargs='*', default=None,
826
+ help='Path to the training dataset. Accepted format:'
827
+ '1) a single data path, 2) multiple datasets in the'
828
+ 'form: dataset1-weight dataset1-path dataset2-weight '
829
+ 'dataset2-path ...')
830
+ group.add_argument('--split', type=str, default='969, 30, 1',
831
+ help='Comma-separated list of proportions for training,'
832
+ ' validation, and test split. For example the split '
833
+ '`90,5,5` will use 90%% of data for training, 5%% for '
834
+ 'validation and 5%% for test.')
835
+ group.add_argument('--vocab-file', type=str, default=None,
836
+ help='Path to the vocab file.')
837
+ group.add_argument('--merge-file', type=str, default=None,
838
+ help='Path to the BPE merge file.')
839
+ group.add_argument('--vocab-extra-ids', type=int, default=0,
840
+ help='Number of additional vocabulary tokens. '
841
+ 'They are used for span masking in the T5 model')
842
+ group.add_argument('--seq-length', type=int, default=None,
843
+ help='Maximum sequence length to process.')
844
+ group.add_argument('--encoder-seq-length', type=int, default=None,
845
+ help='Maximum encoder sequence length to process.'
846
+ 'This should be exclusive of --seq-length')
847
+ group.add_argument('--decoder-seq-length', type=int, default=None,
848
+ help="Maximum decoder sequence length to process.")
849
+ group.add_argument('--retriever-seq-length', type=int, default=256,
850
+ help='Maximum sequence length for the biencoder model '
851
+ ' for retriever')
852
+ group.add_argument('--sample-rate', type=float, default=1.0,
853
+ help='sample rate for training data. Supposed to be 0 '
854
+ ' < sample_rate < 1')
855
+ group.add_argument('--mask-prob', type=float, default=0.15,
856
+ help='Probability of replacing a token with mask.')
857
+ group.add_argument('--short-seq-prob', type=float, default=0.1,
858
+ help='Probability of producing a short sequence.')
859
+ group.add_argument('--mmap-warmup', action='store_true',
860
+ help='Warm up mmap files.')
861
+ group.add_argument('--num-workers', type=int, default=2,
862
+ help="Dataloader number of workers.")
863
+ group.add_argument('--tokenizer-type', type=str,
864
+ default=None,
865
+ choices=['BertWordPieceLowerCase',
866
+ 'BertWordPieceCase',
867
+ 'GPT2BPETokenizer'],
868
+ help='What type of tokenizer to use.')
869
+ group.add_argument('--data-impl', type=str, default='infer',
870
+ choices=['lazy', 'cached', 'mmap', 'infer'],
871
+ help='Implementation of indexed datasets.')
872
+ group.add_argument('--reset-position-ids', action='store_true',
873
+ help='Reset posistion ids after end-of-document token.')
874
+ group.add_argument('--reset-attention-mask', action='store_true',
875
+ help='Reset self attention maske after '
876
+ 'end-of-document token.')
877
+ group.add_argument('--eod-mask-loss', action='store_true',
878
+ help='Mask loss for the end of document tokens.')
879
+
880
+ return parser
881
+
882
+
883
+ def _add_autoresume_args(parser):
884
+ group = parser.add_argument_group(title='autoresume')
885
+
886
+ group.add_argument('--adlr-autoresume', action='store_true',
887
+ help='Enable autoresume on adlr cluster.')
888
+ group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
889
+ help='Intervals over which check for autoresume'
890
+ 'termination signal')
891
+
892
+ return parser
893
+
894
+
895
+ def _add_biencoder_args(parser):
896
+ group = parser.add_argument_group(title='biencoder')
897
+
898
+ # network size
899
+ group.add_argument('--ict-head-size', type=int, default=None,
900
+ help='Size of block embeddings to be used in ICT and '
901
+ 'REALM (paper default: 128)')
902
+ group.add_argument('--biencoder-projection-dim', type=int, default=0,
903
+ help='Size of projection head used in biencoder (paper'
904
+ ' default: 128)')
905
+ group.add_argument('--biencoder-shared-query-context-model', action='store_true',
906
+ help='Whether to share the parameters of the query '
907
+ 'and context models or not')
908
+
909
+ # checkpointing
910
+ group.add_argument('--ict-load', type=str, default=None,
911
+ help='Directory containing an ICTBertModel checkpoint')
912
+ group.add_argument('--bert-load', type=str, default=None,
913
+ help='Directory containing an BertModel checkpoint '
914
+ '(needed to start ICT and REALM)')
915
+
916
+ # data
917
+ group.add_argument('--titles-data-path', type=str, default=None,
918
+ help='Path to titles dataset used for ICT')
919
+ group.add_argument('--query-in-block-prob', type=float, default=0.1,
920
+ help='Probability of keeping query in block for '
921
+ 'ICT dataset')
922
+ group.add_argument('--use-one-sent-docs', action='store_true',
923
+ help='Whether to use one sentence documents in ICT')
924
+ group.add_argument('--evidence-data-path', type=str, default=None,
925
+ help='Path to Wikipedia Evidence frm DPR paper')
926
+
927
+ # training
928
+ group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
929
+ default=[], help="Which top-k accuracies to report "
930
+ "(e.g. '1 5 20')")
931
+ group.add_argument('--retriever-score-scaling', action='store_true',
932
+ help='Whether to scale retriever scores by inverse '
933
+ 'square root of hidden size')
934
+
935
+ # faiss index
936
+ group.add_argument('--block-data-path', type=str, default=None,
937
+ help='Where to save/load BlockData to/from')
938
+ group.add_argument('--embedding-path', type=str, default=None,
939
+ help='Where to save/load Open-Retrieval Embedding'
940
+ ' data to/from')
941
+
942
+ # indexer
943
+ group.add_argument('--indexer-batch-size', type=int, default=128,
944
+ help='How large of batches to use when doing indexing '
945
+ 'jobs')
946
+ group.add_argument('--indexer-log-interval', type=int, default=1000,
947
+ help='After how many batches should the indexer '
948
+ 'report progress')
949
+ return parser
950
+
951
+
952
+ def _add_vision_args(parser):
953
+ group = parser.add_argument_group(title="vision")
954
+
955
+ # general vision arguements
956
+ group.add_argument('--num-classes', type=int, default=1000,
957
+ help='num of classes in vision classificaiton task')
958
+ group.add_argument('--img-h', type=int, default=224,
959
+ help='Image height for vision classification task')
960
+ group.add_argument('--img-w', type=int, default=224,
961
+ help='Image height for vision classification task')
962
+ group.add_argument('--num-channels', type=int, default=3,
963
+ help='Number of channels in input image data')
964
+ group.add_argument('--patch-dim', type=int, default=16,
965
+ help='patch dimension')
966
+ group.add_argument('--classes-fraction', type=float, default=1.0,
967
+ help='training with fraction of classes.')
968
+ group.add_argument('--data-per-class-fraction', type=float, default=1.0,
969
+ help='training with fraction of data per class.')
970
+ group.add_argument('--no-data-sharding', action='store_false',
971
+ help='Disable data sharding.',
972
+ dest='data_sharding')
973
+ group.add_argument('--head-lr-mult', type=float, default=1.0,
974
+ help='learning rate multiplier for head during finetuning')
975
+
976
+ # pretraining type and backbone selection`
977
+ group.add_argument('--vision-pretraining', action='store_true',
978
+ help='flag to indicate vision pretraining')
979
+ group.add_argument('--vision-pretraining-type', type=str, default='classify',
980
+ choices=['classify', 'inpaint', 'dino'],
981
+ help='pretraining objectives')
982
+ group.add_argument('--vision-backbone-type', type=str, default='vit',
983
+ choices=['vit', 'mit', 'swin'],
984
+ help='backbone types types')
985
+ group.add_argument('--swin-backbone-type', type=str, default='tiny',
986
+ choices=['tiny', 'base', 'h3'],
987
+ help='pretraining objectives')
988
+
989
+ # inpainting arguments
990
+ group.add_argument('--mask-type', type=str, default='random',
991
+ choices=['random', 'row'],
992
+ help='mask types')
993
+ group.add_argument('--mask-factor', type=float, default=1.0,
994
+ help='mask size scaling parameter')
995
+
996
+ # dino arguments
997
+ group.add_argument('--iter-per-epoch', type=int, default=1250,
998
+ help='iterations per epoch')
999
+ group.add_argument('--dino-local-img-size', type=int, default=96,
1000
+ help='Image size for vision classification task')
1001
+ group.add_argument('--dino-local-crops-number', type=int, default=10,
1002
+ help='Number of local crops')
1003
+ group.add_argument('--dino-head-hidden-size', type=int, default=2048,
1004
+ help='Hidden dimension size in dino head')
1005
+ group.add_argument('--dino-bottleneck-size', type=int, default=256,
1006
+ help='Bottle neck dimension in dino head ')
1007
+ group.add_argument('--dino-freeze-last-layer', type=float, default=1,
1008
+ help='Freezing last layer weights')
1009
+ group.add_argument('--dino-norm-last-layer', action='store_true',
1010
+ help='Disable Norm in last layer.')
1011
+ group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
1012
+ help='warump teacher temperature')
1013
+ group.add_argument('--dino-teacher-temp', type=float, default=0.07,
1014
+ help='teacher temperature')
1015
+ group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
1016
+ help='warmup teacher temperaure epochs')
1017
+
1018
+ return parser
megatron/checkpointing.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Input/output checkpointing."""
17
+
18
+ import os
19
+ import random
20
+ import sys
21
+ import numpy as np
22
+
23
+ import torch
24
+
25
+ from megatron import (mpu,
26
+ update_num_microbatches)
27
+ from .global_vars import get_args
28
+ from .utils import (unwrap_model,
29
+ print_rank_0)
30
+
31
+
32
+ _CHECKPOINT_VERSION = None
33
+
34
+ def set_checkpoint_version(value):
35
+ global _CHECKPOINT_VERSION
36
+ if _CHECKPOINT_VERSION is not None:
37
+ assert _CHECKPOINT_VERSION == value, \
38
+ "checkpoint versions do not match"
39
+ _CHECKPOINT_VERSION = value
40
+
41
+ def get_checkpoint_version():
42
+ global _CHECKPOINT_VERSION
43
+ return _CHECKPOINT_VERSION
44
+
45
+ def check_checkpoint_args(checkpoint_args):
46
+ """Ensure fixed arguments for a model are the same for the input
47
+ arguments and the one retrieved from checkpoint."""
48
+ args = get_args()
49
+
50
+ def _compare(arg_name, old_arg_name=None):
51
+ if old_arg_name is not None:
52
+ checkpoint_value = getattr(checkpoint_args, old_arg_name)
53
+ else:
54
+ checkpoint_value = getattr(checkpoint_args, arg_name)
55
+ args_value = getattr(args, arg_name)
56
+ error_message = '{} value from checkpoint ({}) is not equal to the ' \
57
+ 'input argument value ({}).'.format(
58
+ arg_name, checkpoint_value, args_value)
59
+ assert checkpoint_value == args_value, error_message
60
+
61
+ _compare('num_layers')
62
+ _compare('hidden_size')
63
+ _compare('num_attention_heads')
64
+ if args.vocab_file:
65
+ _compare('max_position_embeddings')
66
+ _compare('make_vocab_size_divisible_by')
67
+ _compare('padded_vocab_size')
68
+ _compare('tokenizer_type')
69
+ if args.data_parallel_random_init:
70
+ _compare('data_parallel_random_init')
71
+ if get_checkpoint_version() < 3.0:
72
+ _compare('tensor_model_parallel_size',
73
+ old_arg_name='model_parallel_size')
74
+ if get_checkpoint_version() >= 3.0:
75
+ _compare('tensor_model_parallel_size')
76
+ _compare('pipeline_model_parallel_size')
77
+
78
+ def ensure_directory_exists(filename):
79
+ """Build filename's path if it does not already exists."""
80
+ dirname = os.path.dirname(filename)
81
+ if not os.path.exists(dirname):
82
+ os.makedirs(dirname)
83
+
84
+
85
+ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release=False,
86
+ pipeline_parallel=None, tensor_rank=None, pipeline_rank=None):
87
+ """Determine the directory name for this rank's checkpoint."""
88
+ if release:
89
+ directory = 'release'
90
+ else:
91
+ directory = 'iter_{:07d}'.format(iteration)
92
+
93
+ # Use both the tensor and pipeline MP rank.
94
+ if pipeline_parallel is None:
95
+ pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
96
+ if tensor_rank is None:
97
+ tensor_rank = mpu.get_tensor_model_parallel_rank()
98
+ if pipeline_rank is None:
99
+ pipeline_rank = mpu.get_pipeline_model_parallel_rank()
100
+
101
+ # Use both the tensor and pipeline MP rank. If using the distributed
102
+ # optimizer, then the optimizer's path must additionally include the
103
+ # data parallel rank.
104
+ if not pipeline_parallel:
105
+ common_path = os.path.join(checkpoints_path, directory,
106
+ f'mp_rank_{tensor_rank:02d}')
107
+ else:
108
+ common_path = os.path.join(checkpoints_path, directory,
109
+ f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
110
+
111
+ if use_distributed_optimizer:
112
+ model_name = os.path.join(common_path, "model_rng.pt")
113
+ optim_name = os.path.join(
114
+ common_path + "_%03d" % mpu.get_data_parallel_rank(),
115
+ "optim.pt")
116
+ else:
117
+ model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
118
+ return model_name, optim_name
119
+
120
+ def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimizer, release=False):
121
+ """Finds the checkpoint for rank 0 without knowing if we are using
122
+ pipeline parallelism or not.
123
+
124
+ Since the checkpoint naming scheme changes if pipeline parallelism
125
+ is present, we need to look for both naming schemes if we don't
126
+ know if the checkpoint has pipeline parallelism.
127
+
128
+ """
129
+
130
+ # Look for checkpoint with no pipelining
131
+ filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
132
+ pipeline_parallel=False,
133
+ tensor_rank=0, pipeline_rank=0)
134
+ if os.path.isfile(filenames[0]):
135
+ return filenames
136
+
137
+ # Look for checkpoint with pipelining
138
+ filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
139
+ pipeline_parallel=True,
140
+ tensor_rank=0, pipeline_rank=0)
141
+ if os.path.isfile(filenames[0]):
142
+ return filenames
143
+
144
+ return None, None
145
+
146
+ def get_checkpoint_tracker_filename(checkpoints_path):
147
+
148
+ """Tracker file rescords the latest chckpoint during
149
+ training to restart from."""
150
+ return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
151
+
152
+
153
+ def read_metadata(tracker_filename):
154
+ # Read the tracker file and either set the iteration or
155
+ # mark it as a release checkpoint.
156
+ iteration = 0
157
+ release = False
158
+ with open(tracker_filename, 'r') as f:
159
+ metastring = f.read().strip()
160
+ try:
161
+ iteration = int(metastring)
162
+ except ValueError:
163
+ release = metastring == 'release'
164
+ if not release:
165
+ print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
166
+ tracker_filename))
167
+ sys.exit()
168
+ assert iteration > 0 or release, 'error parsing metadata file {}'.format(
169
+ tracker_filename)
170
+
171
+ # Get the max iteration retrieved across the ranks.
172
+ if torch.distributed.is_initialized():
173
+ iters_cuda = torch.cuda.LongTensor([iteration])
174
+ torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
175
+ max_iter = iters_cuda[0].item()
176
+
177
+ # We should now have all the same iteration.
178
+ # If not, print a warning and chose the maximum
179
+ # iteration across all ranks.
180
+ if iteration != max_iter:
181
+ print('WARNING: on rank {} found iteration {} in the '
182
+ 'metadata while max iteration across the ranks '
183
+ 'is {}, replacing it with max iteration.'.format(
184
+ rank, iteration, max_iter), flush=True)
185
+ else:
186
+ # When loading a checkpoint outside of training (for example,
187
+ # when editing it), we might not have torch distributed
188
+ # initialized, in this case, just assume we have the latest
189
+ max_iter = iteration
190
+ return max_iter, release
191
+
192
+
193
+ def get_rng_state():
194
+ """ collect rng state across data parallel ranks """
195
+ args = get_args()
196
+ rng_state = {
197
+ 'random_rng_state': random.getstate(),
198
+ 'np_rng_state': np.random.get_state(),
199
+ 'torch_rng_state': torch.get_rng_state(),
200
+ 'cuda_rng_state': torch.cuda.get_rng_state(),
201
+ 'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}
202
+
203
+ rng_state_list = None
204
+ if torch.distributed.is_initialized() and \
205
+ mpu.get_data_parallel_world_size() > 1 and \
206
+ args.data_parallel_random_init:
207
+ rng_state_list = \
208
+ [None for i in range(mpu.get_data_parallel_world_size())]
209
+ torch.distributed.all_gather_object(
210
+ rng_state_list,
211
+ rng_state,
212
+ group=mpu.get_data_parallel_group())
213
+ else:
214
+ rng_state_list = [rng_state]
215
+
216
+ return rng_state_list
217
+
218
+
219
+ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
220
+ """Save a model checkpoint."""
221
+ args = get_args()
222
+
223
+ # Only rank zero of the data parallel writes to the disk.
224
+ model = unwrap_model(model)
225
+
226
+ print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
227
+ iteration, args.save))
228
+
229
+ # Collect rng state across data parallel ranks.
230
+ rng_state = get_rng_state()
231
+
232
+ # Checkpoint file names.
233
+ model_checkpoint_name, optim_checkpoint_name = \
234
+ get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer)
235
+
236
+ # Collect args, model, RNG.
237
+ model_state_dict = {}
238
+ if not torch.distributed.is_initialized() \
239
+ or mpu.get_data_parallel_rank() == 0:
240
+
241
+ # Arguments, iteration, and model.
242
+ model_state_dict['args'] = args
243
+ model_state_dict['checkpoint_version'] = 3.0
244
+ model_state_dict['iteration'] = iteration
245
+ if len(model) == 1:
246
+ model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
247
+ else:
248
+ for i in range(len(model)):
249
+ mpu.set_virtual_pipeline_model_parallel_rank(i)
250
+ model_state_dict['model%d' % i] = \
251
+ model[i].state_dict_for_save_checkpoint()
252
+
253
+ # RNG states.
254
+ if not args.no_save_rng:
255
+ model_state_dict["rng_state"] = rng_state
256
+
257
+ # Collect optimizer state. (Optimizer is saved separately from the model, due
258
+ # to the conflicting data pattern when using the distributed optimizer.)
259
+ optim_state_dict = {}
260
+ if not args.no_save_optim \
261
+ and (not torch.distributed.is_initialized()
262
+ or mpu.get_data_parallel_rank() == 0
263
+ or args.use_distributed_optimizer):
264
+
265
+ # Optimizer stuff.
266
+ if optimizer is not None:
267
+ optim_state_dict['optimizer'] = optimizer.state_dict()
268
+ if opt_param_scheduler is not None:
269
+ optim_state_dict['opt_param_scheduler'] = \
270
+ opt_param_scheduler.state_dict()
271
+
272
+ # Save.
273
+ if args.use_distributed_optimizer:
274
+ # Save model separate from optimizer.
275
+ if model_state_dict:
276
+ ensure_directory_exists(model_checkpoint_name)
277
+ torch.save(model_state_dict, model_checkpoint_name)
278
+ if optim_state_dict:
279
+ ensure_directory_exists(optim_checkpoint_name)
280
+ torch.save(optim_state_dict, optim_checkpoint_name)
281
+ else:
282
+ # Save model and optimizer together.
283
+ state_dict = {**model_state_dict, **optim_state_dict}
284
+ if state_dict: # only saves if populated (i.e., inherits conditions above)
285
+ ensure_directory_exists(model_checkpoint_name)
286
+ torch.save(state_dict, model_checkpoint_name)
287
+
288
+ # Wait so everyone is done (necessary)
289
+ if torch.distributed.is_initialized():
290
+ torch.distributed.barrier()
291
+
292
+ print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}'.format(
293
+ iteration, args.save))
294
+
295
+ # And update the latest iteration
296
+ if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
297
+ tracker_filename = get_checkpoint_tracker_filename(args.save)
298
+ with open(tracker_filename, 'w') as f:
299
+ f.write(str(iteration))
300
+
301
+ # Wait so everyone is done (not necessary)
302
+ if torch.distributed.is_initialized():
303
+ torch.distributed.barrier()
304
+
305
+ def _transpose_first_dim(t, num_splits, num_splits_first, model):
306
+ input_shape = t.size()
307
+ # We use a self_attention module but the values extracted aren't
308
+ # specific to self attention so should work for cross attention as well
309
+ while hasattr(model, 'module'):
310
+ model = model.module
311
+ attention_module = model.language_model.encoder.layers[0].self_attention
312
+ hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
313
+ num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
314
+ if num_splits_first:
315
+ """[num_splits * np * hn, h]
316
+ -->(view) [num_splits, np, hn, h]
317
+ -->(tranpose) [np, num_splits, hn, h]
318
+ -->(view) [np * num_splits * hn, h] """
319
+
320
+ intermediate_shape = \
321
+ (num_splits, num_attention_heads_per_partition,
322
+ hidden_size_per_attention_head) + input_shape[1:]
323
+
324
+ t = t.view(*intermediate_shape)
325
+ t = t.transpose(0, 1).contiguous()
326
+ else:
327
+ """[np * hn * num_splits, h]
328
+ -->(view) [np, hn, num_splits, h]
329
+ -->(tranpose) [np, num_splits, hn, h]
330
+ -->(view) [np * num_splits * hn, h] """
331
+
332
+ intermediate_shape = \
333
+ (num_attention_heads_per_partition,
334
+ hidden_size_per_attention_head, num_splits) +\
335
+ input_shape[1:]
336
+
337
+ t = t.view(*intermediate_shape)
338
+ t = t.transpose(1, 2).contiguous()
339
+ t = t.view(*input_shape)
340
+
341
+ return t
342
+
343
+ def fix_query_key_value_ordering(model, checkpoint_version):
344
+ """Fix up query/key/value matrix ordering if checkpoint
345
+ version is smaller than 2.0
346
+ """
347
+ if checkpoint_version < 2.0:
348
+ if isinstance(model, list):
349
+ assert len(model)==1
350
+ model = model[0]
351
+ for name, param in model.named_parameters():
352
+ if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
353
+ if checkpoint_version == 0:
354
+ fixed_param = _transpose_first_dim(param.data, 3, True, model)
355
+ elif checkpoint_version == 1.0:
356
+ fixed_param = _transpose_first_dim(param.data, 3, False, model)
357
+ else:
358
+ print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
359
+ sys.exit()
360
+ param.data.copy_(fixed_param)
361
+ if name.endswith(('.key_value.weight', '.key_value.bias')):
362
+ if checkpoint_version == 0:
363
+ fixed_param = _transpose_first_dim(param.data, 2, True, model)
364
+ elif checkpoint_version == 1.0:
365
+ fixed_param = _transpose_first_dim(param.data, 2, False, model)
366
+ else:
367
+ print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
368
+ sys.exit()
369
+ param.data.copy_(fixed_param)
370
+ print_rank_0(" succesfully fixed query-key-values ordering for"
371
+ " checkpoint version {}".format(checkpoint_version))
372
+
373
+ def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False):
374
+ """ Load the base state_dict from the given directory
375
+
376
+ If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
377
+ """
378
+
379
+
380
+ # Read the tracker file and set the iteration.
381
+ tracker_filename = get_checkpoint_tracker_filename(load_dir)
382
+
383
+ # If no tracker file, return nothing
384
+ if not os.path.isfile(tracker_filename):
385
+ if not rank0:
386
+ print_rank_0('WARNING: could not find the metadata file {} '.format(
387
+ tracker_filename))
388
+ print_rank_0(' will not load any checkpoints and will start from '
389
+ 'random')
390
+ return None, None, False
391
+
392
+ # Otherwise, read the tracker file and either set the iteration or
393
+ # mark it as a release checkpoint.
394
+ iteration, release = read_metadata(tracker_filename)
395
+
396
+ # Checkpoint.
397
+ if rank0:
398
+ checkpoint_names = find_checkpoint_rank_0(load_dir, iteration, use_distributed_optimizer,
399
+ release)
400
+ else:
401
+ checkpoint_names = get_checkpoint_names(load_dir, iteration, use_distributed_optimizer,
402
+ release)
403
+ if release:
404
+ print_rank_0(f' loading release checkpoint from {load_dir}')
405
+ else:
406
+ print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}')
407
+
408
+ model_checkpoint_name, optim_checkpoint_name = checkpoint_names
409
+
410
+ # Load the checkpoint.
411
+ try:
412
+ model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
413
+ if use_distributed_optimizer:
414
+ optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
415
+ else:
416
+ optim_state_dict = model_state_dict
417
+ except ModuleNotFoundError:
418
+ from megatron.fp16_deprecated import loss_scaler
419
+ # For backward compatibility.
420
+ if not rank0:
421
+ print_rank_0(' > deserializing using the old code structure ...')
422
+ sys.modules['fp16.loss_scaler'] = sys.modules[
423
+ 'megatron.fp16_deprecated.loss_scaler']
424
+ sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
425
+ 'megatron.fp16_deprecated.loss_scaler']
426
+ model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
427
+ optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
428
+ sys.modules.pop('fp16.loss_scaler', None)
429
+ sys.modules.pop('megatron.fp16.loss_scaler', None)
430
+ except BaseException as e:
431
+ print_rank_0('could not load the checkpoint')
432
+ print_rank_0(e)
433
+ sys.exit()
434
+
435
+ return model_state_dict, optim_state_dict, release
436
+
437
+ def load_args_from_checkpoint(args, load_arg='load'):
438
+ """Set required arguments from the checkpoint specified in the
439
+ arguments.
440
+
441
+ Will overwrite arguments that have a non-None default value, but
442
+ will leave any arguments that default to None as set.
443
+
444
+ Returns the same args NameSpace with the new values added/updated.
445
+
446
+ If no checkpoint is specified in args, or if the checkpoint is
447
+ there but invalid, the arguments will not be modified
448
+
449
+ """
450
+ load_dir = getattr(args, load_arg)
451
+
452
+ if load_dir is None:
453
+ print_rank_0('No load directory specified, using provided arguments.')
454
+ return args
455
+
456
+ model_state_dict, optim_state_dict, release = \
457
+ _load_base_checkpoint(load_dir,
458
+ use_distributed_optimizer=args.use_distributed_optimizer,
459
+ rank0=True)
460
+
461
+ # For args we only care about model state dict
462
+ state_dict = model_state_dict
463
+
464
+ if not state_dict:
465
+ print_rank_0('Checkpoint not found to provide arguments, using provided arguments.')
466
+ return args
467
+
468
+ if 'args' not in state_dict:
469
+ print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.')
470
+ return args
471
+
472
+ checkpoint_args = state_dict['args']
473
+ checkpoint_version = state_dict.get('checkpoint_version', 0)
474
+ args.iteration = state_dict['iteration']
475
+
476
+ def _set_arg(arg_name, old_arg_name=None, force=False):
477
+ if not force and getattr(args, arg_name, None) is not None:
478
+ return
479
+
480
+ if old_arg_name is not None:
481
+ checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
482
+ else:
483
+ checkpoint_value = getattr(checkpoint_args, arg_name, None)
484
+
485
+ if checkpoint_value is not None:
486
+ print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
487
+ setattr(args, arg_name, checkpoint_value)
488
+
489
+ _set_arg('num_layers')
490
+ _set_arg('hidden_size')
491
+ _set_arg('ffn_hidden_size')
492
+ _set_arg('seq_length')
493
+ _set_arg('num_attention_heads')
494
+ _set_arg('kv_channels')
495
+ _set_arg('max_position_embeddings')
496
+ _set_arg('tokenizer_type')
497
+ _set_arg('padded_vocab_size')
498
+ if checkpoint_version < 3.0:
499
+ _set_arg('tensor_model_parallel_size',
500
+ 'model_parallel_size')
501
+ else:
502
+ _set_arg('tensor_model_parallel_size', force=True)
503
+ _set_arg('pipeline_model_parallel_size', force=True)
504
+ _set_arg('num_layers_per_virtual_pipeline_stage')
505
+ return args
506
+
507
+
508
+ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
509
+ """Load a model checkpoint and return the iteration.
510
+ strict (bool): whether to strictly enforce that the keys in
511
+ :attr:`state_dict` of the checkpoint match the names of
512
+ parameters and buffers in model.
513
+ """
514
+ args = get_args()
515
+ load_dir = getattr(args, load_arg)
516
+
517
+ model = unwrap_model(model)
518
+
519
+ model_state_dict, optim_state_dict, release = \
520
+ _load_base_checkpoint(load_dir,
521
+ use_distributed_optimizer=args.use_distributed_optimizer,
522
+ rank0=False)
523
+
524
+ if model_state_dict is None:
525
+ return 0
526
+
527
+ # set checkpoint version
528
+ set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
529
+
530
+ # Set iteration.
531
+ if args.finetune or release:
532
+ iteration = 0
533
+ else:
534
+ try:
535
+ iteration = model_state_dict['iteration']
536
+ except KeyError:
537
+ try: # Backward compatible with older checkpoints
538
+ iteration = model_state_dict['total_iters']
539
+ except KeyError:
540
+ print_rank_0('A metadata file exists but unable to load '
541
+ 'iteration from checkpoint {}, exiting'.format(
542
+ checkpoint_name))
543
+ sys.exit()
544
+
545
+ # Check arguments.
546
+ assert args.consumed_train_samples == 0
547
+ assert args.consumed_valid_samples == 0
548
+ if 'args' in model_state_dict:
549
+ checkpoint_args = model_state_dict['args']
550
+ check_checkpoint_args(checkpoint_args)
551
+ args.consumed_train_samples = getattr(checkpoint_args,
552
+ 'consumed_train_samples', 0)
553
+ update_num_microbatches(consumed_samples=args.consumed_train_samples)
554
+ args.consumed_valid_samples = getattr(checkpoint_args,
555
+ 'consumed_valid_samples', 0)
556
+ else:
557
+ print_rank_0('could not find arguments in the checkpoint ...')
558
+
559
+ # Model.
560
+ if len(model) == 1:
561
+ model[0].load_state_dict(model_state_dict['model'], strict=strict)
562
+ else:
563
+ for i in range(len(model)):
564
+ mpu.set_virtual_pipeline_model_parallel_rank(i)
565
+ model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
566
+
567
+ # Fix up query/key/value matrix ordering if needed
568
+ checkpoint_version = get_checkpoint_version()
569
+ print_rank_0(f' checkpoint version {checkpoint_version}')
570
+ fix_query_key_value_ordering(model, checkpoint_version)
571
+
572
+ # Optimizer.
573
+ if not release and not args.finetune and not args.no_load_optim:
574
+ try:
575
+ if optimizer is not None:
576
+ optimizer.load_state_dict(optim_state_dict['optimizer'])
577
+ if opt_param_scheduler is not None:
578
+ if 'lr_scheduler' in optim_state_dict: # backward compatbility
579
+ opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler'])
580
+ else:
581
+ opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler'])
582
+ except KeyError:
583
+ print_rank_0('Unable to load optimizer from checkpoint {}. '
584
+ 'Specify --no-load-optim or --finetune to prevent '
585
+ 'attempting to load the optimizer state, '
586
+ 'exiting ...'.format(checkpoint_name))
587
+ sys.exit()
588
+
589
+ # rng states.
590
+ if not release and not args.finetune and not args.no_load_rng:
591
+ try:
592
+ if 'rng_state' in model_state_dict:
593
+ # access rng_state for data parallel rank
594
+ if args.data_parallel_random_init:
595
+
596
+ rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
597
+ else:
598
+ rng_state = model_state_dict['rng_state'][0]
599
+ random.setstate(rng_state['random_rng_state'])
600
+ np.random.set_state(rng_state['np_rng_state'])
601
+ torch.set_rng_state(rng_state['torch_rng_state'])
602
+ torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
603
+ # Check for empty states array
604
+ if not rng_state['rng_tracker_states']:
605
+ raise KeyError
606
+ mpu.get_cuda_rng_tracker().set_states(
607
+ rng_state['rng_tracker_states'])
608
+ else: # backward compatability
609
+ random.setstate(model_state_dict['random_rng_state'])
610
+ np.random.set_state(model_state_dict['np_rng_state'])
611
+ torch.set_rng_state(model_state_dict['torch_rng_state'])
612
+ torch.cuda.set_rng_state(model_state_dict['cuda_rng_state'])
613
+ # Check for empty states array
614
+ if not model_state_dict['rng_tracker_states']:
615
+ raise KeyError
616
+ mpu.get_cuda_rng_tracker().set_states(
617
+ model_state_dict['rng_tracker_states'])
618
+ except KeyError:
619
+ print_rank_0('Unable to load rng state from checkpoint {}. '
620
+ 'Specify --no-load-rng or --finetune to prevent '
621
+ 'attempting to load the rng state, '
622
+ 'exiting ...'.format(checkpoint_name))
623
+ sys.exit()
624
+
625
+ # Some utilities want to load a checkpoint without distributed being initialized
626
+ if torch.distributed.is_initialized():
627
+ torch.distributed.barrier()
628
+
629
+ print_rank_0(f' successfully loaded checkpoint from {args.load} '
630
+ f'at iteration {iteration}')
631
+
632
+ return iteration
633
+
634
+
635
+ def load_biencoder_checkpoint(model, only_query_model=False,
636
+ only_context_model=False, custom_load_path=None):
637
+ """
638
+ selectively load retrieval models for indexing/retrieving
639
+ from saved checkpoints
640
+ """
641
+
642
+ args = get_args()
643
+
644
+ model = unwrap_model(model)
645
+
646
+ load_path = custom_load_path if custom_load_path is not None else args.load
647
+
648
+ tracker_filename = get_checkpoint_tracker_filename(load_path)
649
+ with open(tracker_filename, 'r') as f:
650
+ iteration = int(f.read().strip())
651
+
652
+ checkpoint_name, _ = get_checkpoint_names(load_path, iteration,
653
+ args.use_distributed_optimizer,
654
+ release=False)
655
+
656
+ if mpu.get_data_parallel_rank() == 0:
657
+ print('global rank {} is loading checkpoint {}'.format(
658
+ torch.distributed.get_rank(), checkpoint_name))
659
+
660
+ state_dict = torch.load(model_checkpoint_name, map_location='cpu')
661
+ ret_state_dict = state_dict['model']
662
+
663
+ if only_query_model:
664
+ ret_state_dict.pop('context_model')
665
+ if only_context_model:
666
+ ret_state_dict.pop('query_model')
667
+
668
+ assert len(model) == 1
669
+ model[0].load_state_dict(ret_state_dict)
670
+ torch.distributed.barrier()
671
+
672
+ if mpu.get_data_parallel_rank() == 0:
673
+ print(' successfully loaded {}'.format(checkpoint_name))
674
+
675
+ return model
megatron/data/Makefile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
2
+ CPPFLAGS += $(shell python3 -m pybind11 --includes)
3
+ LIBNAME = helpers
4
+ LIBEXT = $(shell python3-config --extension-suffix)
5
+
6
+ default: $(LIBNAME)$(LIBEXT)
7
+
8
+ %$(LIBEXT): %.cpp
9
+ $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
megatron/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import indexed_dataset
megatron/data/autoaugment.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AutoAugment data augmentation policy for ImageNet.
2
+
3
+ -- Begin license text.
4
+
5
+ MIT License
6
+
7
+ Copyright (c) 2018 Philip Popien
8
+
9
+ Permission is hereby granted, free of charge, to any person obtaining a copy
10
+ of this software and associated documentation files (the "Software"), to deal
11
+ in the Software without restriction, including without limitation the rights
12
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13
+ copies of the Software, and to permit persons to whom the Software is
14
+ furnished to do so, subject to the following conditions:
15
+
16
+ The above copyright notice and this permission notice shall be included in all
17
+ copies or substantial portions of the Software.
18
+
19
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25
+ SOFTWARE.
26
+
27
+ -- End license text.
28
+
29
+ Code adapted from https://github.com/DeepVoltaire/AutoAugment.
30
+
31
+ This module implements the fixed AutoAugment data augmentation policy for ImageNet provided in
32
+ Appendix A, Table 9 of reference [1]. It does not include any of the search code for augmentation
33
+ policies.
34
+
35
+ Reference:
36
+ [1] https://arxiv.org/abs/1805.09501
37
+ """
38
+
39
+ import random
40
+
41
+ import numpy as np
42
+ from PIL import Image
43
+ from PIL import ImageEnhance
44
+ from PIL import ImageOps
45
+
46
+ _MAX_LEVEL = 10 # Maximum integer strength of an augmentation, if applicable.
47
+
48
+
49
+ class ImageNetPolicy:
50
+ """Definition of an ImageNetPolicy.
51
+
52
+ Implements a fixed AutoAugment data augmentation policy targeted at
53
+ ImageNet training by randomly applying at runtime one of the 25 pre-defined
54
+ data augmentation sub-policies provided in Reference [1].
55
+
56
+ Usage example as a Pytorch Transform:
57
+ >>> transform=transforms.Compose([transforms.Resize(256),
58
+ >>> ImageNetPolicy(),
59
+ >>> transforms.ToTensor()])
60
+ """
61
+
62
+ def __init__(self, fillcolor=(128, 128, 128)):
63
+ """Initialize an ImageNetPolicy.
64
+
65
+ Args:
66
+ fillcolor (tuple): RGB color components of the color to be used for
67
+ filling when needed (default: (128, 128, 128), which
68
+ corresponds to gray).
69
+ """
70
+ # Instantiate a list of sub-policies.
71
+ # Each entry of the list is a SubPolicy which consists of
72
+ # two augmentation operations,
73
+ # each of those parametrized as operation, probability, magnitude.
74
+ # Those two operations are applied sequentially on the image upon call.
75
+ self.policies = [
76
+ SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor),
77
+ SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
78
+ SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
79
+ SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor),
80
+ SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
81
+ SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor),
82
+ SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor),
83
+ SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor),
84
+ SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor),
85
+ SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor),
86
+ SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor),
87
+ SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor),
88
+ SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor),
89
+ SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
90
+ SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
91
+ SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor),
92
+ SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor),
93
+ SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor),
94
+ SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor),
95
+ SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor),
96
+ SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
97
+ SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
98
+ SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
99
+ SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
100
+ SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
101
+ ]
102
+
103
+ def __call__(self, img):
104
+ """Define call method for ImageNetPolicy class."""
105
+ policy_idx = random.randint(0, len(self.policies) - 1)
106
+ return self.policies[policy_idx](img)
107
+
108
+ def __repr__(self):
109
+ """Define repr method for ImageNetPolicy class."""
110
+ return "ImageNetPolicy"
111
+
112
+
113
+ class SubPolicy:
114
+ """Definition of a SubPolicy.
115
+
116
+ A SubPolicy consists of two augmentation operations,
117
+ each of those parametrized as operation, probability, magnitude.
118
+ The two operations are applied sequentially on the image upon call.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ operation1,
124
+ probability1,
125
+ magnitude_idx1,
126
+ operation2,
127
+ probability2,
128
+ magnitude_idx2,
129
+ fillcolor,
130
+ ):
131
+ """Initialize a SubPolicy.
132
+
133
+ Args:
134
+ operation1 (str): Key specifying the first augmentation operation.
135
+ There are fourteen key values altogether (see supported_ops below
136
+ listing supported operations). probability1 (float): Probability
137
+ within [0., 1.] of applying the first augmentation operation.
138
+ magnitude_idx1 (int): Integer specifiying the strength of the first
139
+ operation as an index further used to derive the magnitude from a
140
+ range of possible values.
141
+ operation2 (str): Key specifying the second augmentation operation.
142
+ probability2 (float): Probability within [0., 1.] of applying the
143
+ second augmentation operation.
144
+ magnitude_idx2 (int): Integer specifiying the strength of the
145
+ second operation as an index further used to derive the magnitude
146
+ from a range of possible values.
147
+ fillcolor (tuple): RGB color components of the color to be used for
148
+ filling.
149
+ Returns:
150
+ """
151
+ # List of supported operations for operation1 and operation2.
152
+ supported_ops = [
153
+ "shearX",
154
+ "shearY",
155
+ "translateX",
156
+ "translateY",
157
+ "rotate",
158
+ "color",
159
+ "posterize",
160
+ "solarize",
161
+ "contrast",
162
+ "sharpness",
163
+ "brightness",
164
+ "autocontrast",
165
+ "equalize",
166
+ "invert",
167
+ ]
168
+ assert (operation1 in supported_ops) and (
169
+ operation2 in supported_ops
170
+ ), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation."
171
+
172
+ assert (
173
+ 0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0
174
+ ), "SubPolicy: prob1 and prob2 should be within [0., 1.]."
175
+
176
+ assert (
177
+ isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10
178
+ ), "SubPolicy: idx1 should be specified as an integer within [0, 10]."
179
+
180
+ assert (
181
+ isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10
182
+ ), "SubPolicy: idx2 should be specified as an integer within [0, 10]."
183
+
184
+ # Define a dictionary where each key refers to a specific type of
185
+ # augmentation and the corresponding value is a range of ten possible
186
+ # magnitude values for that augmentation.
187
+ num_levels = _MAX_LEVEL + 1
188
+ ranges = {
189
+ "shearX": np.linspace(0, 0.3, num_levels),
190
+ "shearY": np.linspace(0, 0.3, num_levels),
191
+ "translateX": np.linspace(0, 150 / 331, num_levels),
192
+ "translateY": np.linspace(0, 150 / 331, num_levels),
193
+ "rotate": np.linspace(0, 30, num_levels),
194
+ "color": np.linspace(0.0, 0.9, num_levels),
195
+ "posterize": np.round(np.linspace(8, 4, num_levels), 0).astype(
196
+ np.int
197
+ ),
198
+ "solarize": np.linspace(256, 0, num_levels), # range [0, 256]
199
+ "contrast": np.linspace(0.0, 0.9, num_levels),
200
+ "sharpness": np.linspace(0.0, 0.9, num_levels),
201
+ "brightness": np.linspace(0.0, 0.9, num_levels),
202
+ "autocontrast": [0]
203
+ * num_levels, # This augmentation doesn't use magnitude parameter.
204
+ "equalize": [0]
205
+ * num_levels, # This augmentation doesn't use magnitude parameter.
206
+ "invert": [0]
207
+ * num_levels, # This augmentation doesn't use magnitude parameter.
208
+ }
209
+
210
+ def rotate_with_fill(img, magnitude):
211
+ """Define rotation transformation with fill.
212
+
213
+ The input image is first rotated, then it is blended together with
214
+ a gray mask of the same size. Note that fillcolor as defined
215
+ elsewhere in this module doesn't apply here.
216
+
217
+ Args:
218
+ magnitude (float): rotation angle in degrees.
219
+ Returns:
220
+ rotated_filled (PIL Image): rotated image with gray filling for
221
+ disoccluded areas unveiled by the rotation.
222
+ """
223
+ rotated = img.convert("RGBA").rotate(magnitude)
224
+ rotated_filled = Image.composite(
225
+ rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated
226
+ )
227
+ return rotated_filled.convert(img.mode)
228
+
229
+ # Define a dictionary of augmentation functions where each key refers
230
+ # to a specific type of augmentation and the corresponding value defines
231
+ # the augmentation itself using a lambda function.
232
+ # pylint: disable=unnecessary-lambda
233
+ func_dict = {
234
+ "shearX": lambda img, magnitude: img.transform(
235
+ img.size,
236
+ Image.AFFINE,
237
+ (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
238
+ Image.BICUBIC,
239
+ fillcolor=fillcolor,
240
+ ),
241
+ "shearY": lambda img, magnitude: img.transform(
242
+ img.size,
243
+ Image.AFFINE,
244
+ (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
245
+ Image.BICUBIC,
246
+ fillcolor=fillcolor,
247
+ ),
248
+ "translateX": lambda img, magnitude: img.transform(
249
+ img.size,
250
+ Image.AFFINE,
251
+ (
252
+ 1,
253
+ 0,
254
+ magnitude * img.size[0] * random.choice([-1, 1]),
255
+ 0,
256
+ 1,
257
+ 0,
258
+ ),
259
+ fillcolor=fillcolor,
260
+ ),
261
+ "translateY": lambda img, magnitude: img.transform(
262
+ img.size,
263
+ Image.AFFINE,
264
+ (
265
+ 1,
266
+ 0,
267
+ 0,
268
+ 0,
269
+ 1,
270
+ magnitude * img.size[1] * random.choice([-1, 1]),
271
+ ),
272
+ fillcolor=fillcolor,
273
+ ),
274
+ "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
275
+ "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
276
+ 1 + magnitude * random.choice([-1, 1])
277
+ ),
278
+ "posterize": lambda img, magnitude: ImageOps.posterize(
279
+ img, magnitude
280
+ ),
281
+ "solarize": lambda img, magnitude: ImageOps.solarize(
282
+ img, magnitude
283
+ ),
284
+ "contrast": lambda img, magnitude: ImageEnhance.Contrast(
285
+ img
286
+ ).enhance(1 + magnitude * random.choice([-1, 1])),
287
+ "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(
288
+ img
289
+ ).enhance(1 + magnitude * random.choice([-1, 1])),
290
+ "brightness": lambda img, magnitude: ImageEnhance.Brightness(
291
+ img
292
+ ).enhance(1 + magnitude * random.choice([-1, 1])),
293
+ "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
294
+ "equalize": lambda img, magnitude: ImageOps.equalize(img),
295
+ "invert": lambda img, magnitude: ImageOps.invert(img),
296
+ }
297
+
298
+ # Store probability, function and magnitude of the first augmentation
299
+ # for the sub-policy.
300
+ self.probability1 = probability1
301
+ self.operation1 = func_dict[operation1]
302
+ self.magnitude1 = ranges[operation1][magnitude_idx1]
303
+
304
+ # Store probability, function and magnitude of the second augmentation
305
+ # for the sub-policy.
306
+ self.probability2 = probability2
307
+ self.operation2 = func_dict[operation2]
308
+ self.magnitude2 = ranges[operation2][magnitude_idx2]
309
+
310
+ def __call__(self, img):
311
+ """Define call method for SubPolicy class."""
312
+ # Randomly apply operation 1.
313
+ if random.random() < self.probability1:
314
+ img = self.operation1(img, self.magnitude1)
315
+
316
+ # Randomly apply operation 2.
317
+ if random.random() < self.probability2:
318
+ img = self.operation2(img, self.magnitude2)
319
+
320
+ return img
megatron/data/bert_dataset.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """BERT Style dataset."""
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from megatron import (
22
+ get_args,
23
+ get_tokenizer,
24
+ mpu,
25
+ print_rank_0
26
+ )
27
+ from megatron.data.dataset_utils import (
28
+ get_samples_mapping,
29
+ get_a_and_b_segments,
30
+ truncate_segments,
31
+ create_tokens_and_tokentypes,
32
+ create_masked_lm_predictions
33
+ )
34
+
35
+ class DummyBertDataset(torch.utils.data.Dataset):
36
+ def __init__(self, name, num_samples, max_seq_length):
37
+ self.name = name
38
+ self.num_samples = num_samples
39
+ self.max_seq_length = max_seq_length
40
+ self.np_rng = np.random.RandomState(seed=0)
41
+ # self.token_nps = np_rng.randint(1000, 2000, (self.num_samples, 512))
42
+ # Vocab stuff.
43
+ tokenizer = get_tokenizer()
44
+ self.vocab_id_list = list(tokenizer.inv_vocab.keys())
45
+ self.vocab_id_to_token_dict = tokenizer.inv_vocab
46
+ self.cls_id = tokenizer.cls
47
+ self.sep_id = tokenizer.sep
48
+ self.mask_id = tokenizer.mask
49
+ self.pad_id = tokenizer.pad
50
+
51
+ def __len__(self):
52
+ return self.num_samples
53
+
54
+ def __getitem__(self, idx):
55
+ tokens = self.np_rng.randint(1000, 2000, self.max_seq_length)
56
+ masked_position = np.arange(int(tokens.shape[0] * 0.15))
57
+ tokens = tokens.astype(np.int64)
58
+ labels = tokens[masked_position]
59
+ label_np = np.full_like(tokens, -1)
60
+ label_np[masked_position] = labels
61
+ tokens[masked_position] = self.mask_id
62
+ loss_mask_np = np.zeros_like(tokens)
63
+ loss_mask_np[masked_position] = 1
64
+ train_sample = {
65
+ 'text': tokens,
66
+ 'types': np.zeros_like(tokens),
67
+ 'labels': label_np,
68
+ 'is_random': 0,
69
+ 'loss_mask': loss_mask_np,
70
+ 'padding_mask': np.ones_like(tokens),
71
+ 'truncated': 0
72
+ }
73
+ return train_sample
74
+
75
+ class BertDataset(torch.utils.data.Dataset):
76
+
77
+ def __init__(self, name, indexed_dataset, data_prefix,
78
+ num_epochs, max_num_samples, masked_lm_prob,
79
+ max_seq_length, short_seq_prob, seed, binary_head):
80
+
81
+ # Params to store.
82
+ self.name = name
83
+ self.seed = seed
84
+ self.masked_lm_prob = masked_lm_prob
85
+ self.max_seq_length = max_seq_length
86
+ self.binary_head = binary_head
87
+
88
+ # Dataset.
89
+ self.indexed_dataset = indexed_dataset
90
+
91
+ # Build the samples mapping.
92
+ self.samples_mapping = get_samples_mapping(self.indexed_dataset,
93
+ data_prefix,
94
+ num_epochs,
95
+ max_num_samples,
96
+ self.max_seq_length - 3, # account for added tokens
97
+ short_seq_prob,
98
+ self.seed,
99
+ self.name,
100
+ self.binary_head)
101
+
102
+ # Vocab stuff.
103
+ tokenizer = get_tokenizer()
104
+ self.vocab_id_list = list(tokenizer.inv_vocab.keys())
105
+ self.vocab_id_to_token_dict = tokenizer.inv_vocab
106
+ self.cls_id = tokenizer.cls
107
+ self.sep_id = tokenizer.sep
108
+ self.mask_id = tokenizer.mask
109
+ self.pad_id = tokenizer.pad
110
+
111
+ def __len__(self):
112
+ return self.samples_mapping.shape[0]
113
+
114
+ def __getitem__(self, idx):
115
+ start_idx, end_idx, seq_length = self.samples_mapping[idx]
116
+ sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
117
+ # Note that this rng state should be numpy and not python since
118
+ # python randint is inclusive whereas the numpy one is exclusive.
119
+ # We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
120
+ np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
121
+ return build_training_sample(sample, seq_length,
122
+ self.max_seq_length, # needed for padding
123
+ self.vocab_id_list,
124
+ self.vocab_id_to_token_dict,
125
+ self.cls_id, self.sep_id,
126
+ self.mask_id, self.pad_id,
127
+ self.masked_lm_prob, np_rng,
128
+ self.binary_head)
129
+
130
+
131
+
132
+
133
+ def build_training_sample(sample,
134
+ target_seq_length, max_seq_length,
135
+ vocab_id_list, vocab_id_to_token_dict,
136
+ cls_id, sep_id, mask_id, pad_id,
137
+ masked_lm_prob, np_rng, binary_head):
138
+ """Biuld training sample.
139
+
140
+ Arguments:
141
+ sample: A list of sentences in which each sentence is a list token ids.
142
+ target_seq_length: Desired sequence length.
143
+ max_seq_length: Maximum length of the sequence. All values are padded to
144
+ this length.
145
+ vocab_id_list: List of vocabulary ids. Used to pick a random id.
146
+ vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
147
+ cls_id: Start of example id.
148
+ sep_id: Separator id.
149
+ mask_id: Mask token id.
150
+ pad_id: Padding token id.
151
+ masked_lm_prob: Probability to mask tokens.
152
+ np_rng: Random number genenrator. Note that this rng state should be
153
+ numpy and not python since python randint is inclusive for
154
+ the opper bound whereas the numpy one is exclusive.
155
+ """
156
+
157
+ if binary_head:
158
+ # We assume that we have at least two sentences in the sample
159
+ assert len(sample) > 1
160
+ assert target_seq_length <= max_seq_length
161
+
162
+ # Divide sample into two segments (A and B).
163
+ if binary_head:
164
+ tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
165
+ np_rng)
166
+ else:
167
+ tokens_a = []
168
+ for j in range(len(sample)):
169
+ tokens_a.extend(sample[j])
170
+ tokens_b = []
171
+ is_next_random = False
172
+
173
+ # Truncate to `target_sequence_length`.
174
+ max_num_tokens = target_seq_length
175
+ truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
176
+ len(tokens_b), max_num_tokens, np_rng)
177
+
178
+ # Build tokens and toketypes.
179
+ tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
180
+ cls_id, sep_id)
181
+
182
+ # Masking.
183
+ max_predictions_per_seq = masked_lm_prob * max_num_tokens
184
+ (tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
185
+ tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
186
+ cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
187
+
188
+ # Padding.
189
+ tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
190
+ = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
191
+ masked_labels, pad_id, max_seq_length)
192
+
193
+ train_sample = {
194
+ 'text': tokens_np,
195
+ 'types': tokentypes_np,
196
+ 'labels': labels_np,
197
+ 'is_random': int(is_next_random),
198
+ 'loss_mask': loss_mask_np,
199
+ 'padding_mask': padding_mask_np,
200
+ 'truncated': int(truncated)}
201
+ return train_sample
202
+
203
+
204
+ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
205
+ masked_labels, pad_id, max_seq_length):
206
+ """Pad sequences and convert them to numpy."""
207
+
208
+ # Some checks.
209
+ num_tokens = len(tokens)
210
+ padding_length = max_seq_length - num_tokens
211
+ assert padding_length >= 0
212
+ assert len(tokentypes) == num_tokens
213
+ assert len(masked_positions) == len(masked_labels)
214
+
215
+ # Tokens and token types.
216
+ filler = [pad_id] * padding_length
217
+ tokens_np = np.array(tokens + filler, dtype=np.int64)
218
+ tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
219
+
220
+ # Padding mask.
221
+ padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
222
+ dtype=np.int64)
223
+
224
+ # Lables and loss mask.
225
+ labels = [-1] * max_seq_length
226
+ loss_mask = [0] * max_seq_length
227
+ for i in range(len(masked_positions)):
228
+ assert masked_positions[i] < num_tokens
229
+ labels[masked_positions[i]] = masked_labels[i]
230
+ loss_mask[masked_positions[i]] = 1
231
+ labels_np = np.array(labels, dtype=np.int64)
232
+ loss_mask_np = np.array(loss_mask, dtype=np.int64)
233
+
234
+ return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
megatron/data/biencoder_dataset_utils.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from megatron import get_args, get_tokenizer, mpu, print_rank_0
8
+ from megatron.data.dataset_utils import create_masked_lm_predictions, \
9
+ pad_and_convert_to_numpy
10
+ from megatron.data.data_samplers import MegatronPretrainingSampler
11
+
12
+ def make_attention_mask(source_block, target_block):
13
+ """
14
+ Returns a 2-dimensional (2-D) attention mask
15
+ :param source_block: 1-D array
16
+ :param target_block: 1-D array
17
+ """
18
+ mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
19
+ mask = mask.astype(np.int64)
20
+ # (source_length, target_length)
21
+ return mask
22
+
23
+ def get_one_epoch_dataloader(dataset, micro_batch_size=None):
24
+ """Specifically one epoch to be used in an indexing job."""
25
+ args = get_args()
26
+
27
+ if micro_batch_size is None:
28
+ micro_batch_size = args.micro_batch_size
29
+ num_workers = args.num_workers
30
+
31
+ # Use megatron's sampler with consumed samples set to 0 as
32
+ # this is only for evaluation and don't intend to resume half way.
33
+ # Also, set the drop last to false as don't intend to remove
34
+ # the last batch
35
+ batch_sampler = MegatronPretrainingSampler(
36
+ total_samples=len(dataset),
37
+ consumed_samples=0,
38
+ micro_batch_size=args.micro_batch_size,
39
+ data_parallel_rank=mpu.get_data_parallel_rank(),
40
+ data_parallel_size=mpu.get_data_parallel_world_size(),
41
+ drop_last=False)
42
+
43
+ return torch.utils.data.DataLoader(dataset,
44
+ batch_sampler=batch_sampler,
45
+ num_workers=num_workers,
46
+ pin_memory=True)
47
+
48
+
49
+ def get_ict_batch(data_iterator):
50
+ # Items and their type.
51
+ keys = ['query_tokens', 'query_mask',
52
+ 'context_tokens', 'context_mask', 'block_data']
53
+ datatype = torch.int64
54
+
55
+ # Broadcast data.
56
+ if data_iterator is None:
57
+ data = None
58
+ else:
59
+ data = next(data_iterator)
60
+ data_b = mpu.broadcast_data(keys, data, datatype)
61
+
62
+ # Unpack.
63
+ query_tokens = data_b['query_tokens'].long()
64
+ query_mask = data_b['query_mask'] < 0.5
65
+ context_tokens = data_b['context_tokens'].long()
66
+ context_mask = data_b['context_mask'] < 0.5
67
+ block_indices = data_b['block_data'].long()
68
+
69
+ return query_tokens, query_mask,\
70
+ context_tokens, context_mask, block_indices
71
+
72
+
73
+ def join_str_list(str_list):
74
+ """Join a list of strings, handling spaces appropriately"""
75
+ result = ""
76
+ for s in str_list:
77
+ if s.startswith("##"):
78
+ result += s[2:]
79
+ else:
80
+ result += " " + s
81
+ return result
82
+
83
+
84
+ class BlockSampleData(object):
85
+ """A struct for fully describing a fixed-size block of data as used in REALM
86
+
87
+ :param start_idx: for first sentence of the block
88
+ :param end_idx: for last sentence of the block (may be partially truncated in sample construction)
89
+ :param doc_idx: the index of the document from which the block comes in the original indexed dataset
90
+ :param block_idx: a unique integer identifier given to every block.
91
+ """
92
+ def __init__(self, start_idx, end_idx, doc_idx, block_idx):
93
+ self.start_idx = start_idx
94
+ self.end_idx = end_idx
95
+ self.doc_idx = doc_idx
96
+ self.block_idx = block_idx
97
+
98
+ def as_array(self):
99
+ return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
100
+
101
+ def as_tuple(self):
102
+ return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
103
+
104
+
105
+ class BlockSamplesMapping(object):
106
+ def __init__(self, mapping_array):
107
+ # make sure that the array is compatible with BlockSampleData
108
+ assert mapping_array.shape[1] == 4
109
+ self.mapping_array = mapping_array
110
+
111
+ def __len__(self):
112
+ return self.mapping_array.shape[0]
113
+
114
+ def __getitem__(self, idx):
115
+ """Get the data associated with an indexed sample."""
116
+ sample_data = BlockSampleData(*self.mapping_array[idx])
117
+ return sample_data
118
+
119
+
120
+ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
121
+ max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
122
+ """Get samples mapping for a dataset over fixed size blocks. This function also requires
123
+ a dataset of the titles for the source documents since their lengths must be taken into account.
124
+
125
+ :return: samples_mapping (BlockSamplesMapping)
126
+ """
127
+
128
+ if not num_epochs:
129
+ if not max_num_samples:
130
+ raise ValueError("Need to specify either max_num_samples "
131
+ "or num_epochs")
132
+ num_epochs = np.iinfo(np.int32).max - 1
133
+ if not max_num_samples:
134
+ max_num_samples = np.iinfo(np.int64).max - 1
135
+
136
+ # Filename of the index mapping
137
+ indexmap_filename = data_prefix
138
+ indexmap_filename += '_{}_indexmap'.format(name)
139
+ if num_epochs != (np.iinfo(np.int32).max - 1):
140
+ indexmap_filename += '_{}ep'.format(num_epochs)
141
+ if max_num_samples != (np.iinfo(np.int64).max - 1):
142
+ indexmap_filename += '_{}mns'.format(max_num_samples)
143
+ indexmap_filename += '_{}msl'.format(max_seq_length)
144
+ indexmap_filename += '_{}s'.format(seed)
145
+ if use_one_sent_docs:
146
+ indexmap_filename += '_1sentok'
147
+ indexmap_filename += '.npy'
148
+
149
+ # Build the indexed mapping if not exist.
150
+ if mpu.get_data_parallel_rank() == 0 and \
151
+ not os.path.isfile(indexmap_filename):
152
+ print(' > WARNING: could not find index map file {}, building '
153
+ 'the indices on rank 0 ...'.format(indexmap_filename))
154
+
155
+ # Make sure the types match the helpers input types.
156
+ assert block_dataset.doc_idx.dtype == np.int64
157
+ assert block_dataset.sizes.dtype == np.int32
158
+
159
+ # Build samples mapping
160
+ verbose = torch.distributed.get_rank() == 0
161
+ start_time = time.time()
162
+ print_rank_0(' > building samples index mapping for {} ...'.format(
163
+ name))
164
+
165
+ from megatron.data import helpers
166
+ mapping_array = helpers.build_blocks_mapping(
167
+ block_dataset.doc_idx,
168
+ block_dataset.sizes,
169
+ title_dataset.sizes,
170
+ num_epochs,
171
+ max_num_samples,
172
+ max_seq_length - 3, # account for added tokens
173
+ seed,
174
+ verbose,
175
+ use_one_sent_docs)
176
+
177
+
178
+ print_rank_0(' > done building samples index mapping')
179
+ np.save(indexmap_filename, mapping_array, allow_pickle=True)
180
+ print_rank_0(' > saved the index mapping in {}'.format(
181
+ indexmap_filename))
182
+ # Make sure all the ranks have built the mapping
183
+ print_rank_0(' > elapsed time to build and save samples mapping '
184
+ '(seconds): {:4f}'.format(
185
+ time.time() - start_time))
186
+
187
+ # This should be a barrier but nccl barrier assumes
188
+ # device_index=rank which is not the case for model
189
+ # parallel case
190
+ counts = torch.cuda.LongTensor([1])
191
+ torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
192
+ assert counts[0].item() == torch.distributed.get_world_size(
193
+ group=mpu.get_data_parallel_group())
194
+
195
+ # Load indexed dataset.
196
+ print_rank_0(' > loading indexed mapping from {}'.format(
197
+ indexmap_filename))
198
+ start_time = time.time()
199
+
200
+ mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
201
+ samples_mapping = BlockSamplesMapping(mapping_array)
202
+
203
+ print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
204
+ time.time() - start_time))
205
+ print_rank_0(' total number of samples: {}'.format(
206
+ mapping_array.shape[0]))
207
+
208
+ return samples_mapping
megatron/data/blendable_dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Blendable dataset."""
17
+
18
+ import time
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from megatron import print_rank_0
24
+ from megatron import mpu
25
+
26
+
27
+ class BlendableDataset(torch.utils.data.Dataset):
28
+
29
+
30
+ def __init__(self, datasets, weights):
31
+
32
+ self.datasets = datasets
33
+ num_datasets = len(datasets)
34
+ assert num_datasets == len(weights)
35
+
36
+ self.size = 0
37
+ for dataset in self.datasets:
38
+ self.size += len(dataset)
39
+
40
+ # Normalize weights.
41
+ weights = np.array(weights, dtype=np.float64)
42
+ sum_weights = np.sum(weights)
43
+ assert sum_weights > 0.0
44
+ weights /= sum_weights
45
+
46
+ # Build indecies.
47
+ start_time = time.time()
48
+ assert num_datasets < 255
49
+ self.dataset_index = np.zeros(self.size, dtype=np.uint8)
50
+ self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
51
+
52
+ from megatron.data import helpers
53
+ helpers.build_blending_indices(self.dataset_index,
54
+ self.dataset_sample_index,
55
+ weights, num_datasets, self.size,
56
+ torch.distributed.get_rank() == 0)
57
+ print_rank_0('> elapsed time for building blendable dataset indices: '
58
+ '{:.2f} (sec)'.format(time.time() - start_time))
59
+
60
+
61
+ def __len__(self):
62
+ return self.size
63
+
64
+
65
+ def __getitem__(self, idx):
66
+ dataset_idx = self.dataset_index[idx]
67
+ sample_idx = self.dataset_sample_index[idx]
68
+ return self.datasets[dataset_idx][sample_idx]
megatron/data/data_samplers.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Dataloaders."""
17
+
18
+
19
+ import random
20
+ import torch
21
+ import numpy as np
22
+ from torch.utils.data import Dataset
23
+ from megatron import get_args
24
+ from megatron import mpu
25
+
26
+
27
+ def build_pretraining_data_loader(dataset, consumed_samples):
28
+ """Buld dataloader given an input dataset."""
29
+
30
+ if dataset is None:
31
+ return None
32
+ args = get_args()
33
+
34
+ # Megatron sampler
35
+ if args.dataloader_type == 'single':
36
+ batch_sampler = MegatronPretrainingSampler(
37
+ total_samples=len(dataset),
38
+ consumed_samples=consumed_samples,
39
+ micro_batch_size=args.micro_batch_size,
40
+ data_parallel_rank=mpu.get_data_parallel_rank(),
41
+ data_parallel_size=mpu.get_data_parallel_world_size())
42
+ elif args.dataloader_type == 'cyclic':
43
+ batch_sampler = MegatronPretrainingRandomSampler(
44
+ dataset,
45
+ total_samples=len(dataset),
46
+ consumed_samples=consumed_samples,
47
+ micro_batch_size=args.micro_batch_size,
48
+ data_parallel_rank=mpu.get_data_parallel_rank(),
49
+ data_parallel_size=mpu.get_data_parallel_world_size(),
50
+ data_sharding=args.data_sharding)
51
+ else:
52
+ raise Exception('{} dataloader type is not supported.'.format(
53
+ args.dataloader_type))
54
+
55
+ # Torch dataloader.
56
+ return torch.utils.data.DataLoader(dataset,
57
+ batch_sampler=batch_sampler,
58
+ num_workers=args.num_workers,
59
+ pin_memory=True)
60
+
61
+ class MegatronPretrainingSampler:
62
+
63
+ def __init__(self, total_samples, consumed_samples, micro_batch_size,
64
+ data_parallel_rank, data_parallel_size, drop_last=True):
65
+ # Keep a copy of input params for later use.
66
+ self.total_samples = total_samples
67
+ self.consumed_samples = consumed_samples
68
+ self.micro_batch_size = micro_batch_size
69
+ self.data_parallel_rank = data_parallel_rank
70
+ self.micro_batch_times_data_parallel_size = \
71
+ self.micro_batch_size * data_parallel_size
72
+ self.drop_last = drop_last
73
+
74
+ # Sanity checks.
75
+ assert self.total_samples > 0, \
76
+ 'no sample to consume: {}'.format(self.total_samples)
77
+ assert self.consumed_samples < self.total_samples, \
78
+ 'no samples left to consume: {}, {}'.format(self.consumed_samples,
79
+ self.total_samples)
80
+ assert self.micro_batch_size > 0
81
+ assert data_parallel_size > 0
82
+ assert self.data_parallel_rank < data_parallel_size, \
83
+ 'data_parallel_rank should be smaller than data size: {}, ' \
84
+ '{}'.format(self.data_parallel_rank, data_parallel_size)
85
+
86
+ def __len__(self):
87
+ return self.total_samples
88
+
89
+ def get_start_end_idx(self):
90
+ start_idx = self.data_parallel_rank * self.micro_batch_size
91
+ end_idx = start_idx + self.micro_batch_size
92
+ return start_idx, end_idx
93
+
94
+ def __iter__(self):
95
+ batch = []
96
+ # Last batch will be dropped if drop_last is not set False
97
+ for idx in range(self.consumed_samples, self.total_samples):
98
+ batch.append(idx)
99
+ if len(batch) == self.micro_batch_times_data_parallel_size:
100
+ start_idx, end_idx = self.get_start_end_idx()
101
+ yield batch[start_idx:end_idx]
102
+ batch = []
103
+
104
+ # Check the last partial batch and see drop_last is set
105
+ if len(batch) > 0 and not self.drop_last:
106
+ start_idx, end_idx = self.get_start_end_idx()
107
+ yield batch[start_idx:end_idx]
108
+
109
+
110
+ class RandomSeedDataset(Dataset):
111
+
112
+ def __init__(self, dataset):
113
+ args = get_args()
114
+ self.base_seed = args.seed
115
+ self.curr_seed = args.seed
116
+ self.dataset = dataset
117
+
118
+ def __len__(self):
119
+ return len(self.dataset)
120
+
121
+ def set_epoch(self, epoch):
122
+ self.curr_seed = self.base_seed + epoch
123
+
124
+ def __getitem__(self, idx):
125
+ seed = idx + self.curr_seed
126
+ torch.manual_seed(seed)
127
+ random.seed(seed)
128
+ np.random.seed(seed)
129
+ return self.dataset[idx]
130
+
131
+
132
+ class MegatronPretrainingRandomSampler:
133
+
134
+ def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
135
+ data_parallel_rank, data_parallel_size, data_sharding):
136
+ # Keep a copy of input params for later use.
137
+ self.dataset = dataset
138
+ self.total_samples = total_samples
139
+ self.consumed_samples = consumed_samples
140
+ self.micro_batch_size = micro_batch_size
141
+ self.data_parallel_rank = data_parallel_rank
142
+ self.data_parallel_size = data_parallel_size
143
+ self.data_sharding = data_sharding
144
+ self.micro_batch_times_data_parallel_size = \
145
+ self.micro_batch_size * data_parallel_size
146
+ self.last_batch_size = \
147
+ self.total_samples % self.micro_batch_times_data_parallel_size
148
+
149
+ # Sanity checks.
150
+ assert self.total_samples > 0, \
151
+ 'no sample to consume: {}'.format(self.total_samples)
152
+ assert self.micro_batch_size > 0
153
+ assert data_parallel_size > 0
154
+ assert self.data_parallel_rank < data_parallel_size, \
155
+ 'data_parallel_rank should be smaller than data size: {}, ' \
156
+ '{}'.format(self.data_parallel_rank, data_parallel_size)
157
+
158
+ def __len__(self):
159
+ return self.total_samples
160
+
161
+ def __iter__(self):
162
+ active_total_samples = self.total_samples - self.last_batch_size
163
+ self.epoch = self.consumed_samples // active_total_samples
164
+ current_epoch_samples = self.consumed_samples % active_total_samples
165
+ assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
166
+
167
+ if isinstance(self.dataset, RandomSeedDataset):
168
+ self.dataset.set_epoch(self.epoch)
169
+
170
+ # data sharding and random sampling
171
+ if self.data_sharding:
172
+ bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
173
+ * self.micro_batch_size
174
+ bucket_offset = current_epoch_samples // self.data_parallel_size
175
+ start_idx = self.data_parallel_rank * bucket_size
176
+
177
+ g = torch.Generator()
178
+ g.manual_seed(self.epoch)
179
+ random_idx = torch.randperm(bucket_size, generator=g).tolist()
180
+ idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
181
+ else:
182
+ full_bucket_size = (self.total_samples // self.micro_batch_size) \
183
+ * self.micro_batch_size
184
+ full_bucket_offset = current_epoch_samples
185
+ g = torch.Generator()
186
+ g.manual_seed(self.epoch)
187
+ idx_range_total = \
188
+ torch.randperm(full_bucket_size, generator=g).tolist()
189
+ idx_range_active = idx_range_total[full_bucket_offset:]
190
+ idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size]
191
+
192
+ batch = []
193
+ # Last batch if not complete will be dropped.
194
+ for idx in idx_range:
195
+ batch.append(idx)
196
+ if len(batch) == self.micro_batch_size:
197
+ self.consumed_samples += self.micro_batch_times_data_parallel_size
198
+ yield batch
199
+ batch = []
megatron/data/dataset_utils.py ADDED
@@ -0,0 +1,938 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ # Most of the code here has been copied from:
18
+ # https://github.com/google-research/albert/blob/master/create_pretraining_data.py
19
+ # with some modifications.
20
+
21
+ import math
22
+ import os
23
+ import time
24
+ import collections
25
+
26
+ import numpy as np
27
+ import torch
28
+ import random
29
+
30
+ from megatron import (
31
+ get_tokenizer,
32
+ get_args,
33
+ mpu,
34
+ print_rank_0
35
+ )
36
+ from megatron.data.blendable_dataset import BlendableDataset
37
+ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
38
+
39
+ DSET_TYPE_BERT = 'standard_bert'
40
+ DSET_TYPE_ICT = 'ict'
41
+ DSET_TYPE_T5 = 't5'
42
+ DSET_TYPE_GLM = 'glm'
43
+
44
+ DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_GLM]
45
+
46
+
47
+ def get_datasets_weights_and_num_samples(data_prefix,
48
+ train_valid_test_num_samples):
49
+
50
+ # The data prefix should be in the format of:
51
+ # weight-1, data-prefix-1, weight-2, data-prefix-2, ..
52
+ assert len(data_prefix) % 2 == 0
53
+ num_datasets = len(data_prefix) // 2
54
+ weights = [0]*num_datasets
55
+ prefixes = [0]*num_datasets
56
+ for i in range(num_datasets):
57
+ weights[i] = float(data_prefix[2*i])
58
+ prefixes[i] = (data_prefix[2*i+1]).strip()
59
+ # Normalize weights
60
+ weight_sum = 0.0
61
+ for weight in weights:
62
+ weight_sum += weight
63
+ assert weight_sum > 0.0
64
+ weights = [weight / weight_sum for weight in weights]
65
+
66
+ # Add 0.5% (the 1.005 factor) so in case the bleding dataset does
67
+ # not uniformly distribute the number of samples, we still have
68
+ # samples left to feed to the network.
69
+ datasets_train_valid_test_num_samples = []
70
+ for weight in weights:
71
+ datasets_train_valid_test_num_samples.append(
72
+ [int(math.ceil(val * weight * 1.005))
73
+ for val in train_valid_test_num_samples])
74
+
75
+
76
+ return prefixes, weights, datasets_train_valid_test_num_samples
77
+
78
+
79
+ def compile_helper():
80
+ """Compile helper function ar runtime. Make sure this
81
+ is invoked on a single process."""
82
+ import os
83
+ import subprocess
84
+ path = os.path.abspath(os.path.dirname(__file__))
85
+ ret = subprocess.run(['make', '-C', path])
86
+ if ret.returncode != 0:
87
+ print("Making C++ dataset helpers module failed, exiting.")
88
+ import sys
89
+ sys.exit(1)
90
+
91
+
92
+ def get_a_and_b_segments(sample, np_rng):
93
+ """Divide sample into a and b segments."""
94
+
95
+ # Number of sentences in the sample.
96
+ n_sentences = len(sample)
97
+ # Make sure we always have two sentences.
98
+ assert n_sentences > 1, 'make sure each sample has at least two sentences.'
99
+
100
+ # First part:
101
+ # `a_end` is how many sentences go into the `A`.
102
+ a_end = 1
103
+ if n_sentences >= 3:
104
+ # Note that randin in numpy is exclusive.
105
+ a_end = np_rng.randint(1, n_sentences)
106
+ tokens_a = []
107
+ for j in range(a_end):
108
+ tokens_a.extend(sample[j])
109
+
110
+ # Second part:
111
+ tokens_b = []
112
+ for j in range(a_end, n_sentences):
113
+ tokens_b.extend(sample[j])
114
+
115
+ # Random next:
116
+ is_next_random = False
117
+ if np_rng.random() < 0.5:
118
+ is_next_random = True
119
+ tokens_a, tokens_b = tokens_b, tokens_a
120
+
121
+ return tokens_a, tokens_b, is_next_random
122
+
123
+
124
+ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
125
+ """Truncates a pair of sequences to a maximum sequence length."""
126
+ #print(len_a, len_b, max_num_tokens)
127
+ assert len_a > 0
128
+ if len_a + len_b <= max_num_tokens:
129
+ return False
130
+ while len_a + len_b > max_num_tokens:
131
+ if len_a > len_b:
132
+ len_a -= 1
133
+ tokens = tokens_a
134
+ else:
135
+ len_b -= 1
136
+ tokens = tokens_b
137
+ if np_rng.random() < 0.5:
138
+ del tokens[0]
139
+ else:
140
+ tokens.pop()
141
+ return True
142
+
143
+
144
+ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
145
+ """Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
146
+
147
+ tokens = []
148
+ tokentypes = []
149
+ # [CLS].
150
+ tokens.append(cls_id)
151
+ tokentypes.append(0)
152
+ # Segment A.
153
+ for token in tokens_a:
154
+ tokens.append(token)
155
+ tokentypes.append(0)
156
+ # [SEP].
157
+ tokens.append(sep_id)
158
+ tokentypes.append(0)
159
+ # Segment B.
160
+ for token in tokens_b:
161
+ tokens.append(token)
162
+ tokentypes.append(1)
163
+ if tokens_b:
164
+ # [SEP].
165
+ tokens.append(sep_id)
166
+ tokentypes.append(1)
167
+
168
+ return tokens, tokentypes
169
+
170
+ def create_tokens(tokens_a, tokens_b, cls_id, sep_id):
171
+ """Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
172
+
173
+ tokens = []
174
+ # [CLS].
175
+ tokens.append(cls_id)
176
+ # Segment A.
177
+ for token in tokens_a:
178
+ tokens.append(token)
179
+ # [SEP].
180
+ tokens.append(sep_id)
181
+ # Segment B.
182
+ for token in tokens_b:
183
+ tokens.append(token)
184
+ if tokens_b:
185
+ # [SEP].
186
+ tokens.append(sep_id)
187
+
188
+ return tokens
189
+
190
+ MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
191
+ ["index", "label"])
192
+
193
+
194
+ def is_start_piece(piece):
195
+ """Check if the current word piece is the starting piece (BERT)."""
196
+ # When a word has been split into
197
+ # WordPieces, the first token does not have any marker and any subsequence
198
+ # tokens are prefixed with ##. So whenever we see the ## token, we
199
+ # append it to the previous set of word indexes.
200
+ return not piece.startswith("##")
201
+
202
+
203
+ def create_masked_lm_predictions(tokens,
204
+ vocab_id_list, vocab_id_to_token_dict,
205
+ masked_lm_prob,
206
+ cls_id, sep_id, mask_id,
207
+ max_predictions_per_seq,
208
+ np_rng,
209
+ max_ngrams=3,
210
+ do_whole_word_mask=True,
211
+ favor_longer_ngram=False,
212
+ do_permutation=False,
213
+ geometric_dist=False,
214
+ masking_style="bert"):
215
+ """Creates the predictions for the masked LM objective.
216
+ Note: Tokens here are vocab ids and not text tokens."""
217
+
218
+ cand_indexes = []
219
+ # Note(mingdachen): We create a list for recording if the piece is
220
+ # the starting piece of current token, where 1 means true, so that
221
+ # on-the-fly whole word masking is possible.
222
+ token_boundary = [0] * len(tokens)
223
+
224
+ for (i, token) in enumerate(tokens):
225
+ if token == cls_id or token == sep_id:
226
+ token_boundary[i] = 1
227
+ continue
228
+ # Whole Word Masking means that if we mask all of the wordpieces
229
+ # corresponding to an original word.
230
+ #
231
+ # Note that Whole Word Masking does *not* change the training code
232
+ # at all -- we still predict each WordPiece independently, softmaxed
233
+ # over the entire vocabulary.
234
+ if (do_whole_word_mask and len(cand_indexes) >= 1 and
235
+ not is_start_piece(vocab_id_to_token_dict[token])):
236
+ cand_indexes[-1].append(i)
237
+ else:
238
+ cand_indexes.append([i])
239
+ if is_start_piece(vocab_id_to_token_dict[token]):
240
+ token_boundary[i] = 1
241
+
242
+ output_tokens = list(tokens)
243
+
244
+ masked_lm_positions = []
245
+ masked_lm_labels = []
246
+
247
+ if masked_lm_prob == 0:
248
+ return (output_tokens, masked_lm_positions,
249
+ masked_lm_labels, token_boundary)
250
+
251
+ num_to_predict = min(max_predictions_per_seq,
252
+ max(1, int(round(len(tokens) * masked_lm_prob))))
253
+
254
+ ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
255
+ if not geometric_dist:
256
+ # Note(mingdachen):
257
+ # By default, we set the probilities to favor shorter ngram sequences.
258
+ pvals = 1. / np.arange(1, max_ngrams + 1)
259
+ pvals /= pvals.sum(keepdims=True)
260
+ if favor_longer_ngram:
261
+ pvals = pvals[::-1]
262
+
263
+ ngram_indexes = []
264
+ for idx in range(len(cand_indexes)):
265
+ ngram_index = []
266
+ for n in ngrams:
267
+ ngram_index.append(cand_indexes[idx:idx + n])
268
+ ngram_indexes.append(ngram_index)
269
+
270
+ np_rng.shuffle(ngram_indexes)
271
+
272
+ (masked_lms, masked_spans) = ([], [])
273
+ covered_indexes = set()
274
+ for cand_index_set in ngram_indexes:
275
+ if len(masked_lms) >= num_to_predict:
276
+ break
277
+ if not cand_index_set:
278
+ continue
279
+ # Note(mingdachen):
280
+ # Skip current piece if they are covered in lm masking or previous ngrams.
281
+ for index_set in cand_index_set[0]:
282
+ for index in index_set:
283
+ if index in covered_indexes:
284
+ continue
285
+
286
+ if not geometric_dist:
287
+ n = np_rng.choice(ngrams[:len(cand_index_set)],
288
+ p=pvals[:len(cand_index_set)] /
289
+ pvals[:len(cand_index_set)].sum(keepdims=True))
290
+ else:
291
+ # Sampling "n" from the geometric distribution and clipping it to
292
+ # the max_ngrams. Using p=0.2 default from the SpanBERT paper
293
+ # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
294
+ n = min(np_rng.geometric(0.2), max_ngrams)
295
+
296
+ index_set = sum(cand_index_set[n - 1], [])
297
+ n -= 1
298
+ # Note(mingdachen):
299
+ # Repeatedly looking for a candidate that does not exceed the
300
+ # maximum number of predictions by trying shorter ngrams.
301
+ while len(masked_lms) + len(index_set) > num_to_predict:
302
+ if n == 0:
303
+ break
304
+ index_set = sum(cand_index_set[n - 1], [])
305
+ n -= 1
306
+ # If adding a whole-word mask would exceed the maximum number of
307
+ # predictions, then just skip this candidate.
308
+ if len(masked_lms) + len(index_set) > num_to_predict:
309
+ continue
310
+ is_any_index_covered = False
311
+ for index in index_set:
312
+ if index in covered_indexes:
313
+ is_any_index_covered = True
314
+ break
315
+ if is_any_index_covered:
316
+ continue
317
+ for index in index_set:
318
+ covered_indexes.add(index)
319
+ masked_token = None
320
+ if masking_style == "bert":
321
+ # 80% of the time, replace with [MASK]
322
+ if np_rng.random() < 0.8:
323
+ masked_token = mask_id
324
+ else:
325
+ # 10% of the time, keep original
326
+ if np_rng.random() < 0.5:
327
+ masked_token = tokens[index]
328
+ # 10% of the time, replace with random word
329
+ else:
330
+ masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
331
+ elif masking_style == "t5":
332
+ masked_token = mask_id
333
+ else:
334
+ raise ValueError("invalid value of masking style")
335
+
336
+ output_tokens[index] = masked_token
337
+ masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
338
+
339
+ masked_spans.append(MaskedLmInstance(
340
+ index=index_set,
341
+ label=[tokens[index] for index in index_set]))
342
+
343
+ assert len(masked_lms) <= num_to_predict
344
+ np_rng.shuffle(ngram_indexes)
345
+
346
+ select_indexes = set()
347
+ if do_permutation:
348
+ for cand_index_set in ngram_indexes:
349
+ if len(select_indexes) >= num_to_predict:
350
+ break
351
+ if not cand_index_set:
352
+ continue
353
+ # Note(mingdachen):
354
+ # Skip current piece if they are covered in lm masking or previous ngrams.
355
+ for index_set in cand_index_set[0]:
356
+ for index in index_set:
357
+ if index in covered_indexes or index in select_indexes:
358
+ continue
359
+
360
+ n = np.random.choice(ngrams[:len(cand_index_set)],
361
+ p=pvals[:len(cand_index_set)] /
362
+ pvals[:len(cand_index_set)].sum(keepdims=True))
363
+ index_set = sum(cand_index_set[n - 1], [])
364
+ n -= 1
365
+
366
+ while len(select_indexes) + len(index_set) > num_to_predict:
367
+ if n == 0:
368
+ break
369
+ index_set = sum(cand_index_set[n - 1], [])
370
+ n -= 1
371
+ # If adding a whole-word mask would exceed the maximum number of
372
+ # predictions, then just skip this candidate.
373
+ if len(select_indexes) + len(index_set) > num_to_predict:
374
+ continue
375
+ is_any_index_covered = False
376
+ for index in index_set:
377
+ if index in covered_indexes or index in select_indexes:
378
+ is_any_index_covered = True
379
+ break
380
+ if is_any_index_covered:
381
+ continue
382
+ for index in index_set:
383
+ select_indexes.add(index)
384
+ assert len(select_indexes) <= num_to_predict
385
+
386
+ select_indexes = sorted(select_indexes)
387
+ permute_indexes = list(select_indexes)
388
+ np_rng.shuffle(permute_indexes)
389
+ orig_token = list(output_tokens)
390
+
391
+ for src_i, tgt_i in zip(select_indexes, permute_indexes):
392
+ output_tokens[src_i] = orig_token[tgt_i]
393
+ masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
394
+
395
+ masked_lms = sorted(masked_lms, key=lambda x: x.index)
396
+ # Sort the spans by the index of the first span
397
+ masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
398
+
399
+ for p in masked_lms:
400
+ masked_lm_positions.append(p.index)
401
+ masked_lm_labels.append(p.label)
402
+ return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)
403
+
404
+
405
+ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
406
+ masked_labels, pad_id, max_seq_length):
407
+ """Pad sequences and convert them to numpy."""
408
+
409
+ # Some checks.
410
+ num_tokens = len(tokens)
411
+ padding_length = max_seq_length - num_tokens
412
+ assert padding_length >= 0
413
+ assert len(tokentypes) == num_tokens
414
+ assert len(masked_positions) == len(masked_labels)
415
+
416
+ # Tokens and token types.
417
+ filler = [pad_id] * padding_length
418
+ tokens_np = np.array(tokens + filler, dtype=np.int64)
419
+ tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
420
+
421
+ # Padding mask.
422
+ padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
423
+ dtype=np.int64)
424
+
425
+ # Lables and loss mask.
426
+ labels = [-1] * max_seq_length
427
+ loss_mask = [0] * max_seq_length
428
+ for i in range(len(masked_positions)):
429
+ assert masked_positions[i] < num_tokens
430
+ labels[masked_positions[i]] = masked_labels[i]
431
+ loss_mask[masked_positions[i]] = 1
432
+ labels_np = np.array(labels, dtype=np.int64)
433
+ loss_mask_np = np.array(loss_mask, dtype=np.int64)
434
+
435
+ return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
436
+
437
+
438
+ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
439
+ train_valid_test_num_samples,
440
+ max_seq_length,
441
+ masked_lm_prob, short_seq_prob, seed,
442
+ skip_warmup, binary_head=False,
443
+ max_seq_length_dec=None,
444
+ dataset_type='standard_bert'):
445
+ if len(data_prefix) == 1:
446
+ return _build_train_valid_test_datasets(data_prefix[0],
447
+ data_impl, splits_string,
448
+ train_valid_test_num_samples,
449
+ max_seq_length, masked_lm_prob,
450
+ short_seq_prob, seed,
451
+ skip_warmup,
452
+ binary_head,
453
+ max_seq_length_dec,
454
+ dataset_type=dataset_type)
455
+ # Blending dataset.
456
+ # Parse the values.
457
+ output = get_datasets_weights_and_num_samples(data_prefix,
458
+ train_valid_test_num_samples)
459
+ prefixes, weights, datasets_train_valid_test_num_samples = output
460
+
461
+ # Build individual datasets.
462
+ train_datasets = []
463
+ valid_datasets = []
464
+ test_datasets = []
465
+ for i in range(len(prefixes)):
466
+ train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
467
+ prefixes[i], data_impl, splits_string,
468
+ datasets_train_valid_test_num_samples[i],
469
+ max_seq_length, masked_lm_prob, short_seq_prob,
470
+ seed, skip_warmup, binary_head, max_seq_length_dec, dataset_type=dataset_type)
471
+ if train_ds:
472
+ train_datasets.append(train_ds)
473
+ if valid_ds:
474
+ valid_datasets.append(valid_ds)
475
+ if test_ds:
476
+ test_datasets.append(test_ds)
477
+
478
+ # Blend.
479
+ blending_train_dataset = None
480
+ if train_datasets:
481
+ blending_train_dataset = BlendableDataset(train_datasets, weights)
482
+ blending_valid_dataset = None
483
+ if valid_datasets:
484
+ blending_valid_dataset = BlendableDataset(valid_datasets, weights)
485
+ blending_test_dataset = None
486
+ if test_datasets:
487
+ blending_test_dataset = BlendableDataset(test_datasets, weights)
488
+
489
+ return (blending_train_dataset, blending_valid_dataset,
490
+ blending_test_dataset)
491
+
492
+
493
+ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
494
+ train_valid_test_num_samples,
495
+ max_seq_length,
496
+ masked_lm_prob, short_seq_prob, seed,
497
+ skip_warmup, binary_head,
498
+ max_seq_length_dec,
499
+ dataset_type='standard_bert'):
500
+
501
+ if dataset_type not in DSET_TYPES:
502
+ raise ValueError("Invalid dataset_type: ", dataset_type)
503
+
504
+ # Indexed dataset.
505
+ indexed_dataset = get_indexed_dataset_(data_prefix,
506
+ data_impl,
507
+ skip_warmup)
508
+
509
+ if dataset_type == DSET_TYPE_ICT:
510
+ args = get_args()
511
+ title_dataset = get_indexed_dataset_(args.titles_data_path,
512
+ data_impl,
513
+ skip_warmup)
514
+
515
+ # Get start and end indices of train/valid/train into doc-idx
516
+ # Note that doc-idx is desinged to be num-docs + 1 so we can
517
+ # easily iterate over it.
518
+ total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
519
+ splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
520
+
521
+ # Print stats about the splits.
522
+ print_rank_0(' > dataset split:')
523
+
524
+ def print_split_stats(name, index):
525
+ print_rank_0(' {}:'.format(name))
526
+ print_rank_0(' document indices in [{}, {}) total of {} '
527
+ 'documents'.format(splits[index], splits[index + 1],
528
+ splits[index + 1] - splits[index]))
529
+ start_index = indexed_dataset.doc_idx[splits[index]]
530
+ end_index = indexed_dataset.doc_idx[splits[index + 1]]
531
+ print_rank_0(' sentence indices in [{}, {}) total of {} '
532
+ 'sentences'.format(start_index, end_index,
533
+ end_index - start_index))
534
+ print_split_stats('train', 0)
535
+ print_split_stats('validation', 1)
536
+ print_split_stats('test', 2)
537
+
538
+ def build_dataset(index, name):
539
+ from megatron.data.bert_dataset import BertDataset
540
+ from megatron.data.ict_dataset import ICTDataset
541
+ from megatron.data.t5_dataset import T5Dataset
542
+ from megatron.data.glm_dataset import GlmDataset
543
+ dataset = None
544
+ if splits[index + 1] > splits[index]:
545
+ # Get the pointer to the original doc-idx so we can set it later.
546
+ doc_idx_ptr = indexed_dataset.get_doc_idx()
547
+ # Slice the doc-idx
548
+ start_index = splits[index]
549
+ # Add +1 so we can index into the dataset to get the upper bound.
550
+ end_index = splits[index + 1] + 1
551
+ # New doc_idx view.
552
+ indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
553
+ # Build the dataset accordingly.
554
+ kwargs = dict(
555
+ name=name,
556
+ data_prefix=data_prefix,
557
+ num_epochs=None,
558
+ max_num_samples=train_valid_test_num_samples[index],
559
+ max_seq_length=max_seq_length,
560
+ seed=seed,
561
+ )
562
+
563
+ if dataset_type == DSET_TYPE_ICT:
564
+ args = get_args()
565
+ dataset = ICTDataset(
566
+ block_dataset=indexed_dataset,
567
+ title_dataset=title_dataset,
568
+ query_in_block_prob=args.query_in_block_prob,
569
+ use_one_sent_docs=args.use_one_sent_docs,
570
+ binary_head=binary_head,
571
+ **kwargs
572
+ )
573
+ elif dataset_type == DSET_TYPE_T5:
574
+ dataset = T5Dataset(
575
+ indexed_dataset=indexed_dataset,
576
+ masked_lm_prob=masked_lm_prob,
577
+ max_seq_length_dec=max_seq_length_dec,
578
+ short_seq_prob=short_seq_prob,
579
+ **kwargs
580
+ )
581
+ elif dataset_type == DSET_TYPE_BERT:
582
+ dataset = BertDataset(
583
+ indexed_dataset=indexed_dataset,
584
+ masked_lm_prob=masked_lm_prob,
585
+ short_seq_prob=short_seq_prob,
586
+ binary_head=binary_head,
587
+ **kwargs
588
+ )
589
+ elif dataset_type == DSET_TYPE_GLM:
590
+ dataset = GlmDataset(
591
+ indexed_dataset=indexed_dataset,
592
+ masked_lm_prob=masked_lm_prob,
593
+ short_seq_prob=short_seq_prob,
594
+ binary_head=binary_head,
595
+ **kwargs
596
+ )
597
+ else:
598
+ raise NotImplementedError("Dataset type not fully implemented.")
599
+
600
+ # Set the original pointer so dataset remains the main dataset.
601
+ indexed_dataset.set_doc_idx(doc_idx_ptr)
602
+ # Checks.
603
+ assert indexed_dataset.doc_idx[0] == 0
604
+ assert indexed_dataset.doc_idx.shape[0] == \
605
+ (total_num_of_documents + 1)
606
+ return dataset
607
+
608
+ train_dataset = build_dataset(0, 'train')
609
+ valid_dataset = build_dataset(1, 'valid')
610
+ test_dataset = build_dataset(2, 'test')
611
+
612
+ return (train_dataset, valid_dataset, test_dataset)
613
+
614
+
615
+ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
616
+
617
+ print_rank_0(' > building dataset index ...')
618
+
619
+ start_time = time.time()
620
+ indexed_dataset = make_indexed_dataset(data_prefix,
621
+ data_impl,
622
+ skip_warmup)
623
+ assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
624
+ print_rank_0(' > finished creating indexed dataset in {:4f} '
625
+ 'seconds'.format(time.time() - start_time))
626
+
627
+ print_rank_0(' > indexed dataset stats:')
628
+ print_rank_0(' number of documents: {}'.format(
629
+ indexed_dataset.doc_idx.shape[0] - 1))
630
+ print_rank_0(' number of sentences: {}'.format(
631
+ indexed_dataset.sizes.shape[0]))
632
+
633
+ return indexed_dataset
634
+
635
+
636
+ def get_train_valid_test_split_(splits_string, size):
637
+ """ Get dataset splits from comma or '/' separated string list."""
638
+
639
+ splits = []
640
+ if splits_string.find(',') != -1:
641
+ splits = [float(s) for s in splits_string.split(',')]
642
+ elif splits_string.find('/') != -1:
643
+ splits = [float(s) for s in splits_string.split('/')]
644
+ else:
645
+ splits = [float(splits_string)]
646
+ while len(splits) < 3:
647
+ splits.append(0.)
648
+ splits = splits[:3]
649
+ splits_sum = sum(splits)
650
+ assert splits_sum > 0.0
651
+ splits = [split / splits_sum for split in splits]
652
+ splits_index = [0]
653
+ for index, split in enumerate(splits):
654
+ splits_index.append(splits_index[index] +
655
+ int(round(split * float(size))))
656
+ diff = splits_index[-1] - size
657
+ for index in range(1, len(splits_index)):
658
+ splits_index[index] -= diff
659
+ assert len(splits_index) == 4
660
+ assert splits_index[-1] == size
661
+ return splits_index
662
+
663
+ def get_samples_mapping(indexed_dataset,
664
+ data_prefix,
665
+ num_epochs,
666
+ max_num_samples,
667
+ max_seq_length,
668
+ short_seq_prob,
669
+ seed,
670
+ name,
671
+ binary_head):
672
+ """Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""
673
+
674
+ if not num_epochs:
675
+ if not max_num_samples:
676
+ raise ValueError("Need to specify either max_num_samples "
677
+ "or num_epochs")
678
+ num_epochs = np.iinfo(np.int32).max - 1
679
+ if not max_num_samples:
680
+ max_num_samples = np.iinfo(np.int64).max - 1
681
+
682
+ # Filename of the index mapping
683
+ indexmap_filename = data_prefix
684
+ indexmap_filename += '_{}_indexmap'.format(name)
685
+ if num_epochs != (np.iinfo(np.int32).max - 1):
686
+ indexmap_filename += '_{}ep'.format(num_epochs)
687
+ if max_num_samples != (np.iinfo(np.int64).max - 1):
688
+ indexmap_filename += '_{}mns'.format(max_num_samples)
689
+ indexmap_filename += '_{}msl'.format(max_seq_length)
690
+ indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
691
+ indexmap_filename += '_{}s'.format(seed)
692
+ indexmap_filename += '.npy'
693
+
694
+ # Build the indexed mapping if not exist.
695
+ if torch.distributed.get_rank() == 0 and \
696
+ not os.path.isfile(indexmap_filename):
697
+ print(' > WARNING: could not find index map file {}, building '
698
+ 'the indices on rank 0 ...'.format(indexmap_filename))
699
+
700
+ # Make sure the types match the helpers input types.
701
+ assert indexed_dataset.doc_idx.dtype == np.int64
702
+ assert indexed_dataset.sizes.dtype == np.int32
703
+
704
+ # Build samples mapping
705
+ verbose = torch.distributed.get_rank() == 0
706
+ start_time = time.time()
707
+ print_rank_0(' > building samples index mapping for {} ...'.format(
708
+ name))
709
+ # First compile and then import.
710
+ from megatron.data import helpers
711
+ samples_mapping = helpers.build_mapping(
712
+ indexed_dataset.doc_idx,
713
+ indexed_dataset.sizes,
714
+ num_epochs,
715
+ max_num_samples,
716
+ max_seq_length,
717
+ short_seq_prob,
718
+ seed,
719
+ verbose,
720
+ 2 if binary_head else 1)
721
+ print_rank_0(' > done building samples index maping')
722
+ np.save(indexmap_filename, samples_mapping, allow_pickle=True)
723
+ print_rank_0(' > saved the index mapping in {}'.format(
724
+ indexmap_filename))
725
+ # Make sure all the ranks have built the mapping
726
+ print_rank_0(' > elasped time to build and save samples mapping '
727
+ '(seconds): {:4f}'.format(
728
+ time.time() - start_time))
729
+ # This should be a barrier but nccl barrier assumes
730
+ # device_index=rank which is not the case for model
731
+ # parallel case
732
+ counts = torch.cuda.LongTensor([1])
733
+ torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
734
+ torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
735
+ assert counts[0].item() == (
736
+ torch.distributed.get_world_size() //
737
+ torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
738
+
739
+ # Load indexed dataset.
740
+ print_rank_0(' > loading indexed mapping from {}'.format(
741
+ indexmap_filename))
742
+ start_time = time.time()
743
+ samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
744
+ print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
745
+ time.time() - start_time))
746
+ print_rank_0(' total number of samples: {}'.format(
747
+ samples_mapping.shape[0]))
748
+
749
+ return samples_mapping
750
+
751
+
752
+ class MaskEncoder(object):
753
+ def __init__(self):
754
+ tokenizer = get_tokenizer()
755
+ self.vocab_size = tokenizer.vocab_size
756
+ self.vocab_id_list = list(tokenizer.inv_vocab.keys())
757
+ self.vocab_id_to_token_dict = tokenizer.inv_vocab
758
+ self.cls_id = tokenizer.cls
759
+ self.sep_id = tokenizer.sep
760
+ self.mask_id = tokenizer.mask
761
+ self.pad_id = tokenizer.pad
762
+
763
+ import jieba_fast
764
+ self.zh_tokenizer = jieba_fast.lcut
765
+ self.random_ratio = 0
766
+
767
+
768
+ def word_starts(self, source):
769
+ raw_tokens = [self.vocab_id_to_token_dict[i] for i in source.tolist()]
770
+ words = [raw_tokens[0]] + self.zh_tokenizer(''.join(raw_tokens[1:-1]), HMM=True) + [raw_tokens[-1]]
771
+
772
+ def _is_chinese_char(c):
773
+ """Checks whether CP is the codepoint of a CJK character."""
774
+ # This defines a "chinese character" as anything in the CJK Unicode block:
775
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
776
+ #
777
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
778
+ # despite its name. The modern Korean Hangul alphabet is a different block,
779
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
780
+ # space-separated words, so they are not treated specially and handled
781
+ # like the all of the other languages.
782
+ if len(c) > 1:
783
+ return all([_is_chinese_char(c_i) for c_i in c])
784
+ cp = ord(c)
785
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
786
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
787
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
788
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
789
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
790
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
791
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
792
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
793
+ return True
794
+
795
+ return False
796
+
797
+ def align_linear(atokens, btokens):
798
+ a2c = []
799
+ c2b = []
800
+ a2b = []
801
+ length = 0
802
+ for tok in atokens:
803
+ a2c.append([length + i for i in range(len(tok))])
804
+ length += len(tok)
805
+ for i, tok in enumerate(btokens):
806
+ c2b.extend([i for _ in range(len(tok))])
807
+
808
+ for i, amap in enumerate(a2c):
809
+ bmap = [c2b[ci] for ci in amap]
810
+ a2b.append(list(set(bmap)))
811
+ return a2b
812
+
813
+ raw_to_word_align = align_linear(raw_tokens, words)
814
+ is_word_start = torch.zeros(source.size())
815
+ word_starts = []
816
+ skip_cur_word = True
817
+ for i in range(1, len(raw_to_word_align)):
818
+ if raw_to_word_align[i-1] == raw_to_word_align[i]:
819
+ # not a word start, as they align to the same word
820
+ if not skip_cur_word and not _is_chinese_char(raw_tokens[i]):
821
+ word_starts.pop(-1)
822
+ skip_cur_word = True
823
+ continue
824
+ else:
825
+ is_word_start[i] = 1
826
+ if _is_chinese_char(raw_tokens[i]):
827
+ word_starts.append(i)
828
+ skip_cur_word = False
829
+ is_word_start[0] = 0
830
+ is_word_start[-1] = 0
831
+ word_starts = torch.tensor(word_starts).long().view(-1, 1)
832
+ return is_word_start, word_starts
833
+
834
+ def add_whole_word_mask(self, source, p, replace_length=1):
835
+ is_word_start, word_starts = self.word_starts(source)
836
+ num_to_mask_word = int(math.ceil(word_starts.size(0) * p))
837
+ num_to_mask_char = int(math.ceil(word_starts.size(0) * p * 0.1))
838
+ num_to_mask = num_to_mask_word + num_to_mask_char
839
+ if num_to_mask > word_starts.size(0):
840
+ word_starts = is_word_start.nonzero(as_tuple=False)
841
+ num_inserts = 0
842
+ if num_to_mask == 0:
843
+ return source
844
+
845
+ lengths = torch.ones((num_to_mask,)).long()
846
+ assert is_word_start[-1] == 0
847
+ indices = word_starts[
848
+ torch.randperm(word_starts.size(0))[:num_to_mask]
849
+ ].squeeze(1)
850
+ if len(indices) < num_to_mask:
851
+ num_to_mask = len(indices)
852
+
853
+ mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
854
+ source_length = source.size(0)
855
+ assert source_length - 1 not in indices
856
+ to_keep = torch.ones(source_length, dtype=torch.bool)
857
+ is_word_start[
858
+ -1
859
+ ] = 255 # acts as a long length, so spans don't go over the end of doc
860
+ if replace_length == 0:
861
+ to_keep[indices] = 0
862
+ else:
863
+ # keep index, but replace it with [MASK]
864
+ # print(source.size(), word_starts.size(), indices.size(), mask_random.size())
865
+ # try:
866
+ source[indices] = self.mask_id
867
+ source[indices[mask_random]] = torch.randint(
868
+ 1, self.vocab_size, size=(mask_random.sum(),)
869
+ )
870
+ # except:
871
+ # print(source)
872
+ # print(indices)
873
+ # print(mask_random)
874
+ # print()
875
+ # sorted_indices = torch.sort(indices)[0]
876
+ # continue_mask_pos = ((sorted_indices + 1)[:-1] == sorted_indices[1:])
877
+ # continue_mask_indices = sorted_indices[1:][continue_mask_pos]
878
+ # to_keep[continue_mask_indices] = 0
879
+
880
+ # for char indices, we already masked, the following loop handles word mask
881
+ indices = indices[:num_to_mask_word]
882
+ mask_random = mask_random[:num_to_mask_word]
883
+ while indices.size(0) > 0:
884
+ uncompleted = is_word_start[indices + 1] == 0
885
+ indices = indices[uncompleted] + 1
886
+ mask_random = mask_random[uncompleted]
887
+ if replace_length != -1:
888
+ # delete token
889
+ to_keep[indices] = 0
890
+ else:
891
+ # keep index, but replace it with [MASK]
892
+ source[indices] = self.mask_id
893
+ source[indices[mask_random]] = torch.randint(
894
+ 1, self.vocab_size, size=(mask_random.sum(),)
895
+ )
896
+
897
+ assert source_length - 1 not in indices
898
+ source = source[to_keep]
899
+
900
+ return source
901
+
902
+ def shif_chinese_word(self, tokens, tokens_bf_mask):
903
+ assert len(tokens) == len(tokens_bf_mask), 'length must be equal in this function'
904
+ buff_list = []
905
+ buff_list_index = []
906
+ for i in range(len(tokens)):
907
+ if tokens[i] == tokens_bf_mask[i]:
908
+ if len(buff_list) == 0:
909
+ continue
910
+ else:
911
+ if len(buff_list) != 1:
912
+ random.shuffle(buff_list)
913
+ tokens[buff_list_index[0] : buff_list_index[-1]+1] = buff_list
914
+ buff_list = []
915
+ buff_list_index = []
916
+ else:
917
+ buff_list.append(tokens_bf_mask[i])
918
+ buff_list_index.append(i)
919
+
920
+ return tokens
921
+
922
+ def mass_style_mask(self, tokens):
923
+ tokens = tokens[:]
924
+ p = random.uniform(0.3, 0.5)
925
+ num_to_mask = int(len(tokens) * p)
926
+ start_index = int((1 - p) / 2 * len(tokens))
927
+ tokens[start_index : start_index + num_to_mask] = [self.mask_id] * num_to_mask
928
+
929
+ return tokens
930
+
931
+ def delete_chinese_word(self, tokens, tokens_bf_mask):
932
+ return_tokens = []
933
+ assert len(tokens) == len(tokens_bf_mask), 'length must be equal in this function'
934
+ for i in range(len(tokens)):
935
+ if tokens[i] == tokens_bf_mask[i]:
936
+ return_tokens.append(tokens[i])
937
+
938
+ return return_tokens
megatron/data/glm_dataset.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """BERT Style dataset."""
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from megatron import (
22
+ get_args,
23
+ get_tokenizer,
24
+ mpu,
25
+ print_rank_0
26
+ )
27
+ from megatron.data.dataset_utils import (
28
+ get_samples_mapping,
29
+ get_a_and_b_segments,
30
+ truncate_segments,
31
+ create_tokens_and_tokentypes,
32
+ create_tokens,
33
+ create_masked_lm_predictions,
34
+ MaskEncoder
35
+ )
36
+
37
+ class DummyBertDataset(torch.utils.data.Dataset):
38
+ def __init__(self, name, num_samples, max_seq_length):
39
+ self.name = name
40
+ self.num_samples = num_samples
41
+ self.max_seq_length = max_seq_length
42
+ self.np_rng = np.random.RandomState(seed=0)
43
+ # self.token_nps = np_rng.randint(1000, 2000, (self.num_samples, 512))
44
+ # Vocab stuff.
45
+ tokenizer = get_tokenizer()
46
+ self.vocab_id_list = list(tokenizer.inv_vocab.keys())
47
+ self.vocab_id_to_token_dict = tokenizer.inv_vocab
48
+ self.cls_id = tokenizer.cls
49
+ self.sep_id = tokenizer.sep
50
+ self.mask_id = tokenizer.mask
51
+ self.pad_id = tokenizer.pad
52
+
53
+ def __len__(self):
54
+ return self.num_samples
55
+
56
+ def __getitem__(self, idx):
57
+ tokens = self.np_rng.randint(1000, 2000, self.max_seq_length)
58
+ masked_position = np.arange(int(tokens.shape[0] * 0.15))
59
+ tokens = tokens.astype(np.int64)
60
+ labels = tokens[masked_position]
61
+ label_np = np.full_like(tokens, -1)
62
+ label_np[masked_position] = labels
63
+ tokens[masked_position] = self.mask_id
64
+ loss_mask_np = np.zeros_like(tokens)
65
+ loss_mask_np[masked_position] = 1
66
+ train_sample = {
67
+ 'text': tokens,
68
+ 'types': np.zeros_like(tokens),
69
+ 'labels': label_np,
70
+ 'is_random': 0,
71
+ 'loss_mask': loss_mask_np,
72
+ 'padding_mask': np.ones_like(tokens),
73
+ 'truncated': 0
74
+ }
75
+ return train_sample
76
+
77
+ class GlmDataset(torch.utils.data.Dataset):
78
+
79
+ def __init__(self, name, indexed_dataset, data_prefix,
80
+ num_epochs, max_num_samples, masked_lm_prob,
81
+ max_seq_length, short_seq_prob, seed, binary_head):
82
+
83
+ # Params to store.
84
+ self.name = name
85
+ self.seed = seed
86
+ self.masked_lm_prob = masked_lm_prob
87
+ self.max_seq_length = max_seq_length
88
+ self.binary_head = binary_head
89
+
90
+ # Dataset.
91
+ self.indexed_dataset = indexed_dataset
92
+
93
+ # Build the samples mapping.
94
+ self.samples_mapping = get_samples_mapping(self.indexed_dataset,
95
+ data_prefix,
96
+ num_epochs,
97
+ max_num_samples,
98
+ self.max_seq_length - 3, # account for added tokens
99
+ short_seq_prob,
100
+ self.seed,
101
+ self.name,
102
+ self.binary_head)
103
+
104
+ # Vocab stuff.
105
+ tokenizer = get_tokenizer()
106
+ self.vocab_id_list = list(tokenizer.inv_vocab.keys())
107
+ self.vocab_id_to_token_dict = tokenizer.inv_vocab
108
+ self.cls_id = tokenizer.cls
109
+ self.sep_id = tokenizer.sep
110
+ self.mask_id = tokenizer.mask
111
+ self.pad_id = tokenizer.pad
112
+
113
+ def __len__(self):
114
+ return self.samples_mapping.shape[0]
115
+
116
+ def __getitem__(self, idx):
117
+ start_idx, end_idx, seq_length = self.samples_mapping[idx]
118
+ sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
119
+ # Note that this rng state should be numpy and not python since
120
+ # python randint is inclusive whereas the numpy one is exclusive.
121
+ # We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
122
+ np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
123
+ return build_training_sample(sample, seq_length,
124
+ self.max_seq_length, # needed for padding
125
+ self.vocab_id_list,
126
+ self.vocab_id_to_token_dict,
127
+ self.cls_id, self.sep_id,
128
+ self.mask_id, self.pad_id,
129
+ self.masked_lm_prob, np_rng,
130
+ self.binary_head)
131
+
132
+ def sent_level_task(binary_head, sample, target_seq_length, max_seq_length, np_rng):
133
+ if binary_head:
134
+ # We assume that we have at least two sentences in the sample
135
+ assert len(sample) > 1
136
+ assert target_seq_length <= max_seq_length
137
+ # Divide sample into two segments (A and B).
138
+ if binary_head:
139
+ tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
140
+ else:
141
+ tokens_a = []
142
+ for j in range(len(sample)):
143
+ tokens_a.extend(sample[j])
144
+ tokens_b = []
145
+ is_next_random = False
146
+ # Truncate to `target_sequence_length`.
147
+
148
+ max_num_tokens = target_seq_length
149
+ truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
150
+ len(tokens_b), max_num_tokens, np_rng)
151
+ return is_next_random, truncated, max_num_tokens, tokens_a, tokens_b
152
+
153
+ def generate_decoder_input_and_output(tokens, pad_id, sep_id):
154
+ """
155
+ decoder input [SEP] [CSL] A B C D
156
+ decoder output [CLS] A B C D E
157
+ """
158
+
159
+ decoder_output = tokens[:]
160
+ decoder_input = [0] * len(decoder_output)
161
+ decoder_input[0] = sep_id # match the preprocessing in fairseq
162
+ # decoder_input[0] = sep_id # match the preprocessing in fairseq
163
+ decoder_input[1:] = decoder_output[:-1]
164
+
165
+ """
166
+ decoder input [CSL] A B C D [SEP]
167
+ decoder output A B C D [SEP] [PAD]
168
+ """
169
+
170
+ # decoder_input = tokens[:]
171
+ # decoder_output = [0] * len(decoder_input)
172
+ # decoder_output[:-1] = decoder_input[1:]
173
+ # decoder_output[-1] = pad_id
174
+
175
+ return decoder_input, decoder_output
176
+
177
+
178
+
179
+ def build_training_sample(sample,
180
+ target_seq_length, max_seq_length,
181
+ vocab_id_list, vocab_id_to_token_dict,
182
+ cls_id, sep_id, mask_id, pad_id,
183
+ masked_lm_prob, np_rng, binary_head):
184
+
185
+ """
186
+ sent-level task
187
+ """
188
+ is_next_random, truncated, max_num_tokens, tokens_a, tokens_b = sent_level_task(
189
+ binary_head, sample, target_seq_length, max_seq_length, np_rng)
190
+ tokens_bf_mask = create_tokens(tokens_a, tokens_b, cls_id, sep_id)
191
+ if is_next_random:
192
+ raw_tokens = create_tokens(tokens_b, tokens_a, cls_id, sep_id)
193
+ else:
194
+ raw_tokens = tokens_bf_mask[:]
195
+
196
+ """
197
+ decoder-input and output
198
+ """
199
+ decoder_input, decoder_output = generate_decoder_input_and_output(raw_tokens, pad_id, sep_id)
200
+
201
+ # importance part
202
+
203
+ encoder_loss_flag = 0
204
+ decoder_loss_flag = 0
205
+ sent_loss_flag = 1
206
+ encoder_rng = torch.rand(1).item()
207
+ me = MaskEncoder()
208
+ if encoder_rng < 1.1:
209
+ # only train with encoder and decoder
210
+ # Masking.
211
+ if 0:
212
+ max_predictions_per_seq = masked_lm_prob * max_num_tokens
213
+ (tokens, _, _, _, _) = create_masked_lm_predictions(
214
+ tokens_bf_mask, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
215
+ cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng, masking_style="t5")
216
+ if 1 :
217
+ tokens = torch.LongTensor(tokens_bf_mask)
218
+ tokens = me.add_whole_word_mask(tokens, 0.15, -1)
219
+ tokens = tokens.tolist()
220
+ shift_rng = torch.rand(1).item()
221
+ if shift_rng < 0.0:
222
+ tokens = me.shif_chinese_word(tokens, tokens_bf_mask)
223
+ encoder_loss_flag = 1
224
+ decoder_loss_flag = 1
225
+ else:
226
+ # train only with decoder
227
+ tokens = torch.LongTensor(tokens_bf_mask)
228
+ decoder_rng = torch.rand(1).item()
229
+ if decoder_rng < 0.4:
230
+ # WWM mask 30% word
231
+ tokens = me.add_whole_word_mask(tokens, 0.3, -1)
232
+ tokens = tokens.tolist()
233
+ if decoder_rng >= 0.4 and decoder_rng < 0.6:
234
+ # MASS mask style
235
+ tokens = me.mass_style_mask(tokens_bf_mask)
236
+ if decoder_rng > 0.6:
237
+ # delete tokens
238
+ tokens = me.add_whole_word_mask(tokens, 0.3, -1)
239
+ tokens = tokens.tolist()
240
+ tokens = me.delete_chinese_word(tokens, tokens_bf_mask)
241
+ tmp_tt = get_tokenizer()
242
+ # print("encoder ori input", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens_bf_mask)))
243
+ # print("encoder input ", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens)))
244
+ # print("------\n\n")
245
+
246
+
247
+ decoder_loss_flag = 1
248
+
249
+ # tmp_tt = get_tokenizer()
250
+ # print("encoder ori input", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens_bf_mask)))
251
+ # print("encoder input ", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(tokens)))
252
+ # print("decoder input", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(decoder_input)))
253
+ # print("decoder output", "".join(tmp_tt.tokenizer.convert_ids_to_tokens(decoder_output)))
254
+
255
+ tokentypes = []
256
+ encoder_labels = []
257
+ encoder_labels_mask = []
258
+ padding_mask = []
259
+ apppend_type_id = 0
260
+
261
+ if len(tokens) == len(tokens_bf_mask):
262
+ # encoder and decoder can train togather
263
+ for index in range(len(tokens)):
264
+ padding_mask.append(1)
265
+ # generate tokens type
266
+ if tokens[index] == sep_id:
267
+ apppend_type_id = 1
268
+ tokentypes.append(apppend_type_id)
269
+
270
+ if tokens[index] == tokens_bf_mask[index]:
271
+ encoder_labels.append(-1)
272
+ encoder_labels_mask.append(0)
273
+ else:
274
+ encoder_labels.append(tokens_bf_mask[index])
275
+ encoder_labels_mask.append(1)
276
+ else:
277
+ # only train decoder
278
+ for index in range(len(tokens)):
279
+ padding_mask.append(1)
280
+ if tokens[index] == sep_id:
281
+ apppend_type_id = 1
282
+ tokentypes.append(apppend_type_id)
283
+ encoder_labels.append(-1)
284
+ encoder_labels_mask.append(0)
285
+
286
+ tokens_np = pad_and_convert_to_numpy_light(tokens, max_seq_length, pad_id)
287
+ tokentypes_np = pad_and_convert_to_numpy_light(tokentypes, max_seq_length, pad_id)
288
+ padding_mask_np = pad_and_convert_to_numpy_light(padding_mask, max_seq_length, pad_id)
289
+ encoder_labels_np = pad_and_convert_to_numpy_light(encoder_labels, max_seq_length, -1)
290
+ encoder_labels_mask_np = pad_and_convert_to_numpy_light(encoder_labels_mask, max_seq_length, pad_id)
291
+ decoder_input_np = pad_and_convert_to_numpy_light(decoder_input, max_seq_length, pad_id)
292
+ decoder_output_np = pad_and_convert_to_numpy_light(decoder_output, max_seq_length, pad_id)
293
+
294
+ # print(tokens_np)
295
+ # print(encoder_labels_np)
296
+ # print(padding_mask_np)
297
+ # print(encoder_labels_mask_np)
298
+
299
+ # generate tokentypes
300
+ train_sample = {
301
+ 'text': tokens_np, # encoder_input
302
+ 'types': tokentypes_np, # token_type
303
+ 'is_random': int(is_next_random), #sop_labels
304
+ 'truncated': int(truncated), # if truncated
305
+ 'labels': encoder_labels_np, #encoder_labels
306
+ 'loss_mask': encoder_labels_mask_np, # mlm_mask
307
+ 'padding_mask': padding_mask_np, # padding_mask
308
+ 'decoder_input': decoder_input_np, # decoder_input
309
+ 'decoder_output': decoder_output_np, #decoder_output
310
+ 'encoder_loss_flag': int(encoder_loss_flag),
311
+ 'decoder_loss_flag': int(decoder_loss_flag),
312
+ 'sent_loss_flag': int(sent_loss_flag),
313
+ }
314
+
315
+ # print(tokens_np.shape)
316
+ # print(tokens_np)
317
+
318
+ # print(tokentypes_np.shape)
319
+ # print(tokentypes_np)
320
+
321
+ # print(encoder_labels_np.shape)
322
+ # print(encoder_labels_np)
323
+
324
+ # print(encoder_labels_mask_np.shape)
325
+ # print(encoder_labels_mask_np)
326
+
327
+ # print(padding_mask_np.shape)
328
+ # print(padding_mask_np)
329
+
330
+ # print(decoder_input_np.shape)
331
+ # print(decoder_input_np)
332
+
333
+ # print(decoder_output_np.shape)
334
+ # print(decoder_output_np)
335
+
336
+ # print("=====\n\n\n")
337
+ # import sys;sys.exit(0)
338
+ return train_sample
339
+
340
+ def pad_and_convert_to_numpy_light(tokens, max_seq_length, pad_id):
341
+ padding_length = max_seq_length - len(tokens)
342
+ assert padding_length >= 0
343
+ filler = [pad_id] * padding_length
344
+ tokens_np = np.array(tokens + filler, dtype=np.int64)
345
+ return tokens_np
346
+
347
+ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
348
+ masked_labels, pad_id, max_seq_length):
349
+ """Pad sequences and convert them to numpy."""
350
+
351
+ # Some checks.
352
+ num_tokens = len(tokens)
353
+ padding_length = max_seq_length - num_tokens
354
+ assert padding_length >= 0
355
+ assert len(tokentypes) == num_tokens
356
+ assert len(masked_positions) == len(masked_labels)
357
+
358
+ # Tokens and token types.
359
+ filler = [pad_id] * padding_length
360
+ tokens_np = np.array(tokens + filler, dtype=np.int64)
361
+ tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
362
+
363
+ # Padding mask.
364
+ padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
365
+ dtype=np.int64)
366
+
367
+ # Lables and loss mask.
368
+ labels = [-1] * max_seq_length
369
+ loss_mask = [0] * max_seq_length
370
+ for i in range(len(masked_positions)):
371
+ assert masked_positions[i] < num_tokens
372
+ labels[masked_positions[i]] = masked_labels[i]
373
+ loss_mask[masked_positions[i]] = 1
374
+ labels_np = np.array(labels, dtype=np.int64)
375
+ loss_mask_np = np.array(loss_mask, dtype=np.int64)
376
+
377
+ return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
megatron/data/gpt_dataset.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """GPT style dataset."""
17
+
18
+ import os
19
+ import time
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ from megatron import mpu, print_rank_0
25
+ from megatron.data.blendable_dataset import BlendableDataset
26
+ from megatron.data.dataset_utils import get_datasets_weights_and_num_samples
27
+ from megatron.data.dataset_utils import get_train_valid_test_split_
28
+ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
29
+
30
+
31
+ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
32
+ train_valid_test_num_samples,
33
+ seq_length, seed, skip_warmup):
34
+ """Build train, valid, and test datasets."""
35
+
36
+ # Single dataset.
37
+ if len(data_prefix) == 1:
38
+ return _build_train_valid_test_datasets(data_prefix[0],
39
+ data_impl, splits_string,
40
+ train_valid_test_num_samples,
41
+ seq_length, seed, skip_warmup)
42
+
43
+ # Blending dataset.
44
+ # Parse the values.
45
+ output = get_datasets_weights_and_num_samples(data_prefix,
46
+ train_valid_test_num_samples)
47
+ prefixes, weights, datasets_train_valid_test_num_samples = output
48
+
49
+ # Build individual datasets.
50
+ train_datasets = []
51
+ valid_datasets = []
52
+ test_datasets = []
53
+ for i in range(len(prefixes)):
54
+ train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
55
+ prefixes[i], data_impl, splits_string,
56
+ datasets_train_valid_test_num_samples[i],
57
+ seq_length, seed, skip_warmup)
58
+ if train_ds:
59
+ train_datasets.append(train_ds)
60
+ if valid_ds:
61
+ valid_datasets.append(valid_ds)
62
+ if test_ds:
63
+ test_datasets.append(test_ds)
64
+
65
+ # Blend.
66
+ blending_train_dataset = None
67
+ if train_datasets:
68
+ blending_train_dataset = BlendableDataset(train_datasets, weights)
69
+ blending_valid_dataset = None
70
+ if valid_datasets:
71
+ blending_valid_dataset = BlendableDataset(valid_datasets, weights)
72
+ blending_test_dataset = None
73
+ if test_datasets:
74
+ blending_test_dataset = BlendableDataset(test_datasets, weights)
75
+
76
+ return (blending_train_dataset, blending_valid_dataset,
77
+ blending_test_dataset)
78
+
79
+
80
+ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
81
+ train_valid_test_num_samples,
82
+ seq_length, seed, skip_warmup):
83
+ """Build train, valid, and test datasets."""
84
+
85
+ # Indexed dataset.
86
+ indexed_dataset = get_indexed_dataset_(data_prefix,
87
+ data_impl,
88
+ skip_warmup)
89
+
90
+ total_num_of_documents = indexed_dataset.sizes.shape[0]
91
+ splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
92
+
93
+ # Print stats about the splits.
94
+ print_rank_0(' > dataset split:')
95
+
96
+ def print_split_stats(name, index):
97
+ print_rank_0(' {}:'.format(name))
98
+ print_rank_0(' document indices in [{}, {}) total of {} '
99
+ 'documents'.format(splits[index], splits[index + 1],
100
+ splits[index + 1] - splits[index]))
101
+ print_split_stats('train', 0)
102
+ print_split_stats('validation', 1)
103
+ print_split_stats('test', 2)
104
+
105
+ def build_dataset(index, name):
106
+ dataset = None
107
+ if splits[index + 1] > splits[index]:
108
+ documents = np.arange(start=splits[index], stop=splits[index + 1],
109
+ step=1, dtype=np.int32)
110
+ dataset = GPTDataset(name, data_prefix,
111
+ documents, indexed_dataset,
112
+ train_valid_test_num_samples[index],
113
+ seq_length, seed)
114
+ return dataset
115
+
116
+ train_dataset = build_dataset(0, 'train')
117
+ valid_dataset = build_dataset(1, 'valid')
118
+ test_dataset = build_dataset(2, 'test')
119
+
120
+ return (train_dataset, valid_dataset, test_dataset)
121
+
122
+
123
+ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
124
+ """Build indexed dataset."""
125
+ print_rank_0(' > building dataset index ...')
126
+
127
+ start_time = time.time()
128
+ indexed_dataset = make_indexed_dataset(data_prefix,
129
+ data_impl,
130
+ skip_warmup)
131
+ print_rank_0(' > finished creating indexed dataset in {:4f} '
132
+ 'seconds'.format(time.time() - start_time))
133
+ print_rank_0(' number of documents: {}'.format(
134
+ indexed_dataset.sizes.shape[0]))
135
+
136
+ return indexed_dataset
137
+
138
+
139
+ class GPTDataset(torch.utils.data.Dataset):
140
+
141
+ def __init__(self, name, data_prefix, documents, indexed_dataset,
142
+ num_samples, seq_length, seed):
143
+
144
+ self.name = name
145
+ self.indexed_dataset = indexed_dataset
146
+
147
+ # Checks
148
+ assert np.min(documents) >= 0
149
+ assert np.max(documents) < indexed_dataset.sizes.shape[0]
150
+
151
+ # Build index mappings.
152
+ self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
153
+ self.name, data_prefix, documents, self.indexed_dataset.sizes,
154
+ num_samples, seq_length, seed)
155
+
156
+ def __len__(self):
157
+ # -1 is due to data structure used to retieve the index:
158
+ # sample i --> [sample_idx[i], sample_idx[i+1])
159
+ return self.sample_idx.shape[0] - 1
160
+
161
+ def __getitem__(self, idx):
162
+ # Get the shuffled index.
163
+ idx = self.shuffle_idx[idx]
164
+ # Start and end documents and offsets.
165
+ doc_index_f = self.sample_idx[idx][0]
166
+ doc_index_l = self.sample_idx[idx + 1][0]
167
+ offset_f = self.sample_idx[idx][1]
168
+ offset_l = self.sample_idx[idx + 1][1]
169
+ # If we are within the same document, just extract the chunk.
170
+ if doc_index_f == doc_index_l:
171
+ sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
172
+ offset=offset_f,
173
+ length=offset_l - offset_f + 1)
174
+ else:
175
+ # Otherwise, get the rest of the initial document.
176
+ sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
177
+ offset=offset_f)]
178
+ # Loop over all in between documents and add the entire document.
179
+ for i in range(doc_index_f + 1, doc_index_l):
180
+ sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
181
+ # And finally add the relevant portion of last document.
182
+ sample_list.append(self.indexed_dataset.get(
183
+ self.doc_idx[doc_index_l],
184
+ length=offset_l + 1))
185
+ sample = np.concatenate(sample_list)
186
+
187
+ return {'text': np.array(sample, dtype=np.int64)}
188
+
189
+
190
+ def _build_index_mappings(name, data_prefix, documents, sizes,
191
+ num_samples, seq_length, seed):
192
+ """Build doc-idx, sample-idx, and shuffle-idx.
193
+ doc-idx: is an array (ordered) of documents to be used in training.
194
+ sample-idx: is the start document index and document offset for each
195
+ training sample.
196
+ shuffle-idx: maps the sample index into a random index into sample-idx.
197
+ """
198
+ # Number of tokens in each epoch and number of required epochs.
199
+ tokens_per_epoch = _num_tokens(documents, sizes)
200
+ num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
201
+ # rng state
202
+ np_rng = np.random.RandomState(seed=seed)
203
+
204
+ # Filename of the index mappings.
205
+ _filename = data_prefix
206
+ _filename += '_{}_indexmap'.format(name)
207
+ _filename += '_{}ns'.format(num_samples)
208
+ _filename += '_{}sl'.format(seq_length)
209
+ _filename += '_{}s'.format(seed)
210
+ doc_idx_filename = _filename + '_doc_idx.npy'
211
+ sample_idx_filename = _filename + '_sample_idx.npy'
212
+ shuffle_idx_filename = _filename + '_shuffle_idx.npy'
213
+
214
+ # Build the indexed mapping if not exist.
215
+ if torch.distributed.get_rank() == 0:
216
+ if (not os.path.isfile(doc_idx_filename)) or \
217
+ (not os.path.isfile(sample_idx_filename)) or \
218
+ (not os.path.isfile(shuffle_idx_filename)):
219
+
220
+ print_rank_0(' > WARNING: could not find index map files, building '
221
+ 'the indices on rank 0 ...')
222
+
223
+ # For the last epoch, decide whether include the entire epoch
224
+ # in the global shuffle or not.
225
+
226
+ # If we need only one epoch, then separating last epoch does
227
+ # not mean anything.
228
+ if num_epochs == 1:
229
+ separate_last_epoch = False
230
+ print(' > only one epoch required, setting '
231
+ 'separate_last_epoch to False', flush=True)
232
+
233
+ else:
234
+ # Get the number of samples for the last epoch
235
+ num_samples_from_epochs_minus_one = (
236
+ (num_epochs - 1) * tokens_per_epoch - 1) // seq_length
237
+ last_epoch_num_samples = num_samples - \
238
+ num_samples_from_epochs_minus_one
239
+ assert last_epoch_num_samples >= 0, \
240
+ 'last epoch number of samples should be non-negative.'
241
+ num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
242
+ assert last_epoch_num_samples < (num_samples_per_epoch + 1), \
243
+ 'last epoch number of samples exceeded max value.'
244
+ # If we have less than 80% of the samples for the last epoch,
245
+ # seperate out the epoch and treat it differently.
246
+ # Note: the 80% number is just based on common sense and can
247
+ # be adjusted if needed.
248
+ separate_last_epoch = (last_epoch_num_samples <
249
+ int(0.80 * num_samples_per_epoch))
250
+ if separate_last_epoch:
251
+ string = ' > last epoch number of samples ({}) is smaller '\
252
+ 'than 80% of number of samples per epoch ({}), '\
253
+ 'setting separate_last_epoch to True'
254
+ else:
255
+ string = ' > last epoch number of samples ({}) is larger '\
256
+ 'than 80% of number of samples per epoch ({}), '\
257
+ 'setting separate_last_epoch to False'
258
+ print(string.format(last_epoch_num_samples,
259
+ num_samples_per_epoch), flush=True)
260
+
261
+ # doc-idx.
262
+ start_time = time.time()
263
+ doc_idx = _build_doc_idx(documents, num_epochs, np_rng,
264
+ separate_last_epoch)
265
+ np.save(doc_idx_filename, doc_idx, allow_pickle=True)
266
+ print_rank_0(' > elasped time to build and save doc-idx mapping '
267
+ '(seconds): {:4f}'.format(time.time() - start_time))
268
+ # sample-idx.
269
+ start_time = time.time()
270
+ # Use C++ implementation for speed.
271
+ # First compile and then import.
272
+ from megatron.data import helpers
273
+ assert doc_idx.dtype == np.int32
274
+ assert sizes.dtype == np.int32
275
+ sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
276
+ num_epochs, tokens_per_epoch)
277
+ # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
278
+ # num_epochs, tokens_per_epoch)
279
+ np.save(sample_idx_filename, sample_idx, allow_pickle=True)
280
+ print_rank_0(' > elasped time to build and save sample-idx mapping '
281
+ '(seconds): {:4f}'.format(time.time() - start_time))
282
+ # shuffle-idx.
283
+ start_time = time.time()
284
+ # -1 is due to data structure used to retieve the index:
285
+ # sample i --> [sample_idx[i], sample_idx[i+1])
286
+ if separate_last_epoch:
287
+ num_samples_ = num_samples_from_epochs_minus_one
288
+ else:
289
+ num_samples_ = sample_idx.shape[0] - 1
290
+ shuffle_idx = _build_shuffle_idx(num_samples_,
291
+ sample_idx.shape[0] - 1, np_rng)
292
+ np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
293
+ print_rank_0(' > elasped time to build and save shuffle-idx mapping'
294
+ ' (seconds): {:4f}'.format(time.time() - start_time))
295
+
296
+ # This should be a barrier but nccl barrier assumes
297
+ # device_index=rank which is not the case for model
298
+ # parallel case
299
+ counts = torch.cuda.LongTensor([1])
300
+ torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
301
+ torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
302
+ assert counts[0].item() == (
303
+ torch.distributed.get_world_size() //
304
+ torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
305
+
306
+ # Load mappings.
307
+ start_time = time.time()
308
+ print_rank_0(' > loading doc-idx mapping from {}'.format(
309
+ doc_idx_filename))
310
+ doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r')
311
+ print_rank_0(' > loading sample-idx mapping from {}'.format(
312
+ sample_idx_filename))
313
+ sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r')
314
+ print_rank_0(' > loading shuffle-idx mapping from {}'.format(
315
+ shuffle_idx_filename))
316
+ shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r')
317
+ print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
318
+ time.time() - start_time))
319
+ print_rank_0(' total number of samples: {}'.format(
320
+ sample_idx.shape[0]))
321
+ print_rank_0(' total number of epochs: {}'.format(num_epochs))
322
+
323
+ return doc_idx, sample_idx, shuffle_idx
324
+
325
+
326
+ def _num_tokens(documents, sizes):
327
+ """Total number of tokens in the dataset."""
328
+ return np.sum(sizes[documents])
329
+
330
+
331
+ def _num_epochs(tokens_per_epoch, seq_length, num_samples):
332
+ """Based on number of samples and sequence lenght, calculate how many
333
+ epochs will be needed."""
334
+ num_epochs = 0
335
+ total_tokens = 0
336
+ while True:
337
+ num_epochs += 1
338
+ total_tokens += tokens_per_epoch
339
+ # -1 is because we need to retrieve seq_length + 1 token each time
340
+ # but the last token will overlap with the first token of the next
341
+ # sample except for the last sample.
342
+ if ((total_tokens - 1) // seq_length) >= num_samples:
343
+ return num_epochs
344
+
345
+
346
+ def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
347
+ """Build an array with length = number-of-epochs * number-of-dcuments.
348
+ Each index is mapped to a corresponding document."""
349
+ if not separate_last_epoch or num_epochs == 1:
350
+ doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1]
351
+ doc_idx[:] = documents
352
+ doc_idx = doc_idx.reshape(-1)
353
+ doc_idx = doc_idx.astype(np.int32)
354
+ np_rng.shuffle(doc_idx)
355
+ return doc_idx
356
+
357
+ doc_idx_first = _build_doc_idx(documents, num_epochs-1, np_rng, False)
358
+ doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
359
+ return np.concatenate((doc_idx_first, doc_idx_last))
360
+
361
+
362
+ def _build_sample_idx(sizes, doc_idx, seq_length,
363
+ num_epochs, tokens_per_epoch):
364
+ """Sample index mapping is a 2D array with sizes
365
+ [number-of-samples + 1, 2] where [..., 0] contains
366
+ the index into `doc_idx` and [..., 1] is the
367
+ starting offset in that document."""
368
+
369
+ # Total number of samples. For -1 see comments in `_num_epochs`.
370
+ num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
371
+ sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
372
+
373
+ # Index into sample_idx.
374
+ sample_index = 0
375
+ # Index into doc_idx.
376
+ doc_idx_index = 0
377
+ # Begining offset for each document.
378
+ doc_offset = 0
379
+ # Start with first document and no offset.
380
+ sample_idx[sample_index][0] = doc_idx_index
381
+ sample_idx[sample_index][1] = doc_offset
382
+ sample_index += 1
383
+ while sample_index <= num_samples:
384
+ # Start with a fresh sequence.
385
+ remaining_seq_length = seq_length + 1
386
+ while remaining_seq_length != 0:
387
+ # Get the document length.
388
+ doc_id = doc_idx[doc_idx_index]
389
+ doc_length = sizes[doc_id] - doc_offset
390
+ # And add it to the current sequence.
391
+ remaining_seq_length -= doc_length
392
+ # If we have more than a full sequence, adjust offset and set
393
+ # remaining length to zero so we return from the while loop.
394
+ # Note that -1 here is for the same reason we have -1 in
395
+ # `_num_epochs` calculations.
396
+ if remaining_seq_length <= 0:
397
+ doc_offset += (remaining_seq_length + doc_length - 1)
398
+ remaining_seq_length = 0
399
+ else:
400
+ # Otherwise, start from the begining of the next document.
401
+ doc_idx_index += 1
402
+ doc_offset = 0
403
+ # Record the sequence.
404
+ sample_idx[sample_index][0] = doc_idx_index
405
+ sample_idx[sample_index][1] = doc_offset
406
+ sample_index += 1
407
+
408
+ return sample_idx
409
+
410
+
411
+ def _build_shuffle_idx(num_samples, total_size, np_rng):
412
+ """Build the range [0, size) and shuffle."""
413
+ print(' > building shuffle index with split [0, {}) and [{}, {}) '
414
+ '...'.format(num_samples, num_samples, total_size), flush=True)
415
+
416
+ dtype_ = np.uint32
417
+ if total_size >= (np.iinfo(np.uint32).max - 1):
418
+ dtype_ = np.int64
419
+
420
+ shuffle_idx_first = np.arange(start=0, stop=num_samples,
421
+ step=1, dtype=dtype_)
422
+ np_rng.shuffle(shuffle_idx_first)
423
+ if num_samples == total_size:
424
+ return shuffle_idx_first
425
+
426
+ shuffle_idx_last = np.arange(start=num_samples, stop=total_size,
427
+ step=1, dtype=dtype_)
428
+ np_rng.shuffle(shuffle_idx_last)
429
+
430
+ return np.concatenate((shuffle_idx_first, shuffle_idx_last))
megatron/data/helpers.cpp ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ coding=utf-8
3
+ Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ */
17
+
18
+
19
+ /* Helper methods for fast index mapping builds */
20
+
21
+ #include <algorithm>
22
+ #include <iostream>
23
+ #include <limits>
24
+ #include <math.h>
25
+ #include <stdexcept>
26
+ #include <pybind11/pybind11.h>
27
+ #include <pybind11/numpy.h>
28
+ #include <random>
29
+
30
+ namespace py = pybind11;
31
+ using namespace std;
32
+
33
+ const int32_t LONG_SENTENCE_LEN = 512;
34
+
35
+
36
+ void build_blending_indices(py::array_t<uint8_t>& dataset_index,
37
+ py::array_t<int64_t>& dataset_sample_index,
38
+ const py::array_t<double>& weights,
39
+ const int32_t num_datasets,
40
+ const int64_t size, const bool verbose) {
41
+ /* Given multiple datasets and a weighting array, build samples
42
+ such that it follows those wieghts.*/
43
+
44
+ if (verbose) {
45
+ std::cout << "> building indices for blendable datasets ..." << std::endl;
46
+ }
47
+
48
+ // Get the pointer access without the checks.
49
+ auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
50
+ auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
51
+ auto weights_ptr = weights.unchecked<1>();
52
+
53
+ // Initialize buffer for number of samples used for each dataset.
54
+ int64_t current_samples[num_datasets];
55
+ for(int64_t i = 0; i < num_datasets; ++i) {
56
+ current_samples[i] = 0;
57
+ }
58
+
59
+ // For each sample:
60
+ for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
61
+
62
+ // Determine where the max error in sampling is happening.
63
+ auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
64
+ int64_t max_error_index = 0;
65
+ double max_error = weights_ptr[0] * sample_idx_double -
66
+ static_cast<double>(current_samples[0]);
67
+ for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
68
+ double error = weights_ptr[dataset_idx] * sample_idx_double -
69
+ static_cast<double>(current_samples[dataset_idx]);
70
+ if (error > max_error) {
71
+ max_error = error;
72
+ max_error_index = dataset_idx;
73
+ }
74
+ }
75
+
76
+ // Populate the indices.
77
+ dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
78
+ dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
79
+
80
+ // Update the total samples.
81
+ current_samples[max_error_index] += 1;
82
+
83
+ }
84
+
85
+ // print info
86
+ if (verbose) {
87
+ std::cout << " > sample ratios:" << std::endl;
88
+ for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
89
+ auto ratio = static_cast<double>(current_samples[dataset_idx]) /
90
+ static_cast<double>(size);
91
+ std::cout << " dataset " << dataset_idx << ", input: " <<
92
+ weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
93
+ }
94
+ }
95
+
96
+ }
97
+
98
+
99
+ py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
100
+ const py::array_t<int32_t>& doc_idx_,
101
+ const int32_t seq_length,
102
+ const int32_t num_epochs,
103
+ const int64_t tokens_per_epoch) {
104
+ /* Sample index (sample_idx) is used for gpt2 like dataset for which
105
+ the documents are flattened and the samples are built based on this
106
+ 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
107
+ where [..., 0] contains the index into `doc_idx` and [..., 1] is the
108
+ starting offset in that document.*/
109
+
110
+ // Consistency checks.
111
+ assert(seq_length > 1);
112
+ assert(num_epochs > 0);
113
+ assert(tokens_per_epoch > 1);
114
+
115
+ // Remove bound checks.
116
+ auto sizes = sizes_.unchecked<1>();
117
+ auto doc_idx = doc_idx_.unchecked<1>();
118
+
119
+ // Mapping and it's length (1D).
120
+ int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
121
+ int32_t* sample_idx = new int32_t[2*(num_samples+1)];
122
+
123
+ cout << " using:" << endl << std::flush;
124
+ cout << " number of documents: " <<
125
+ doc_idx_.shape(0) / num_epochs << endl << std::flush;
126
+ cout << " number of epochs: " << num_epochs <<
127
+ endl << std::flush;
128
+ cout << " sequence length: " << seq_length <<
129
+ endl << std::flush;
130
+ cout << " total number of samples: " << num_samples <<
131
+ endl << std::flush;
132
+
133
+ // Index into sample_idx.
134
+ int64_t sample_index = 0;
135
+ // Index into doc_idx.
136
+ int64_t doc_idx_index = 0;
137
+ // Begining offset for each document.
138
+ int32_t doc_offset = 0;
139
+ // Start with first document and no offset.
140
+ sample_idx[2 * sample_index] = doc_idx_index;
141
+ sample_idx[2 * sample_index + 1] = doc_offset;
142
+ ++sample_index;
143
+
144
+ while (sample_index <= num_samples) {
145
+ // Start with a fresh sequence.
146
+ int32_t remaining_seq_length = seq_length + 1;
147
+ while (remaining_seq_length != 0) {
148
+ // Get the document length.
149
+ auto doc_id = doc_idx[doc_idx_index];
150
+ auto doc_length = sizes[doc_id] - doc_offset;
151
+ // And add it to the current sequence.
152
+ remaining_seq_length -= doc_length;
153
+ // If we have more than a full sequence, adjust offset and set
154
+ // remaining length to zero so we return from the while loop.
155
+ // Note that -1 here is for the same reason we have -1 in
156
+ // `_num_epochs` calculations.
157
+ if (remaining_seq_length <= 0) {
158
+ doc_offset += (remaining_seq_length + doc_length - 1);
159
+ remaining_seq_length = 0;
160
+ } else {
161
+ // Otherwise, start from the begining of the next document.
162
+ ++doc_idx_index;
163
+ doc_offset = 0;
164
+ }
165
+ }
166
+ // Record the sequence.
167
+ sample_idx[2 * sample_index] = doc_idx_index;
168
+ sample_idx[2 * sample_index + 1] = doc_offset;
169
+ ++sample_index;
170
+ }
171
+
172
+ // Method to deallocate memory.
173
+ py::capsule free_when_done(sample_idx, [](void *mem_) {
174
+ int32_t *mem = reinterpret_cast<int32_t*>(mem_);
175
+ delete[] mem;
176
+ });
177
+
178
+ // Return the numpy array.
179
+ const auto byte_size = sizeof(int32_t);
180
+ return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
181
+ {2*byte_size, byte_size}, // C-style contiguous strides
182
+ sample_idx, // the data pointer
183
+ free_when_done); // numpy array references
184
+
185
+ }
186
+
187
+
188
+ inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
189
+ const int32_t max_length,
190
+ std::mt19937& rand32_gen) {
191
+ /* Training sample length. */
192
+ if (short_seq_ratio == 0) {
193
+ return max_length;
194
+ }
195
+ const auto random_number = rand32_gen();
196
+ if ((random_number % short_seq_ratio) == 0) {
197
+ return 2 + random_number % (max_length - 1);
198
+ }
199
+ return max_length;
200
+ }
201
+
202
+
203
+ template<typename DocIdx>
204
+ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
205
+ const py::array_t<int32_t>& sizes_,
206
+ const int32_t num_epochs,
207
+ const uint64_t max_num_samples,
208
+ const int32_t max_seq_length,
209
+ const double short_seq_prob,
210
+ const int32_t seed,
211
+ const bool verbose,
212
+ const int32_t min_num_sent) {
213
+ /* Build a mapping of (start-index, end-index, sequence-length) where
214
+ start and end index are the indices of the sentences in the sample
215
+ and sequence-length is the target sequence length.
216
+ */
217
+
218
+ // Consistency checks.
219
+ assert(num_epochs > 0);
220
+ assert(max_seq_length > 1);
221
+ assert(short_seq_prob >= 0.0);
222
+ assert(short_seq_prob <= 1.0);
223
+ assert(seed > 0);
224
+
225
+ // Remove bound checks.
226
+ auto docs = docs_.unchecked<1>();
227
+ auto sizes = sizes_.unchecked<1>();
228
+
229
+ // For efficiency, convert probability to ratio. Note: rand() generates int.
230
+ int32_t short_seq_ratio = 0;
231
+ if (short_seq_prob > 0) {
232
+ short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
233
+ }
234
+
235
+ if (verbose) {
236
+ const auto sent_start_index = docs[0];
237
+ const auto sent_end_index = docs[docs_.shape(0) - 1];
238
+ const auto num_sentences = sent_end_index - sent_start_index;
239
+ cout << " using:" << endl << std::flush;
240
+ cout << " number of documents: " << docs_.shape(0) - 1 <<
241
+ endl << std::flush;
242
+ cout << " sentences range: [" << sent_start_index <<
243
+ ", " << sent_end_index << ")" << endl << std::flush;
244
+ cout << " total number of sentences: " << num_sentences <<
245
+ endl << std::flush;
246
+ cout << " number of epochs: " << num_epochs <<
247
+ endl << std::flush;
248
+ cout << " maximum number of samples: " << max_num_samples <<
249
+ endl << std::flush;
250
+ cout << " maximum sequence length: " << max_seq_length <<
251
+ endl << std::flush;
252
+ cout << " short sequence probability: " << short_seq_prob <<
253
+ endl << std::flush;
254
+ cout << " short sequence ration (1/prob): " << short_seq_ratio <<
255
+ endl << std::flush;
256
+ cout << " seed: " << seed << endl <<
257
+ std::flush;
258
+ }
259
+
260
+ // Mapping and it's length (1D).
261
+ int64_t num_samples = -1;
262
+ DocIdx* maps = NULL;
263
+
264
+ // Perform two iterations, in the first iteration get the size
265
+ // and allocate memory and in the second iteration populate the map.
266
+ bool second = false;
267
+ for (int32_t iteration=0; iteration<2; ++iteration) {
268
+
269
+ // Set the seed so both iterations produce the same results.
270
+ std::mt19937 rand32_gen(seed);
271
+
272
+ // Set the flag on second iteration.
273
+ second = (iteration == 1);
274
+
275
+ // Counters:
276
+ uint64_t empty_docs = 0;
277
+ uint64_t one_sent_docs = 0;
278
+ uint64_t long_sent_docs = 0;
279
+
280
+ // Current map index.
281
+ uint64_t map_index = 0;
282
+
283
+ // For each epoch:
284
+ for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
285
+ if (map_index >= max_num_samples) {
286
+ if (verbose && (!second)) {
287
+ cout << " reached " << max_num_samples << " samples after "
288
+ << epoch << " epochs ..." << endl << std::flush;
289
+ }
290
+ break;
291
+ }
292
+ // For each document:
293
+ for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
294
+
295
+ // Document sentences are in [sent_index_first, sent_index_last)
296
+ const auto sent_index_first = docs[doc];
297
+ const auto sent_index_last = docs[doc + 1];
298
+
299
+ // At the begining of the document previous index is the
300
+ // start index.
301
+ auto prev_start_index = sent_index_first;
302
+
303
+ // Remaining documents.
304
+ auto num_remain_sent = sent_index_last - sent_index_first;
305
+
306
+ // Some bookkeeping
307
+ if ((epoch == 0) && (!second)) {
308
+ if (num_remain_sent == 0) {
309
+ ++empty_docs;
310
+ }
311
+ if (num_remain_sent == 1) {
312
+ ++one_sent_docs;
313
+ }
314
+ }
315
+
316
+ // Detect documents with long sentences.
317
+ bool contains_long_sentence = false;
318
+ if (num_remain_sent > 1) {
319
+ for (auto sent_index=sent_index_first;
320
+ sent_index < sent_index_last; ++sent_index) {
321
+ if (sizes[sent_index] > LONG_SENTENCE_LEN){
322
+ if ((epoch == 0) && (!second)) {
323
+ ++long_sent_docs;
324
+ }
325
+ contains_long_sentence = true;
326
+ break;
327
+ }
328
+ }
329
+ }
330
+
331
+ // If we have more than two sentences.
332
+ if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
333
+
334
+ // Set values.
335
+ auto seq_len = int32_t{0};
336
+ auto num_sent = int32_t{0};
337
+ auto target_seq_len = get_target_sample_len(short_seq_ratio,
338
+ max_seq_length,
339
+ rand32_gen);
340
+
341
+ // Loop through sentences.
342
+ for (auto sent_index=sent_index_first;
343
+ sent_index < sent_index_last; ++sent_index) {
344
+
345
+ // Add the size and number of sentences.
346
+ seq_len += sizes[sent_index];
347
+ ++num_sent;
348
+ --num_remain_sent;
349
+
350
+ // If we have reached the target length.
351
+ // and if not only one sentence is left in the document.
352
+ // and if we have at least two sentneces.
353
+ // and if we have reached end of the document.
354
+ if (((seq_len >= target_seq_len) &&
355
+ (num_remain_sent > 1) &&
356
+ (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
357
+
358
+ // Check for overflow.
359
+ if ((3 * map_index + 2) >
360
+ std::numeric_limits<int64_t>::max()) {
361
+ cout << "number of samples exceeded maximum "
362
+ << "allowed by type int64: "
363
+ << std::numeric_limits<int64_t>::max()
364
+ << endl;
365
+ throw std::overflow_error("Number of samples");
366
+ }
367
+
368
+ // Populate the map.
369
+ if (second) {
370
+ const auto map_index_0 = 3 * map_index;
371
+ maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
372
+ maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
373
+ maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
374
+ }
375
+
376
+ // Update indices / counters.
377
+ ++map_index;
378
+ prev_start_index = sent_index + 1;
379
+ target_seq_len = get_target_sample_len(short_seq_ratio,
380
+ max_seq_length,
381
+ rand32_gen);
382
+ seq_len = 0;
383
+ num_sent = 0;
384
+ }
385
+
386
+ } // for (auto sent_index=sent_index_first; ...
387
+ } // if (num_remain_sent > 1) {
388
+ } // for (int doc=0; doc < num_docs; ++doc) {
389
+ } // for (int epoch=0; epoch < num_epochs; ++epoch) {
390
+
391
+ if (!second) {
392
+ if (verbose) {
393
+ cout << " number of empty documents: " << empty_docs <<
394
+ endl << std::flush;
395
+ cout << " number of documents with one sentence: " <<
396
+ one_sent_docs << endl << std::flush;
397
+ cout << " number of documents with long sentences: " <<
398
+ long_sent_docs << endl << std::flush;
399
+ cout << " will create mapping for " << map_index <<
400
+ " samples" << endl << std::flush;
401
+ }
402
+ assert(maps == NULL);
403
+ assert(num_samples < 0);
404
+ maps = new DocIdx[3*map_index];
405
+ num_samples = static_cast<int64_t>(map_index);
406
+ }
407
+
408
+ } // for (int iteration=0; iteration < 2; ++iteration) {
409
+
410
+ // Shuffle.
411
+ // We need a 64 bit random number generator as we might have more
412
+ // than 2 billion samples.
413
+ std::mt19937_64 rand64_gen(seed + 1);
414
+ for (auto i=(num_samples - 1); i > 0; --i) {
415
+ const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
416
+ const auto i0 = 3 * i;
417
+ const auto j0 = 3 * j;
418
+ // Swap values.
419
+ swap(maps[i0], maps[j0]);
420
+ swap(maps[i0 + 1], maps[j0 + 1]);
421
+ swap(maps[i0 + 2], maps[j0 + 2]);
422
+ }
423
+
424
+ // Method to deallocate memory.
425
+ py::capsule free_when_done(maps, [](void *mem_) {
426
+ DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
427
+ delete[] mem;
428
+ });
429
+
430
+ // Return the numpy array.
431
+ const auto byte_size = sizeof(DocIdx);
432
+ return py::array(std::vector<int64_t>{num_samples, 3}, // shape
433
+ {3*byte_size, byte_size}, // C-style contiguous strides
434
+ maps, // the data pointer
435
+ free_when_done); // numpy array references
436
+
437
+ }
438
+
439
+
440
+ py::array build_mapping(const py::array_t<int64_t>& docs_,
441
+ const py::array_t<int>& sizes_,
442
+ const int num_epochs,
443
+ const uint64_t max_num_samples,
444
+ const int max_seq_length,
445
+ const double short_seq_prob,
446
+ const int seed,
447
+ const bool verbose,
448
+ const int32_t min_num_sent) {
449
+
450
+ if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
451
+ if (verbose) {
452
+ cout << " using uint64 for data mapping..." << endl << std::flush;
453
+ }
454
+ return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
455
+ max_num_samples, max_seq_length,
456
+ short_seq_prob, seed, verbose,
457
+ min_num_sent);
458
+ } else {
459
+ if (verbose) {
460
+ cout << " using uint32 for data mapping..." << endl << std::flush;
461
+ }
462
+ return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
463
+ max_num_samples, max_seq_length,
464
+ short_seq_prob, seed, verbose,
465
+ min_num_sent);
466
+ }
467
+ }
468
+
469
+ template<typename DocIdx>
470
+ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
471
+ const py::array_t<int32_t>& sizes_,
472
+ const py::array_t<int32_t>& titles_sizes_,
473
+ const int32_t num_epochs,
474
+ const uint64_t max_num_samples,
475
+ const int32_t max_seq_length,
476
+ const int32_t seed,
477
+ const bool verbose,
478
+ const bool use_one_sent_blocks) {
479
+ /* Build a mapping of (start-index, end-index, sequence-length) where
480
+ start and end index are the indices of the sentences in the sample
481
+ and sequence-length is the target sequence length.
482
+ */
483
+
484
+ // Consistency checks.
485
+ assert(num_epochs > 0);
486
+ assert(max_seq_length > 1);
487
+ assert(seed > 0);
488
+
489
+ // Remove bound checks.
490
+ auto docs = docs_.unchecked<1>();
491
+ auto sizes = sizes_.unchecked<1>();
492
+ auto titles_sizes = titles_sizes_.unchecked<1>();
493
+
494
+ if (verbose) {
495
+ const auto sent_start_index = docs[0];
496
+ const auto sent_end_index = docs[docs_.shape(0) - 1];
497
+ const auto num_sentences = sent_end_index - sent_start_index;
498
+ cout << " using:" << endl << std::flush;
499
+ cout << " number of documents: " << docs_.shape(0) - 1 <<
500
+ endl << std::flush;
501
+ cout << " sentences range: [" << sent_start_index <<
502
+ ", " << sent_end_index << ")" << endl << std::flush;
503
+ cout << " total number of sentences: " << num_sentences <<
504
+ endl << std::flush;
505
+ cout << " number of epochs: " << num_epochs <<
506
+ endl << std::flush;
507
+ cout << " maximum number of samples: " << max_num_samples <<
508
+ endl << std::flush;
509
+ cout << " maximum sequence length: " << max_seq_length <<
510
+ endl << std::flush;
511
+ cout << " seed: " << seed << endl <<
512
+ std::flush;
513
+ }
514
+
515
+ // Mapping and its length (1D).
516
+ int64_t num_samples = -1;
517
+ DocIdx* maps = NULL;
518
+
519
+ // Acceptable number of sentences per block.
520
+ int min_num_sent = 2;
521
+ if (use_one_sent_blocks) {
522
+ min_num_sent = 1;
523
+ }
524
+
525
+ // Perform two iterations, in the first iteration get the size
526
+ // and allocate memory and in the second iteration populate the map.
527
+ bool second = false;
528
+ for (int32_t iteration=0; iteration<2; ++iteration) {
529
+
530
+ // Set the flag on second iteration.
531
+ second = (iteration == 1);
532
+
533
+ // Current map index.
534
+ uint64_t map_index = 0;
535
+
536
+ uint64_t empty_docs = 0;
537
+ uint64_t one_sent_docs = 0;
538
+ uint64_t long_sent_docs = 0;
539
+ // For each epoch:
540
+ for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
541
+ // assign every block a unique id
542
+ int32_t block_id = 0;
543
+
544
+ if (map_index >= max_num_samples) {
545
+ if (verbose && (!second)) {
546
+ cout << " reached " << max_num_samples << " samples after "
547
+ << epoch << " epochs ..." << endl << std::flush;
548
+ }
549
+ break;
550
+ }
551
+ // For each document:
552
+ for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
553
+
554
+ // Document sentences are in [sent_index_first, sent_index_last)
555
+ const auto sent_index_first = docs[doc];
556
+ const auto sent_index_last = docs[doc + 1];
557
+ const auto target_seq_len = max_seq_length - titles_sizes[doc];
558
+
559
+ // At the begining of the document previous index is the
560
+ // start index.
561
+ auto prev_start_index = sent_index_first;
562
+
563
+ // Remaining documents.
564
+ auto num_remain_sent = sent_index_last - sent_index_first;
565
+
566
+ // Some bookkeeping
567
+ if ((epoch == 0) && (!second)) {
568
+ if (num_remain_sent == 0) {
569
+ ++empty_docs;
570
+ }
571
+ if (num_remain_sent == 1) {
572
+ ++one_sent_docs;
573
+ }
574
+ }
575
+ // Detect documents with long sentences.
576
+ bool contains_long_sentence = false;
577
+ if (num_remain_sent >= min_num_sent) {
578
+ for (auto sent_index=sent_index_first;
579
+ sent_index < sent_index_last; ++sent_index) {
580
+ if (sizes[sent_index] > LONG_SENTENCE_LEN){
581
+ if ((epoch == 0) && (!second)) {
582
+ ++long_sent_docs;
583
+ }
584
+ contains_long_sentence = true;
585
+ break;
586
+ }
587
+ }
588
+ }
589
+ // If we have enough sentences and no long sentences.
590
+ if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
591
+
592
+ // Set values.
593
+ auto seq_len = int32_t{0};
594
+ auto num_sent = int32_t{0};
595
+
596
+ // Loop through sentences.
597
+ for (auto sent_index=sent_index_first;
598
+ sent_index < sent_index_last; ++sent_index) {
599
+
600
+ // Add the size and number of sentences.
601
+ seq_len += sizes[sent_index];
602
+ ++num_sent;
603
+ --num_remain_sent;
604
+
605
+ // If we have reached the target length.
606
+ // and there are an acceptable number of sentences left
607
+ // and if we have at least the minimum number of sentences.
608
+ // or if we have reached end of the document.
609
+ if (((seq_len >= target_seq_len) &&
610
+ (num_remain_sent >= min_num_sent) &&
611
+ (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
612
+
613
+ // Populate the map.
614
+ if (second) {
615
+ const auto map_index_0 = 4 * map_index;
616
+ // Each sample has 4 items: the starting sentence index, ending sentence index,
617
+ // the index of the document from which the block comes (used for fetching titles)
618
+ // and the unique id of the block (used for creating block indexes)
619
+
620
+ maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
621
+ maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
622
+ maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
623
+ maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
624
+ }
625
+
626
+ // Update indices / counters.
627
+ ++map_index;
628
+ ++block_id;
629
+ prev_start_index = sent_index + 1;
630
+ seq_len = 0;
631
+ num_sent = 0;
632
+ }
633
+ } // for (auto sent_index=sent_index_first; ...
634
+ } // if (num_remain_sent > 1) {
635
+ } // for (int doc=0; doc < num_docs; ++doc) {
636
+ } // for (int epoch=0; epoch < num_epochs; ++epoch) {
637
+
638
+ if (!second) {
639
+ if (verbose) {
640
+ cout << " number of empty documents: " << empty_docs <<
641
+ endl << std::flush;
642
+ cout << " number of documents with one sentence: " <<
643
+ one_sent_docs << endl << std::flush;
644
+ cout << " number of documents with long sentences: " <<
645
+ long_sent_docs << endl << std::flush;
646
+ cout << " will create mapping for " << map_index <<
647
+ " samples" << endl << std::flush;
648
+ }
649
+ assert(maps == NULL);
650
+ assert(num_samples < 0);
651
+ maps = new DocIdx[4*map_index];
652
+ num_samples = static_cast<int64_t>(map_index);
653
+ }
654
+
655
+ } // for (int iteration=0; iteration < 2; ++iteration) {
656
+
657
+ // Shuffle.
658
+ // We need a 64 bit random number generator as we might have more
659
+ // than 2 billion samples.
660
+ std::mt19937_64 rand64_gen(seed + 1);
661
+ for (auto i=(num_samples - 1); i > 0; --i) {
662
+ const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
663
+ const auto i0 = 4 * i;
664
+ const auto j0 = 4 * j;
665
+ // Swap values.
666
+ swap(maps[i0], maps[j0]);
667
+ swap(maps[i0 + 1], maps[j0 + 1]);
668
+ swap(maps[i0 + 2], maps[j0 + 2]);
669
+ swap(maps[i0 + 3], maps[j0 + 3]);
670
+ }
671
+
672
+ // Method to deallocate memory.
673
+ py::capsule free_when_done(maps, [](void *mem_) {
674
+ DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
675
+ delete[] mem;
676
+ });
677
+
678
+ // Return the numpy array.
679
+ const auto byte_size = sizeof(DocIdx);
680
+ return py::array(std::vector<int64_t>{num_samples, 4}, // shape
681
+ {4*byte_size, byte_size}, // C-style contiguous strides
682
+ maps, // the data pointer
683
+ free_when_done); // numpy array references
684
+
685
+ }
686
+
687
+ py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
688
+ const py::array_t<int>& sizes_,
689
+ const py::array_t<int>& titles_sizes_,
690
+ const int num_epochs,
691
+ const uint64_t max_num_samples,
692
+ const int max_seq_length,
693
+ const int seed,
694
+ const bool verbose,
695
+ const bool use_one_sent_blocks) {
696
+
697
+ if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
698
+ if (verbose) {
699
+ cout << " using uint64 for data mapping..." << endl << std::flush;
700
+ }
701
+ return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
702
+ num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
703
+ } else {
704
+ if (verbose) {
705
+ cout << " using uint32 for data mapping..." << endl << std::flush;
706
+ }
707
+ return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
708
+ num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
709
+ }
710
+ }
711
+
712
+ PYBIND11_MODULE(helpers, m) {
713
+ m.def("build_mapping", &build_mapping);
714
+ m.def("build_blocks_mapping", &build_blocks_mapping);
715
+ m.def("build_sample_idx", &build_sample_idx);
716
+ m.def("build_blending_indices", &build_blending_indices);
717
+ }
megatron/data/helpers.cpython-38-x86_64-linux-gnu.so ADDED
Binary file (192 kB). View file
 
megatron/data/helpers.cpython-39-x86_64-linux-gnu.so ADDED
Binary file (212 kB). View file
 
megatron/data/ict_dataset.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import random
3
+
4
+ import numpy as np
5
+ from torch.utils.data import Dataset
6
+
7
+ from megatron import get_tokenizer
8
+ from megatron import get_args
9
+ from megatron.data.dataset_utils import get_indexed_dataset_
10
+ from megatron.data.realm_dataset_utils import get_block_samples_mapping
11
+
12
+ def make_attention_mask(source_block, target_block):
13
+ """
14
+ Returns a 2-dimensional (2-D) attention mask
15
+ :param source_block: 1-D array
16
+ :param target_block: 1-D array
17
+ """
18
+ mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
19
+ mask = mask.astype(np.int64)
20
+ # (source_length, target_length)
21
+ return mask
22
+
23
+ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
24
+ """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
25
+ rather than for training, since it is only built with a single epoch sample mapping.
26
+ """
27
+ args = get_args()
28
+ block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
29
+ titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
30
+
31
+ kwargs = dict(
32
+ name='full',
33
+ block_dataset=block_dataset,
34
+ title_dataset=titles_dataset,
35
+ data_prefix=args.data_path,
36
+ num_epochs=1,
37
+ max_num_samples=None,
38
+ max_seq_length=args.seq_length,
39
+ seed=1,
40
+ query_in_block_prob=query_in_block_prob,
41
+ use_titles=use_titles,
42
+ use_one_sent_docs=args.use_one_sent_docs
43
+ )
44
+ dataset = ICTDataset(**kwargs)
45
+ return dataset
46
+
47
+
48
+ class ICTDataset(Dataset):
49
+ """Dataset containing sentences and their blocks for an inverse cloze task."""
50
+ def __init__(self, name, block_dataset, title_dataset, data_prefix,
51
+ num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
52
+ seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
53
+ self.name = name
54
+ self.seed = seed
55
+ self.max_seq_length = max_seq_length
56
+ self.query_in_block_prob = query_in_block_prob
57
+ self.block_dataset = block_dataset
58
+ self.title_dataset = title_dataset
59
+ self.rng = random.Random(self.seed)
60
+ self.use_titles = use_titles
61
+ self.use_one_sent_docs = use_one_sent_docs
62
+
63
+ self.samples_mapping = get_block_samples_mapping(
64
+ block_dataset, title_dataset, data_prefix, num_epochs,
65
+ max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
66
+ self.tokenizer = get_tokenizer()
67
+ self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
68
+ self.vocab_id_to_token_list = self.tokenizer.inv_vocab
69
+ self.cls_id = self.tokenizer.cls
70
+ self.sep_id = self.tokenizer.sep
71
+ self.mask_id = self.tokenizer.mask
72
+ self.pad_id = self.tokenizer.pad
73
+
74
+ def __len__(self):
75
+ return len(self.samples_mapping)
76
+
77
+ def __getitem__(self, idx):
78
+ """Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
79
+ sample_data = self.samples_mapping[idx]
80
+ start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()
81
+
82
+ if self.use_titles:
83
+ title = self.title_dataset[int(doc_idx)]
84
+ title_pad_offset = 3 + len(title)
85
+ else:
86
+ title = None
87
+ title_pad_offset = 2
88
+ block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
89
+ assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
90
+
91
+ # randint() is inclusive for Python rng
92
+ rand_sent_idx = self.rng.randint(0, len(block) - 1)
93
+
94
+ # keep the query in the context query_in_block_prob fraction of the time.
95
+ if self.rng.random() < self.query_in_block_prob:
96
+ query = block[rand_sent_idx].copy()
97
+ else:
98
+ query = block.pop(rand_sent_idx)
99
+
100
+ # still need to truncate because blocks are concluded when
101
+ # the sentence lengths have exceeded max_seq_length.
102
+ query = query[:self.max_seq_length - 2]
103
+ block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
104
+
105
+ query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
106
+ context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
107
+
108
+ query_mask = make_attention_mask(query_tokens, query_tokens)
109
+ context_mask = make_attention_mask(context_tokens, context_tokens)
110
+
111
+ block_data = sample_data.as_array()
112
+
113
+ sample = {
114
+ 'query_tokens': query_tokens,
115
+ 'query_mask': query_mask,
116
+ 'query_pad_mask': query_pad_mask,
117
+ 'context_tokens': context_tokens,
118
+ 'context_mask': context_mask,
119
+ 'context_pad_mask': context_pad_mask,
120
+ 'block_data': block_data,
121
+ }
122
+
123
+ return sample
124
+
125
+ def get_block(self, start_idx, end_idx, doc_idx):
126
+ """Get the IDs for an evidence block plus the title of the corresponding document"""
127
+ block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
128
+ title = self.title_dataset[int(doc_idx)]
129
+
130
+ block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
131
+ block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
132
+
133
+ return block_tokens, block_pad_mask
134
+
135
+ def get_null_block(self):
136
+ """Get empty block and title - used in REALM pretraining"""
137
+ block, title = [], []
138
+ block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
139
+
140
+ return block_tokens, block_pad_mask
141
+
142
+ def concat_and_pad_tokens(self, tokens, title=None):
143
+ """Concat with special tokens and pad sequence to self.max_seq_length"""
144
+ tokens = list(tokens)
145
+ if title is None:
146
+ tokens = [self.cls_id] + tokens + [self.sep_id]
147
+ else:
148
+ title = list(title)
149
+ tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
150
+ assert len(tokens) <= self.max_seq_length
151
+
152
+ num_pad = self.max_seq_length - len(tokens)
153
+ pad_mask = [1] * len(tokens) + [0] * num_pad
154
+ tokens += [self.pad_id] * num_pad
155
+
156
+ return np.array(tokens), np.array(pad_mask)
megatron/data/image_folder.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BSD 3-Clause License
2
+ #
3
+ # Copyright (c) Soumith Chintala 2016,
4
+ # All rights reserved.
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # * Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # * Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # * Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+
31
+ # code taken from
32
+ # https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py
33
+ # added support for classes_fraction and data_per_class_fraction
34
+
35
+ from torchvision.datasets import VisionDataset
36
+ from PIL import Image
37
+
38
+ import os
39
+ import os.path
40
+ from typing import Any, Callable, cast, Dict, List, Optional, Tuple
41
+ import numpy as np
42
+
43
+ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
44
+ """Checks if a file is an allowed extension.
45
+ Args:
46
+ filename (string): path to a file
47
+ extensions (tuple of strings): extensions to consider (lowercase)
48
+ Returns:
49
+ bool: True if the filename ends with one of given extensions
50
+ """
51
+ return filename.lower().endswith(extensions)
52
+
53
+
54
+ def is_image_file(filename: str) -> bool:
55
+ """Checks if a file is an allowed image extension.
56
+ Args:
57
+ filename (string): path to a file
58
+ Returns:
59
+ bool: True if the filename ends with a known image extension
60
+ """
61
+ return has_file_allowed_extension(filename, IMG_EXTENSIONS)
62
+
63
+
64
+ def make_dataset(
65
+ directory: str,
66
+ class_to_idx: Dict[str, int],
67
+ data_per_class_fraction: float,
68
+ extensions: Optional[Tuple[str, ...]] = None,
69
+ is_valid_file: Optional[Callable[[str], bool]] = None,
70
+ ) -> List[Tuple[str, int]]:
71
+ """Generates a list of samples of a form (path_to_sample, class).
72
+ Args:
73
+ directory (str): root dataset directory
74
+ class_to_idx (Dict[str, int]): dictionary mapping class name to class index
75
+ extensions (optional): A list of allowed extensions.
76
+ Either extensions or is_valid_file should be passed. Defaults to None.
77
+ is_valid_file (optional): A function that takes path of a file
78
+ and checks if the file is a valid file
79
+ (used to check of corrupt files) both extensions and
80
+ is_valid_file should not be passed. Defaults to None.
81
+ Raises:
82
+ ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
83
+ Returns:
84
+ List[Tuple[str, int]]: samples of a form (path_to_sample, class)
85
+ """
86
+ instances = []
87
+ directory = os.path.expanduser(directory)
88
+ both_none = extensions is None and is_valid_file is None
89
+ both_something = extensions is not None and is_valid_file is not None
90
+ if both_none or both_something:
91
+ raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
92
+ if extensions is not None:
93
+ def is_valid_file(x: str) -> bool:
94
+ return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
95
+ is_valid_file = cast(Callable[[str], bool], is_valid_file)
96
+ for target_class in sorted(class_to_idx.keys()):
97
+ class_index = class_to_idx[target_class]
98
+ target_dir = os.path.join(directory, target_class)
99
+ if not os.path.isdir(target_dir):
100
+ continue
101
+ local_instances = []
102
+ for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
103
+ for fname in sorted(fnames):
104
+ path = os.path.join(root, fname)
105
+ if is_valid_file(path):
106
+ item = path, class_index
107
+ local_instances.append(item)
108
+
109
+ instances.extend(local_instances[0:int(len(local_instances) * data_per_class_fraction)])
110
+
111
+ return instances
112
+
113
+
114
+ class DatasetFolder(VisionDataset):
115
+ """A generic data loader where the samples are arranged in this way: ::
116
+ root/class_x/xxx.ext
117
+ root/class_x/xxy.ext
118
+ root/class_x/[...]/xxz.ext
119
+ root/class_y/123.ext
120
+ root/class_y/nsdf3.ext
121
+ root/class_y/[...]/asd932_.ext
122
+ Args:
123
+ root (string): Root directory path.
124
+ loader (callable): A function to load a sample given its path.
125
+ extensions (tuple[string]): A list of allowed extensions.
126
+ both extensions and is_valid_file should not be passed.
127
+ transform (callable, optional): A function/transform that takes in
128
+ a sample and returns a transformed version.
129
+ E.g, ``transforms.RandomCrop`` for images.
130
+ target_transform (callable, optional): A function/transform that takes
131
+ in the target and transforms it.
132
+ is_valid_file (callable, optional): A function that takes path of a file
133
+ and check if the file is a valid file (used to check of corrupt files)
134
+ both extensions and is_valid_file should not be passed.
135
+ Attributes:
136
+ classes (list): List of the class names sorted alphabetically.
137
+ class_to_idx (dict): Dict with items (class_name, class_index).
138
+ samples (list): List of (sample path, class_index) tuples
139
+ targets (list): The class_index value for each image in the dataset
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ root: str,
145
+ loader: Callable[[str], Any],
146
+ extensions: Optional[Tuple[str, ...]] = None,
147
+ transform: Optional[Callable] = None,
148
+ target_transform: Optional[Callable] = None,
149
+ classes_fraction=1.0,
150
+ data_per_class_fraction=1.0,
151
+ is_valid_file: Optional[Callable[[str], bool]] = None,
152
+ ) -> None:
153
+ super(DatasetFolder, self).__init__(root, transform=transform,
154
+ target_transform=target_transform)
155
+ self.classes_fraction = classes_fraction
156
+ self.data_per_class_fraction = data_per_class_fraction
157
+ classes, class_to_idx = self._find_classes(self.root)
158
+ samples = self.make_dataset(self.root,
159
+ class_to_idx,
160
+ self.data_per_class_fraction,
161
+ extensions,
162
+ is_valid_file)
163
+ if len(samples) == 0:
164
+ msg = "Found 0 files in subfolders of: {}\n".format(self.root)
165
+ if extensions is not None:
166
+ msg += "Supported extensions are: {}".format(",".join(extensions))
167
+ raise RuntimeError(msg)
168
+
169
+ self.loader = loader
170
+ self.extensions = extensions
171
+ self.total = len(samples)
172
+ self.classes = classes
173
+ self.class_to_idx = class_to_idx
174
+ self.samples = samples
175
+ self.targets = [s[1] for s in samples]
176
+
177
+ @staticmethod
178
+ def make_dataset(
179
+ directory: str,
180
+ class_to_idx: Dict[str, int],
181
+ data_per_class_fraction: float,
182
+ extensions: Optional[Tuple[str, ...]] = None,
183
+ is_valid_file: Optional[Callable[[str], bool]] = None,
184
+ ) -> List[Tuple[str, int]]:
185
+ return make_dataset(directory,
186
+ class_to_idx,
187
+ data_per_class_fraction,
188
+ extensions=extensions,
189
+ is_valid_file=is_valid_file)
190
+
191
+ def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
192
+ """
193
+ Finds the class folders in a dataset.
194
+ Args:
195
+ dir (string): Root directory path.
196
+ Returns:
197
+ tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
198
+ Ensures:
199
+ No class is a subdirectory of another.
200
+ """
201
+ all_classes = [d.name for d in os.scandir(dir) if d.is_dir()]
202
+ classes = all_classes[0:int(len(all_classes) * self.classes_fraction)]
203
+ classes.sort()
204
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
205
+ return classes, class_to_idx
206
+
207
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
208
+ """
209
+ Args:
210
+ index (int): Index
211
+ Returns:
212
+ tuple: (sample, target) where target is class_index of the target class.
213
+ """
214
+ curr_index = index
215
+ for x in range(self.total):
216
+ try:
217
+ path, target = self.samples[curr_index]
218
+ sample = self.loader(path)
219
+ break
220
+ except Exception as e:
221
+ curr_index = np.random.randint(0, self.total)
222
+
223
+ if self.transform is not None:
224
+ sample = self.transform(sample)
225
+ if self.target_transform is not None:
226
+ target = self.target_transform(target)
227
+
228
+ return sample, target
229
+
230
+ def __len__(self) -> int:
231
+ return len(self.samples)
232
+
233
+
234
+ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
235
+
236
+
237
+ def pil_loader(path: str) -> Image.Image:
238
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
239
+ with open(path, 'rb') as f:
240
+ img = Image.open(f)
241
+ return img.convert('RGB')
242
+
243
+
244
+ # TODO: specify the return type
245
+ def accimage_loader(path: str) -> Any:
246
+ import accimage
247
+ try:
248
+ return accimage.Image(path)
249
+ except IOError:
250
+ # Potentially a decoding problem, fall back to PIL.Image
251
+ return pil_loader(path)
252
+
253
+
254
+ def default_loader(path: str) -> Any:
255
+ from torchvision import get_image_backend
256
+ if get_image_backend() == 'accimage':
257
+ return accimage_loader(path)
258
+ else:
259
+ return pil_loader(path)
260
+
261
+
262
+ class ImageFolder(DatasetFolder):
263
+ """A generic data loader where the images are arranged in this way: ::
264
+ root/dog/xxx.png
265
+ root/dog/xxy.png
266
+ root/dog/[...]/xxz.png
267
+ root/cat/123.png
268
+ root/cat/nsdf3.png
269
+ root/cat/[...]/asd932_.png
270
+ Args:
271
+ root (string): Root directory path.
272
+ transform (callable, optional): A function/transform that takes in an PIL image
273
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
274
+ target_transform (callable, optional): A function/transform that takes in the
275
+ target and transforms it.
276
+ loader (callable, optional): A function to load an image given its path.
277
+ is_valid_file (callable, optional): A function that takes path of an Image file
278
+ and check if the file is a valid file (used to check of corrupt files)
279
+ Attributes:
280
+ classes (list): List of the class names sorted alphabetically.
281
+ class_to_idx (dict): Dict with items (class_name, class_index).
282
+ imgs (list): List of (image path, class_index) tuples
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ root: str,
288
+ transform: Optional[Callable] = None,
289
+ target_transform: Optional[Callable] = None,
290
+ classes_fraction=1.0,
291
+ data_per_class_fraction=1.0,
292
+ loader: Callable[[str], Any] = default_loader,
293
+ is_valid_file: Optional[Callable[[str], bool]] = None,
294
+ ):
295
+ super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
296
+ transform=transform,
297
+ target_transform=target_transform,
298
+ classes_fraction=classes_fraction,
299
+ data_per_class_fraction=data_per_class_fraction,
300
+ is_valid_file=is_valid_file)
301
+ self.imgs = self.samples
302
+
megatron/data/indexed_dataset.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # copied from fairseq/fairseq/data/indexed_dataset.py
8
+ # Removed IndexedRawTextDataset since it relied on Fairseq dictionary
9
+ # other slight modifications to remove fairseq dependencies
10
+ # Added document index to index file and made it accessible.
11
+ # An empty sentence no longer separates documents.
12
+
13
+ from functools import lru_cache
14
+ import os
15
+ import shutil
16
+ import struct
17
+ from itertools import accumulate
18
+
19
+ import numpy as np
20
+ import torch
21
+ from megatron import print_rank_0
22
+
23
+
24
+ def __best_fitting_dtype(vocab_size=None):
25
+ if vocab_size is not None and vocab_size < 65500:
26
+ return np.uint16
27
+ else:
28
+ return np.int32
29
+
30
+
31
+ def get_available_dataset_impl():
32
+ return ['lazy', 'cached', 'mmap']
33
+
34
+
35
+ def infer_dataset_impl(path):
36
+ if IndexedDataset.exists(path):
37
+ with open(index_file_path(path), 'rb') as f:
38
+ magic = f.read(8)
39
+ if magic == IndexedDataset._HDR_MAGIC:
40
+ return 'cached'
41
+ elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
42
+ return 'mmap'
43
+ else:
44
+ return None
45
+ else:
46
+ print(f"Dataset does not exist: {path}")
47
+ print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
48
+ return None
49
+
50
+
51
+ def make_builder(out_file, impl, vocab_size=None):
52
+ if impl == 'mmap':
53
+ return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
54
+ else:
55
+ return IndexedDatasetBuilder(out_file)
56
+
57
+
58
+ def make_dataset(path, impl, skip_warmup=False):
59
+ if not IndexedDataset.exists(path):
60
+ print(f"Dataset does not exist: {path}")
61
+ print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
62
+ return None
63
+ if impl == 'infer':
64
+ impl = infer_dataset_impl(path)
65
+ if impl == 'lazy' and IndexedDataset.exists(path):
66
+ return IndexedDataset(path)
67
+ elif impl == 'cached' and IndexedDataset.exists(path):
68
+ return IndexedCachedDataset(path)
69
+ elif impl == 'mmap' and MMapIndexedDataset.exists(path):
70
+ return MMapIndexedDataset(path, skip_warmup)
71
+ print(f"Unknown dataset implementation: {impl}")
72
+ return None
73
+
74
+
75
+ def dataset_exists(path, impl):
76
+ if impl == 'mmap':
77
+ return MMapIndexedDataset.exists(path)
78
+ else:
79
+ return IndexedDataset.exists(path)
80
+
81
+
82
+ def read_longs(f, n):
83
+ a = np.empty(n, dtype=np.int64)
84
+ f.readinto(a)
85
+ return a
86
+
87
+
88
+ def write_longs(f, a):
89
+ f.write(np.array(a, dtype=np.int64))
90
+
91
+
92
+ dtypes = {
93
+ 1: np.uint8,
94
+ 2: np.int8,
95
+ 3: np.int16,
96
+ 4: np.int32,
97
+ 5: np.int64,
98
+ 6: np.float,
99
+ 7: np.double,
100
+ 8: np.uint16
101
+ }
102
+
103
+
104
+ def code(dtype):
105
+ for k in dtypes.keys():
106
+ if dtypes[k] == dtype:
107
+ return k
108
+ raise ValueError(dtype)
109
+
110
+
111
+ def index_file_path(prefix_path):
112
+ return prefix_path + '.idx'
113
+
114
+
115
+ def data_file_path(prefix_path):
116
+ return prefix_path + '.bin'
117
+
118
+
119
+ def create_doc_idx(sizes):
120
+ doc_idx = [0]
121
+ for i, s in enumerate(sizes):
122
+ if s == 0:
123
+ doc_idx.append(i + 1)
124
+ return doc_idx
125
+
126
+
127
+ class IndexedDataset(torch.utils.data.Dataset):
128
+ """Loader for IndexedDataset"""
129
+ _HDR_MAGIC = b'TNTIDX\x00\x00'
130
+
131
+ def __init__(self, path):
132
+ super().__init__()
133
+ self.path = path
134
+ self.data_file = None
135
+ self.read_index(path)
136
+
137
+ def read_index(self, path):
138
+ with open(index_file_path(path), 'rb') as f:
139
+ magic = f.read(8)
140
+ assert magic == self._HDR_MAGIC, (
141
+ 'Index file doesn\'t match expected format. '
142
+ 'Make sure that --dataset-impl is configured properly.'
143
+ )
144
+ version = f.read(8)
145
+ assert struct.unpack('<Q', version) == (1,)
146
+ code, self.element_size = struct.unpack('<QQ', f.read(16))
147
+ self.dtype = dtypes[code]
148
+ self._len, self.s = struct.unpack('<QQ', f.read(16))
149
+ self.doc_count = struct.unpack('<Q', f.read(8))
150
+ self.dim_offsets = read_longs(f, self._len + 1)
151
+ self.data_offsets = read_longs(f, self._len + 1)
152
+ self.sizes = read_longs(f, self.s)
153
+ self.doc_idx = read_longs(f, self.doc_count)
154
+
155
+ def read_data(self, path):
156
+ self.data_file = open(data_file_path(path), 'rb', buffering=0)
157
+
158
+ def check_index(self, i):
159
+ if i < 0 or i >= self._len:
160
+ raise IndexError('index out of range')
161
+
162
+ def __del__(self):
163
+ if self.data_file:
164
+ self.data_file.close()
165
+
166
+ # @lru_cache(maxsize=8)
167
+ def __getitem__(self, idx):
168
+ if not self.data_file:
169
+ self.read_data(self.path)
170
+ if isinstance(idx, int):
171
+ i = idx
172
+ self.check_index(i)
173
+ tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
174
+ a = np.empty(tensor_size, dtype=self.dtype)
175
+ self.data_file.seek(self.data_offsets[i] * self.element_size)
176
+ self.data_file.readinto(a)
177
+ return a
178
+ elif isinstance(idx, slice):
179
+ start, stop, step = idx.indices(len(self))
180
+ if step != 1:
181
+ raise ValueError("Slices into indexed_dataset must be contiguous")
182
+ sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
183
+ size = sum(sizes)
184
+ a = np.empty(size, dtype=self.dtype)
185
+ self.data_file.seek(self.data_offsets[start] * self.element_size)
186
+ self.data_file.readinto(a)
187
+ offsets = list(accumulate(sizes))
188
+ sents = np.split(a, offsets[:-1])
189
+ return sents
190
+
191
+ def __len__(self):
192
+ return self._len
193
+
194
+ def num_tokens(self, index):
195
+ return self.sizes[index]
196
+
197
+ def size(self, index):
198
+ return self.sizes[index]
199
+
200
+ @staticmethod
201
+ def exists(path):
202
+ return (
203
+ os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
204
+ )
205
+
206
+ @property
207
+ def supports_prefetch(self):
208
+ return False # avoid prefetching to save memory
209
+
210
+
211
+ class IndexedCachedDataset(IndexedDataset):
212
+
213
+ def __init__(self, path):
214
+ super().__init__(path)
215
+ self.cache = None
216
+ self.cache_index = {}
217
+
218
+ @property
219
+ def supports_prefetch(self):
220
+ return True
221
+
222
+ def prefetch(self, indices):
223
+ if all(i in self.cache_index for i in indices):
224
+ return
225
+ if not self.data_file:
226
+ self.read_data(self.path)
227
+ indices = sorted(set(indices))
228
+ total_size = 0
229
+ for i in indices:
230
+ total_size += self.data_offsets[i + 1] - self.data_offsets[i]
231
+ self.cache = np.empty(total_size, dtype=self.dtype)
232
+ ptx = 0
233
+ self.cache_index.clear()
234
+ for i in indices:
235
+ self.cache_index[i] = ptx
236
+ size = self.data_offsets[i + 1] - self.data_offsets[i]
237
+ a = self.cache[ptx: ptx + size]
238
+ self.data_file.seek(self.data_offsets[i] * self.element_size)
239
+ self.data_file.readinto(a)
240
+ ptx += size
241
+ if self.data_file:
242
+ # close and delete data file after prefetch so we can pickle
243
+ self.data_file.close()
244
+ self.data_file = None
245
+
246
+ # @lru_cache(maxsize=8)
247
+ def __getitem__(self, idx):
248
+ if isinstance(idx, int):
249
+ i = idx
250
+ self.check_index(i)
251
+ tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
252
+ a = np.empty(tensor_size, dtype=self.dtype)
253
+ ptx = self.cache_index[i]
254
+ np.copyto(a, self.cache[ptx: ptx + a.size])
255
+ return a
256
+ elif isinstance(idx, slice):
257
+ # Hack just to make this work, can optimizer later if necessary
258
+ sents = []
259
+ for i in range(*idx.indices(len(self))):
260
+ sents.append(self[i])
261
+ return sents
262
+
263
+
264
+ class IndexedDatasetBuilder(object):
265
+ element_sizes = {
266
+ np.uint8: 1,
267
+ np.int8: 1,
268
+ np.int16: 2,
269
+ np.int32: 4,
270
+ np.int64: 8,
271
+ np.float: 4,
272
+ np.double: 8
273
+ }
274
+
275
+ def __init__(self, out_file, dtype=np.int32):
276
+ self.out_file = open(out_file, 'wb')
277
+ self.dtype = dtype
278
+ self.data_offsets = [0]
279
+ self.dim_offsets = [0]
280
+ self.sizes = []
281
+ self.element_size = self.element_sizes[self.dtype]
282
+ self.doc_idx = [0]
283
+
284
+ def add_item(self, tensor):
285
+ bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
286
+ self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
287
+ for s in tensor.size():
288
+ self.sizes.append(s)
289
+ self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
290
+
291
+ def end_document(self):
292
+ self.doc_idx.append(len(self.sizes))
293
+
294
+ def merge_file_(self, another_file):
295
+ index = IndexedDataset(another_file)
296
+ assert index.dtype == self.dtype
297
+
298
+ doc_offset = len(self.sizes)
299
+
300
+ begin = self.data_offsets[-1]
301
+ for data_offset in index.data_offsets[1:]:
302
+ self.data_offsets.append(begin + data_offset)
303
+ self.sizes.extend(index.sizes)
304
+
305
+ begin = self.dim_offsets[-1]
306
+ for dim_offset in index.dim_offsets[1:]:
307
+ self.dim_offsets.append(begin + dim_offset)
308
+
309
+ self.doc_idx.extend((doc_offset + index.doc_idx)[1:])
310
+
311
+ with open(data_file_path(another_file), 'rb') as f:
312
+ while True:
313
+ data = f.read(1024)
314
+ if data:
315
+ self.out_file.write(data)
316
+ else:
317
+ break
318
+
319
+ def finalize(self, index_file):
320
+ self.out_file.close()
321
+ index = open(index_file, 'wb')
322
+ index.write(b'TNTIDX\x00\x00')
323
+ index.write(struct.pack('<Q', 1))
324
+ index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
325
+ index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
326
+ index.write(struct.pack('<Q', len(self.doc_idx)))
327
+ write_longs(index, self.dim_offsets)
328
+ write_longs(index, self.data_offsets)
329
+ write_longs(index, self.sizes)
330
+ write_longs(index, self.doc_idx)
331
+ index.close()
332
+
333
+
334
+ def _warmup_mmap_file(path):
335
+ with open(path, 'rb') as stream:
336
+ while stream.read(100 * 1024 * 1024):
337
+ pass
338
+
339
+
340
+ class MMapIndexedDataset(torch.utils.data.Dataset):
341
+ class Index(object):
342
+ _HDR_MAGIC = b'MMIDIDX\x00\x00'
343
+
344
+ @classmethod
345
+ def writer(cls, path, dtype):
346
+ class _Writer(object):
347
+ def __enter__(self):
348
+ self._file = open(path, 'wb')
349
+
350
+ self._file.write(cls._HDR_MAGIC)
351
+ self._file.write(struct.pack('<Q', 1))
352
+ self._file.write(struct.pack('<B', code(dtype)))
353
+
354
+ return self
355
+
356
+ @staticmethod
357
+ def _get_pointers(sizes):
358
+ dtype_size = dtype().itemsize
359
+ address = 0
360
+ pointers = []
361
+
362
+ for size in sizes:
363
+ pointers.append(address)
364
+ address += size * dtype_size
365
+
366
+ return pointers
367
+
368
+ def write(self, sizes, doc_idx):
369
+ pointers = self._get_pointers(sizes)
370
+
371
+ self._file.write(struct.pack('<Q', len(sizes)))
372
+ self._file.write(struct.pack('<Q', len(doc_idx)))
373
+
374
+ sizes = np.array(sizes, dtype=np.int32)
375
+ self._file.write(sizes.tobytes(order='C'))
376
+ del sizes
377
+
378
+ pointers = np.array(pointers, dtype=np.int64)
379
+ self._file.write(pointers.tobytes(order='C'))
380
+ del pointers
381
+
382
+ doc_idx = np.array(doc_idx, dtype=np.int64)
383
+ self._file.write(doc_idx.tobytes(order='C'))
384
+
385
+ def __exit__(self, exc_type, exc_val, exc_tb):
386
+ self._file.close()
387
+
388
+ return _Writer()
389
+
390
+ def __init__(self, path, skip_warmup=False):
391
+ with open(path, 'rb') as stream:
392
+ magic_test = stream.read(9)
393
+ assert self._HDR_MAGIC == magic_test, (
394
+ 'Index file doesn\'t match expected format. '
395
+ 'Make sure that --dataset-impl is configured properly.'
396
+ )
397
+ version = struct.unpack('<Q', stream.read(8))
398
+ assert (1,) == version
399
+
400
+ dtype_code, = struct.unpack('<B', stream.read(1))
401
+ self._dtype = dtypes[dtype_code]
402
+ self._dtype_size = self._dtype().itemsize
403
+
404
+ self._len = struct.unpack('<Q', stream.read(8))[0]
405
+ self._doc_count = struct.unpack('<Q', stream.read(8))[0]
406
+ offset = stream.tell()
407
+
408
+ if not skip_warmup:
409
+ print_rank_0(" warming up index mmap file...")
410
+ _warmup_mmap_file(path)
411
+
412
+ self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
413
+ self._bin_buffer = memoryview(self._bin_buffer_mmap)
414
+ print_rank_0(" reading sizes...")
415
+ self._sizes = np.frombuffer(
416
+ self._bin_buffer,
417
+ dtype=np.int32,
418
+ count=self._len,
419
+ offset=offset)
420
+ print_rank_0(" reading pointers...")
421
+ self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
422
+ offset=offset + self._sizes.nbytes)
423
+ print_rank_0(" reading document index...")
424
+ self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
425
+ offset=offset + self._sizes.nbytes + self._pointers.nbytes)
426
+
427
+ def __del__(self):
428
+ self._bin_buffer_mmap._mmap.close()
429
+ del self._bin_buffer_mmap
430
+
431
+ @property
432
+ def dtype(self):
433
+ return self._dtype
434
+
435
+ @property
436
+ def sizes(self):
437
+ return self._sizes
438
+
439
+ @property
440
+ def doc_idx(self):
441
+ return self._doc_idx
442
+
443
+ @lru_cache(maxsize=8)
444
+ def __getitem__(self, i):
445
+ return self._pointers[i], self._sizes[i]
446
+
447
+ def __len__(self):
448
+ return self._len
449
+
450
+ def __init__(self, path, skip_warmup=False):
451
+ super().__init__()
452
+
453
+ self._path = None
454
+ self._index = None
455
+ self._bin_buffer = None
456
+
457
+ self._do_init(path, skip_warmup)
458
+
459
+ def __getstate__(self):
460
+ return self._path
461
+
462
+ def __setstate__(self, state):
463
+ self._do_init(state)
464
+
465
+ def _do_init(self, path, skip_warmup):
466
+ self._path = path
467
+ self._index = self.Index(index_file_path(self._path), skip_warmup)
468
+
469
+ if not skip_warmup:
470
+ print_rank_0(" warming up data mmap file...")
471
+ _warmup_mmap_file(data_file_path(self._path))
472
+ print_rank_0(" creating numpy buffer of mmap...")
473
+ self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
474
+ print_rank_0(" creating memory view of numpy buffer...")
475
+ self._bin_buffer = memoryview(self._bin_buffer_mmap)
476
+
477
+ def __del__(self):
478
+ self._bin_buffer_mmap._mmap.close()
479
+ del self._bin_buffer_mmap
480
+ del self._index
481
+
482
+ def __len__(self):
483
+ return len(self._index)
484
+
485
+ # @lru_cache(maxsize=8)
486
+ def __getitem__(self, idx):
487
+ if isinstance(idx, int):
488
+ ptr, size = self._index[idx]
489
+ np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
490
+ count=size, offset=ptr)
491
+ return np_array
492
+ elif isinstance(idx, slice):
493
+ start, stop, step = idx.indices(len(self))
494
+ if step != 1:
495
+ raise ValueError("Slices into indexed_dataset must be contiguous")
496
+ ptr = self._index._pointers[start]
497
+ sizes = self._index._sizes[idx]
498
+ offsets = list(accumulate(sizes))
499
+ total_size = sum(sizes)
500
+ np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
501
+ count=total_size, offset=ptr)
502
+ sents = np.split(np_array, offsets[:-1])
503
+ return sents
504
+
505
+ def get(self, idx, offset=0, length=None):
506
+ """ Retrieves a single item from the dataset with the option to only
507
+ return a portion of the item.
508
+
509
+ get(idx) is the same as [idx] but get() does not support slicing.
510
+ """
511
+ ptr, size = self._index[idx]
512
+ if length is None:
513
+ length = size - offset
514
+ ptr += offset * np.dtype(self._index.dtype).itemsize
515
+ np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
516
+ count=length, offset=ptr)
517
+ return np_array
518
+
519
+ @property
520
+ def sizes(self):
521
+ return self._index.sizes
522
+
523
+ @property
524
+ def doc_idx(self):
525
+ return self._index.doc_idx
526
+
527
+ def get_doc_idx(self):
528
+ return self._index._doc_idx
529
+
530
+ def set_doc_idx(self, doc_idx_):
531
+ self._index._doc_idx = doc_idx_
532
+
533
+ @property
534
+ def supports_prefetch(self):
535
+ return False
536
+
537
+ @staticmethod
538
+ def exists(path):
539
+ return (
540
+ os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
541
+ )
542
+
543
+
544
+ class MMapIndexedDatasetBuilder(object):
545
+ def __init__(self, out_file, dtype=np.int64):
546
+ self._data_file = open(out_file, 'wb')
547
+ self._dtype = dtype
548
+ self._sizes = []
549
+ self._doc_idx = [0]
550
+
551
+ def add_item(self, tensor):
552
+ np_array = np.array(tensor.numpy(), dtype=self._dtype)
553
+ self._data_file.write(np_array.tobytes(order='C'))
554
+ self._sizes.append(np_array.size)
555
+
556
+ def end_document(self):
557
+ self._doc_idx.append(len(self._sizes))
558
+
559
+ def merge_file_(self, another_file):
560
+ # Concatenate index
561
+ index = MMapIndexedDataset.Index(index_file_path(another_file))
562
+ assert index.dtype == self._dtype
563
+
564
+ offset = len(self._sizes)
565
+ self._sizes.extend(index.sizes)
566
+ self._doc_idx.extend((offset + index.doc_idx)[1:])
567
+
568
+ # Concatenate data
569
+ with open(data_file_path(another_file), 'rb') as f:
570
+ shutil.copyfileobj(f, self._data_file)
571
+
572
+ def finalize(self, index_file):
573
+ self._data_file.close()
574
+
575
+ with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
576
+ index.write(self._sizes, self._doc_idx)
megatron/data/orqa_wiki_dataset.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Wikipedia dataset from DPR code for ORQA."""
17
+
18
+ from abc import ABC
19
+ import csv
20
+ import numpy as np
21
+ import random
22
+ import torch
23
+ from torch.utils.data import Dataset
24
+
25
+ from megatron import print_rank_0, get_args, get_tokenizer, mpu
26
+ from megatron.data.biencoder_dataset_utils import make_attention_mask
27
+
28
+ def get_open_retrieval_wiki_dataset():
29
+ args = get_args()
30
+ tokenizer = get_tokenizer()
31
+
32
+ dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase',
33
+ 'evidence',
34
+ args.evidence_data_path,
35
+ tokenizer,
36
+ args.retriever_seq_length)
37
+ return dataset
38
+
39
+
40
+ def get_open_retrieval_batch(data_iterator):
41
+ # Items and their type.
42
+ keys = ['row_id', 'context', 'context_mask', 'context_types',
43
+ 'context_pad_mask']
44
+ datatype = torch.int64
45
+
46
+ # Broadcast data.
47
+ data = None if data_iterator is None else next(data_iterator)
48
+ data_b = mpu.broadcast_data(keys, data, datatype)
49
+
50
+ # Unpack.
51
+ row_id = data_b['row_id'].long()
52
+ context = data_b['context'].long()
53
+
54
+ # TODO: make the context mask a binary one
55
+ context_mask = (data_b['context_mask'] < 0.5)
56
+
57
+ context_types = data_b['context_types'].long()
58
+ context_pad_mask = data_b['context_pad_mask'].long()
59
+
60
+ return row_id, context, context_mask, context_types, context_pad_mask
61
+
62
+
63
+ def build_tokens_types_paddings_from_text(row, tokenizer, max_seq_length):
64
+ """Build token types and paddings, trim if needed, and pad if needed."""
65
+
66
+ title_ids = tokenizer.tokenize(row['title'])
67
+ context_ids = tokenizer.tokenize(row['text'])
68
+
69
+ # Appending the title of the context at front
70
+ extended_context_ids = title_ids + [tokenizer.sep_id] + context_ids
71
+
72
+ context_ids, context_types, context_pad_mask = \
73
+ build_tokens_types_paddings_from_ids(extended_context_ids,
74
+ max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
75
+
76
+ return context_ids, context_types, context_pad_mask
77
+
78
+
79
+ # noinspection DuplicatedCode
80
+ def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
81
+ cls_id, sep_id, pad_id):
82
+ """Build token types and paddings, trim if needed, and pad if needed."""
83
+ enc_ids = []
84
+ tokentypes_enc = []
85
+
86
+ # [CLS].
87
+ enc_ids.append(cls_id)
88
+ tokentypes_enc.append(0)
89
+
90
+ # A.
91
+ len_src = len(text_ids)
92
+ enc_ids.extend(text_ids)
93
+ tokentypes_enc.extend([0] * len_src)
94
+
95
+ # Cap the size.
96
+ if len(enc_ids) > max_seq_length - 1:
97
+ enc_ids = enc_ids[0: max_seq_length - 1]
98
+ tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
99
+
100
+ # [SEP].
101
+ enc_ids.append(sep_id)
102
+ tokentypes_enc.append(0)
103
+
104
+ num_tokens_enc = len(enc_ids)
105
+ # Padding.
106
+ padding_length = max_seq_length - len(enc_ids)
107
+ if padding_length > 0:
108
+ enc_ids.extend([pad_id] * padding_length)
109
+ tokentypes_enc.extend([pad_id] * padding_length)
110
+
111
+ pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
112
+ pad_mask = np.array(pad_mask, dtype=np.int64)
113
+
114
+ return enc_ids, tokentypes_enc, pad_mask
115
+
116
+
117
+ def build_sample(row_id, context_ids, context_types, context_pad_mask):
118
+ """Convert to numpy and return a sample consumed by the batch producer."""
119
+
120
+ context_ids = np.array(context_ids, dtype=np.int64)
121
+ context_types = np.array(context_types, dtype=np.int64)
122
+ context_mask = make_attention_mask(context_ids, context_ids)
123
+
124
+ sample = ({
125
+ 'row_id': row_id,
126
+ 'context': context_ids,
127
+ 'context_mask': context_mask,
128
+ 'context_types': context_types,
129
+ 'context_pad_mask': context_pad_mask
130
+ })
131
+ return sample
132
+
133
+
134
+ class OpenRetrievalEvidenceDataset(ABC, Dataset):
135
+ """Open Retrieval Evidence dataset class."""
136
+
137
+ def __init__(self, task_name, dataset_name, datapath, tokenizer,
138
+ max_seq_length):
139
+ # Store inputs.
140
+ self.task_name = task_name
141
+ self.dataset_name = dataset_name
142
+ self.tokenizer = tokenizer
143
+ self.max_seq_length = max_seq_length
144
+ print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
145
+ self.dataset_name))
146
+ # Process the files.
147
+ print_rank_0(datapath)
148
+ self.samples, self.id2text = self.process_samples_from_single_path(
149
+ datapath)
150
+
151
+ args = get_args()
152
+ if args.sample_rate < 1: # subsample
153
+ k = int(len(self.samples) * args.sample_rate)
154
+ self.samples = random.sample(self.samples, k)
155
+
156
+ print_rank_0(' >> total number of samples: {}'.format(
157
+ len(self.samples)))
158
+
159
+ def __len__(self):
160
+ return len(self.samples)
161
+
162
+ def __getitem__(self, idx):
163
+ row = self.samples[idx]
164
+
165
+ context_ids, context_types, context_pad_mask = \
166
+ build_tokens_types_paddings_from_text(row, self.tokenizer,
167
+ self.max_seq_length)
168
+
169
+ sample = build_sample(row['doc_id'],
170
+ context_ids,
171
+ context_types,
172
+ context_pad_mask)
173
+ return sample
174
+
175
+ @staticmethod
176
+ def process_samples_from_single_path(filename):
177
+ print_rank_0(' > Processing {} ...'.format(filename))
178
+ total = 0
179
+
180
+ rows = []
181
+ id2text = {}
182
+
183
+ with open(filename) as tsvfile:
184
+ reader = csv.reader(tsvfile, delimiter='\t')
185
+ next(reader, None) # skip the headers
186
+ for row in reader:
187
+ # file format: doc_id, doc_text, title
188
+ doc_id = int(row[0])
189
+ text = row[1]
190
+ title = row[2]
191
+
192
+ rows.append({'doc_id': doc_id,
193
+ 'text': text,
194
+ 'title': title})
195
+
196
+ assert doc_id not in id2text
197
+ id2text[doc_id] = (text, title)
198
+
199
+ total += 1
200
+ if total % 100000 == 0:
201
+ print_rank_0(' > processed {} rows so far ...'.format(
202
+ total))
203
+
204
+ print_rank_0(' >> processed {} samples.'.format(len(rows)))
205
+ return rows, id2text
megatron/data/realm_dataset_utils.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from megatron import mpu, print_rank_0
8
+ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
9
+ from megatron import get_args, get_tokenizer, print_rank_0, mpu
10
+
11
+
12
+ def get_one_epoch_dataloader(dataset, micro_batch_size=None):
13
+ """Specifically one epoch to be used in an indexing job."""
14
+ args = get_args()
15
+
16
+ world_size = mpu.get_data_parallel_world_size()
17
+ rank = mpu.get_data_parallel_rank()
18
+ if micro_batch_size is None:
19
+ micro_batch_size = args.micro_batch_size
20
+ global_batch_size = micro_batch_size * world_size
21
+ num_workers = args.num_workers
22
+
23
+ sampler = torch.utils.data.SequentialSampler(dataset)
24
+ # importantly, drop_last must be False to get all the data.
25
+ assert False, 'DistributedBatchSampler deprecated, change the implementation'
26
+ from megatron.data.samplers import DistributedBatchSampler
27
+ batch_sampler = DistributedBatchSampler(sampler,
28
+ batch_size=global_batch_size,
29
+ drop_last=False,
30
+ rank=rank,
31
+ world_size=world_size)
32
+
33
+ return torch.utils.data.DataLoader(dataset,
34
+ batch_sampler=batch_sampler,
35
+ num_workers=num_workers,
36
+ pin_memory=True)
37
+
38
+
39
+ def get_ict_batch(data_iterator):
40
+ # Items and their type.
41
+ keys = ['query_tokens', 'query_pad_mask',
42
+ 'block_tokens', 'block_pad_mask', 'block_data']
43
+ datatype = torch.int64
44
+
45
+ # Broadcast data.
46
+ if data_iterator is None:
47
+ data = None
48
+ else:
49
+ data = next(data_iterator)
50
+ data_b = mpu.broadcast_data(keys, data, datatype)
51
+
52
+ # Unpack.
53
+ query_tokens = data_b['query_tokens'].long()
54
+ query_pad_mask = data_b['query_pad_mask'].long()
55
+ block_tokens = data_b['block_tokens'].long()
56
+ block_pad_mask = data_b['block_pad_mask'].long()
57
+ block_indices = data_b['block_data'].long()
58
+
59
+ return query_tokens, query_pad_mask,\
60
+ block_tokens, block_pad_mask, block_indices
61
+
62
+
63
+ def join_str_list(str_list):
64
+ """Join a list of strings, handling spaces appropriately"""
65
+ result = ""
66
+ for s in str_list:
67
+ if s.startswith("##"):
68
+ result += s[2:]
69
+ else:
70
+ result += " " + s
71
+ return result
72
+
73
+
74
+ class BlockSampleData(object):
75
+ """A struct for fully describing a fixed-size block of data as used in REALM
76
+
77
+ :param start_idx: for first sentence of the block
78
+ :param end_idx: for last sentence of the block (may be partially truncated in sample construction)
79
+ :param doc_idx: the index of the document from which the block comes in the original indexed dataset
80
+ :param block_idx: a unique integer identifier given to every block.
81
+ """
82
+ def __init__(self, start_idx, end_idx, doc_idx, block_idx):
83
+ self.start_idx = start_idx
84
+ self.end_idx = end_idx
85
+ self.doc_idx = doc_idx
86
+ self.block_idx = block_idx
87
+
88
+ def as_array(self):
89
+ return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
90
+
91
+ def as_tuple(self):
92
+ return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
93
+
94
+
95
+ class BlockSamplesMapping(object):
96
+ def __init__(self, mapping_array):
97
+ # make sure that the array is compatible with BlockSampleData
98
+ assert mapping_array.shape[1] == 4
99
+ self.mapping_array = mapping_array
100
+
101
+ def __len__(self):
102
+ return self.mapping_array.shape[0]
103
+
104
+ def __getitem__(self, idx):
105
+ """Get the data associated with an indexed sample."""
106
+ sample_data = BlockSampleData(*self.mapping_array[idx])
107
+ return sample_data
108
+
109
+
110
+ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
111
+ max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
112
+ """Get samples mapping for a dataset over fixed size blocks. This function also requires
113
+ a dataset of the titles for the source documents since their lengths must be taken into account.
114
+
115
+ :return: samples_mapping (BlockSamplesMapping)
116
+ """
117
+
118
+ if not num_epochs:
119
+ if not max_num_samples:
120
+ raise ValueError("Need to specify either max_num_samples "
121
+ "or num_epochs")
122
+ num_epochs = np.iinfo(np.int32).max - 1
123
+ if not max_num_samples:
124
+ max_num_samples = np.iinfo(np.int64).max - 1
125
+
126
+ # Filename of the index mapping
127
+ indexmap_filename = data_prefix
128
+ indexmap_filename += '_{}_indexmap'.format(name)
129
+ if num_epochs != (np.iinfo(np.int32).max - 1):
130
+ indexmap_filename += '_{}ep'.format(num_epochs)
131
+ if max_num_samples != (np.iinfo(np.int64).max - 1):
132
+ indexmap_filename += '_{}mns'.format(max_num_samples)
133
+ indexmap_filename += '_{}msl'.format(max_seq_length)
134
+ indexmap_filename += '_{}s'.format(seed)
135
+ if use_one_sent_docs:
136
+ indexmap_filename += '_1sentok'
137
+ indexmap_filename += '.npy'
138
+
139
+ # Build the indexed mapping if not exist.
140
+ if mpu.get_data_parallel_rank() == 0 and \
141
+ not os.path.isfile(indexmap_filename):
142
+ print(' > WARNING: could not find index map file {}, building '
143
+ 'the indices on rank 0 ...'.format(indexmap_filename))
144
+
145
+ # Make sure the types match the helpers input types.
146
+ assert block_dataset.doc_idx.dtype == np.int64
147
+ assert block_dataset.sizes.dtype == np.int32
148
+
149
+ # Build samples mapping
150
+ verbose = torch.distributed.get_rank() == 0
151
+ start_time = time.time()
152
+ print_rank_0(' > building samples index mapping for {} ...'.format(
153
+ name))
154
+
155
+ from megatron.data import helpers
156
+ mapping_array = helpers.build_blocks_mapping(
157
+ block_dataset.doc_idx,
158
+ block_dataset.sizes,
159
+ title_dataset.sizes,
160
+ num_epochs,
161
+ max_num_samples,
162
+ max_seq_length - 3, # account for added tokens
163
+ seed,
164
+ verbose,
165
+ use_one_sent_docs)
166
+
167
+
168
+ print_rank_0(' > done building samples index mapping')
169
+ np.save(indexmap_filename, mapping_array, allow_pickle=True)
170
+ print_rank_0(' > saved the index mapping in {}'.format(
171
+ indexmap_filename))
172
+ # Make sure all the ranks have built the mapping
173
+ print_rank_0(' > elapsed time to build and save samples mapping '
174
+ '(seconds): {:4f}'.format(
175
+ time.time() - start_time))
176
+
177
+ # This should be a barrier but nccl barrier assumes
178
+ # device_index=rank which is not the case for model
179
+ # parallel case
180
+ counts = torch.cuda.LongTensor([1])
181
+ torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
182
+ assert counts[0].item() == torch.distributed.get_world_size(
183
+ group=mpu.get_data_parallel_group())
184
+
185
+ # Load indexed dataset.
186
+ print_rank_0(' > loading indexed mapping from {}'.format(
187
+ indexmap_filename))
188
+ start_time = time.time()
189
+
190
+ mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
191
+ samples_mapping = BlockSamplesMapping(mapping_array)
192
+
193
+ print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
194
+ time.time() - start_time))
195
+ print_rank_0(' total number of samples: {}'.format(
196
+ mapping_array.shape[0]))
197
+
198
+ return samples_mapping
megatron/data/realm_index.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+ import pickle
4
+ import shutil
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from megatron import get_args
10
+ from megatron import mpu
11
+
12
+
13
+ def detach(tensor):
14
+ return tensor.detach().cpu().numpy()
15
+
16
+
17
+ class OpenRetreivalDataStore(object):
18
+ """
19
+ Serializable data structure for holding data for blocks --
20
+ embeddings and necessary metadata for Retriever
21
+ """
22
+ def __init__(self, embedding_path=None, load_from_path=True, rank=None):
23
+ self.embed_data = dict()
24
+ if embedding_path is None:
25
+ args = get_args()
26
+ embedding_path = args.embedding_path
27
+ rank = args.rank
28
+ self.embedding_path = embedding_path
29
+ self.rank = rank
30
+
31
+ if load_from_path:
32
+ self.load_from_file()
33
+
34
+ block_data_name = os.path.splitext(self.embedding_path)[0]
35
+ self.temp_dir_name = block_data_name + '_tmp'
36
+
37
+ def state(self):
38
+ return {
39
+ 'embed_data': self.embed_data,
40
+ }
41
+
42
+ def clear(self):
43
+ """
44
+ Clear the embedding data structures to save memory.
45
+ The metadata ends up getting used, and is also much smaller in
46
+ dimensionality so it isn't really worth clearing.
47
+ """
48
+ self.embed_data = dict()
49
+
50
+ def load_from_file(self):
51
+ """Populate members from instance saved to file"""
52
+
53
+ if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
54
+ print("\n> Unpickling BlockData", flush=True)
55
+ state_dict = pickle.load(open(self.embedding_path, 'rb'))
56
+ if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
57
+ print(">> Finished unpickling BlockData\n", flush=True)
58
+
59
+ self.embed_data = state_dict['embed_data']
60
+
61
+ def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
62
+ """
63
+ Add data for set of blocks
64
+ :param row_id: 1D array of unique int ids for the blocks
65
+ :param block_embeds: 2D array of embeddings of the blocks
66
+ In the case of retriever this will be [start_idx, end_idx, doc_idx]
67
+ """
68
+ for idx, embed in zip(row_id, block_embeds):
69
+ if not allow_overwrite and idx in self.embed_data:
70
+ raise ValueError("Unexpectedly tried to overwrite block data")
71
+
72
+ self.embed_data[idx] = np.float16(embed)
73
+
74
+ def save_shard(self):
75
+ """
76
+ Save the block data that was created this in this process
77
+ """
78
+ if not os.path.isdir(self.temp_dir_name):
79
+ os.makedirs(self.temp_dir_name, exist_ok=True)
80
+
81
+ # save the data for each shard
82
+ with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \
83
+ as writer:
84
+ pickle.dump(self.state(), writer)
85
+
86
+ def merge_shards_and_save(self):
87
+ #Combine all the shards made using save_shard
88
+ shard_names = os.listdir(self.temp_dir_name)
89
+ seen_own_shard = False
90
+
91
+ for fname in os.listdir(self.temp_dir_name):
92
+ shard_rank = int(os.path.splitext(fname)[0])
93
+ if shard_rank == self.rank:
94
+ seen_own_shard = True
95
+ continue
96
+
97
+ with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
98
+ data = pickle.load(f)
99
+ old_size = len(self.embed_data)
100
+ shard_size = len(data['embed_data'])
101
+
102
+ # add the shard's data and check to make sure there
103
+ # is no overlap
104
+ self.embed_data.update(data['embed_data'])
105
+ assert len(self.embed_data) == old_size + shard_size
106
+
107
+ assert seen_own_shard
108
+
109
+ # save the consolidated shards and remove temporary directory
110
+ with open(self.embedding_path, 'wb') as final_file:
111
+ pickle.dump(self.state(), final_file)
112
+ shutil.rmtree(self.temp_dir_name, ignore_errors=True)
113
+
114
+ print("Finished merging {} shards for a total of {} embeds".format(
115
+ len(shard_names), len(self.embed_data)), flush=True)
116
+
117
+
118
+ class FaissMIPSIndex(object):
119
+ """
120
+ Wrapper object for a BlockData which similarity search via FAISS under the hood
121
+ """
122
+ def __init__(self, embed_size, embed_data=None, use_gpu=False):
123
+ self.embed_size = embed_size
124
+ self.embed_data = embed_data
125
+ self.use_gpu = use_gpu
126
+
127
+ self.mips_index = None
128
+ self._set_mips_index()
129
+
130
+ def _set_mips_index(self):
131
+ """
132
+ Create a Faiss Flat index with inner product as the metric
133
+ to search against
134
+ """
135
+ try:
136
+ import faiss
137
+ except ImportError:
138
+ raise Exception("Error: Please install faiss to use FaissMIPSIndex")
139
+
140
+ if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
141
+ print("\n> Building index", flush=True)
142
+
143
+ cpu_index = faiss.IndexFlatIP(self.embed_size)
144
+
145
+ if self.use_gpu:
146
+ # create resources and config for GpuIndex
147
+ config = faiss.GpuMultipleClonerOptions()
148
+ config.shard = True
149
+ config.useFloat16 = True
150
+ gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
151
+ self.mips_index = faiss.IndexIDMap(gpu_index)
152
+ if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
153
+ print(">> Initialized index on GPU", flush=True)
154
+ else:
155
+ # CPU index supports IDs so wrap with IDMap
156
+ self.mips_index = faiss.IndexIDMap(cpu_index)
157
+ if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
158
+ print(">> Initialized index on CPU", flush=True)
159
+
160
+ # if we were constructed with a BlockData, then automatically load it
161
+ # when the FAISS structure is built
162
+ if self.embed_data is not None:
163
+ self.add_embed_data(self.embed_data)
164
+
165
+ def reset_index(self):
166
+ """Delete existing index and create a new"""
167
+ del self.mips_index
168
+
169
+ # reset the block data so that _set_block_index will reload it as well
170
+ if self.embed_data is not None:
171
+ embed_data_path = self.embed_data.embedding_path
172
+ del self.embed_data
173
+ self.embed_data = OpenRetreivalDataStore(embed_data_path)
174
+
175
+ self._set_mips_index()
176
+
177
+ def update_index(self):
178
+ """Delete existing index and create a new"""
179
+ del self.mips_index
180
+
181
+ # reset the block data so that _set_mips_index will reload it as well
182
+ if self.embed_data is not None:
183
+ self.embed_data.load_from_file()
184
+ self._set_mips_index()
185
+
186
+ def add_embed_data(self, all_embed_data):
187
+ """Add the embedding of each block to the underlying FAISS index"""
188
+
189
+ # this assumes the embed_data is a dict : {int: np.array<float>}
190
+ block_indices, block_embeds = zip(*all_embed_data.embed_data.items())
191
+
192
+ # the embeddings have to be entered in as float32 even though the math
193
+ # internally is done with float16.
194
+ embeds_arr = np.float32(np.array(block_embeds))
195
+ indices_arr = np.array(block_indices)
196
+
197
+ # we no longer need the embedding data since it's in the index now
198
+ all_embed_data.clear()
199
+
200
+ self.mips_index.add_with_ids(embeds_arr, indices_arr)
201
+
202
+ if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
203
+ print(">>> Finished adding block data to index", flush=True)
204
+
205
+ def search_mips_index(self, query_embeds, top_k, reconstruct=True):
206
+ """
207
+ Get the top-k blocks by the index distance metric.
208
+
209
+ :param reconstruct: if True: return a [num_queries x k x embed_dim]
210
+ array of blocks
211
+ if False: return [num_queries x k] array of
212
+ distances, and another for indices
213
+ """
214
+ query_embeds = np.float32(detach(query_embeds))
215
+
216
+ if reconstruct:
217
+ # get the vectors themselves
218
+ top_k_block_embeds = self.mips_index.search_and_reconstruct(\
219
+ query_embeds, top_k)
220
+ return top_k_block_embeds
221
+ else:
222
+ # get distances and indices of closest vectors
223
+ distances, block_indices = self.mips_index.search(query_embeds, top_k)
224
+ return distances, block_indices
megatron/data/t5_dataset.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """T5 Style dataset."""
17
+
18
+ import collections
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from megatron import get_tokenizer
24
+ from megatron.data.dataset_utils import (
25
+ create_masked_lm_predictions,
26
+ get_samples_mapping
27
+ )
28
+
29
+ class T5Dataset(torch.utils.data.Dataset):
30
+
31
+ def __init__(self, name, indexed_dataset, data_prefix,
32
+ num_epochs, max_num_samples, masked_lm_prob,
33
+ max_seq_length, max_seq_length_dec,
34
+ short_seq_prob, seed):
35
+
36
+ # Params to store.
37
+ self.name = name
38
+ self.seed = seed
39
+ self.masked_lm_prob = masked_lm_prob
40
+ self.max_seq_length = max_seq_length
41
+ self.max_seq_length_dec = max_seq_length_dec
42
+
43
+ # Dataset.
44
+ self.indexed_dataset = indexed_dataset
45
+
46
+ # Build the samples mapping.
47
+ self.samples_mapping = get_samples_mapping(self.indexed_dataset,
48
+ data_prefix,
49
+ num_epochs,
50
+ max_num_samples,
51
+ self.max_seq_length - 2, # account for added tokens
52
+ short_seq_prob,
53
+ self.seed,
54
+ self.name,
55
+ False)
56
+
57
+ # Vocab stuff.
58
+ tokenizer = get_tokenizer()
59
+ self.vocab_id_list = list(tokenizer.inv_vocab.keys())
60
+ self.vocab_id_to_token_dict = tokenizer.inv_vocab
61
+ self.cls_id = tokenizer.cls
62
+ self.sep_id = tokenizer.sep
63
+ self.mask_id = tokenizer.mask
64
+ self.pad_id = tokenizer.pad
65
+ self.bos_id = tokenizer.bos_token_id
66
+ self.eos_id = tokenizer.eos_token_id
67
+ self.sentinel_tokens = tokenizer.additional_special_tokens_ids
68
+ assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
69
+
70
+ def __len__(self):
71
+ return self.samples_mapping.shape[0]
72
+
73
+ def __getitem__(self, idx):
74
+
75
+ start_index, end_index, seq_length = self.samples_mapping[idx]
76
+ sample = []
77
+ for index in range(start_index, end_index):
78
+ sample.append(self.indexed_dataset[index])
79
+ # Note that this rng state should be numpy and not python since
80
+ # python randint is inclusive whereas the numpy one is exclusive.
81
+ np_rng = np.random.RandomState(seed=(self.seed + idx))
82
+ return build_training_sample(sample, seq_length,
83
+ self.max_seq_length, # needed for padding
84
+ self.max_seq_length_dec,
85
+ self.vocab_id_list,
86
+ self.vocab_id_to_token_dict,
87
+ self.cls_id, self.sep_id,
88
+ self.mask_id, self.pad_id,
89
+ self.masked_lm_prob, np_rng,
90
+ self.bos_id, self.eos_id,
91
+ self.sentinel_tokens)
92
+
93
+
94
+ def build_training_sample(sample, target_seq_length,
95
+ max_seq_length, max_seq_length_dec,
96
+ vocab_id_list, vocab_id_to_token_dict,
97
+ cls_id, sep_id, mask_id, pad_id,
98
+ masked_lm_prob, np_rng, bos_id=None,
99
+ eos_id=None, sentinel_tokens=None):
100
+ """Build training sample.
101
+
102
+ Arguments:
103
+ sample: A list of sentences in which each sentence is a list token ids.
104
+ target_seq_length: Desired sequence length.
105
+ max_seq_length: Maximum length of the sequence. All values are padded to
106
+ this length.
107
+ vocab_id_list: List of vocabulary ids. Used to pick a random id.
108
+ vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
109
+ cls_id: Start of example id.
110
+ sep_id: Separator id.
111
+ mask_id: Mask token id.
112
+ pad_id: Padding token id.
113
+ masked_lm_prob: Probability to mask tokens.
114
+ np_rng: Random number genenrator. Note that this rng state should be
115
+ numpy and not python since python randint is inclusive for
116
+ the opper bound whereas the numpy one is exclusive.
117
+ bos_id: start of decoder example id
118
+ eos_id: end of generation id
119
+ sentinel_tokens: unique value to be substituted for every replaced span
120
+ """
121
+
122
+ assert target_seq_length <= max_seq_length
123
+
124
+ # flatten sentences into one list
125
+ tokens = [token for sentence in sample for token in sentence]
126
+
127
+ # Truncate to `target_sequence_length`.
128
+ max_num_tokens = target_seq_length
129
+ truncated = len(tokens) > max_num_tokens
130
+ tokens = tokens[:max_num_tokens]
131
+
132
+ # Masking.
133
+ max_predictions_per_seq = masked_lm_prob * max_num_tokens
134
+ (tokens, masked_positions, masked_labels, _, masked_spans) = create_masked_lm_predictions(
135
+ tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
136
+ cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng,
137
+ max_ngrams=10, geometric_dist=True, masking_style="t5")
138
+
139
+ # Padding.
140
+ tokens_enc, tokens_dec_in, labels, enc_mask, \
141
+ dec_mask, enc_dec_mask, loss_mask \
142
+ = pad_and_convert_to_numpy(tokens, masked_positions,
143
+ masked_labels, pad_id, max_seq_length,
144
+ max_seq_length_dec, masked_spans,
145
+ bos_id, eos_id, sentinel_tokens)
146
+
147
+ train_sample = {
148
+ 'text_enc': tokens_enc,
149
+ 'text_dec': tokens_dec_in,
150
+ 'labels': labels,
151
+ 'loss_mask': loss_mask,
152
+ 'truncated': int(truncated),
153
+ 'enc_mask': enc_mask,
154
+ 'dec_mask': dec_mask,
155
+ 'enc_dec_mask': enc_dec_mask,
156
+ }
157
+ return train_sample
158
+
159
+
160
+ def pad_and_convert_to_numpy(tokens, masked_positions,
161
+ masked_labels, pad_id,
162
+ max_seq_length, max_seq_length_dec,
163
+ masked_spans=None, bos_id=None,
164
+ eos_id=None, sentinel_tokens=None):
165
+ """Pad sequences and convert them to numpy."""
166
+
167
+ sentinel_tokens = collections.deque(sentinel_tokens)
168
+ t5_input = []
169
+ (t5_decoder_in, t5_decoder_out) = ([bos_id], [])
170
+ (start_index, end_index) = (0, None)
171
+ for span in masked_spans:
172
+ flag = sentinel_tokens.popleft()
173
+
174
+ # Append the same tokens in decoder input and output
175
+ t5_decoder_in.append(flag)
176
+ t5_decoder_in.extend(span.label)
177
+ t5_decoder_out.append(flag)
178
+ t5_decoder_out.extend(span.label)
179
+
180
+ end_index = span.index[0]
181
+ t5_input.extend(tokens[start_index: end_index])
182
+ t5_input.append(flag)
183
+
184
+ # the next start index is the token after the last span token
185
+ start_index = span.index[-1] + 1
186
+
187
+ # Add <eos> token to the t5_decoder_out
188
+ t5_decoder_out.append(eos_id)
189
+
190
+ # Add the remaining tokens to the t5 input
191
+ t5_input.extend(tokens[start_index:])
192
+
193
+ # assert (len(t5_input) - len(masked_spans)) + \
194
+ # (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens)
195
+
196
+ # Some checks.
197
+
198
+ # Encoder-side padding mask.
199
+ num_tokens = len(t5_input)
200
+ padding_length = max_seq_length - num_tokens
201
+ assert padding_length >= 0
202
+ assert len(masked_positions) == len(masked_labels)
203
+
204
+ # Tokens..
205
+ filler = [pad_id] * padding_length
206
+ tokens_enc = np.array(t5_input + filler, dtype=np.int64)
207
+
208
+ # Decoder-side padding mask.
209
+ num_tokens_dec = len(t5_decoder_in)
210
+ padding_length_dec = max_seq_length_dec - num_tokens_dec
211
+ assert padding_length_dec >= 0
212
+ filler_dec = [pad_id] * padding_length_dec
213
+ tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64)
214
+
215
+ # Create attention masks
216
+ enc_mask = make_attention_mask(tokens_enc, tokens_enc)
217
+ enc_dec_mask = make_attention_mask(tokens_dec_in, tokens_enc)
218
+ dec_mask = make_attention_mask(tokens_dec_in, tokens_dec_in)
219
+ dec_mask = dec_mask * make_history_mask(tokens_dec_in)
220
+
221
+ # Labels mask.
222
+ labels = t5_decoder_out + ([-1] * padding_length_dec)
223
+ labels = np.array(labels, dtype=np.int64)
224
+
225
+ # Loss mask
226
+ loss_mask = ([1] * num_tokens_dec) + ([0] * padding_length_dec)
227
+ loss_mask = np.array(loss_mask, dtype=np.int64)
228
+
229
+ return tokens_enc, tokens_dec_in, labels, enc_mask, \
230
+ dec_mask, enc_dec_mask, loss_mask
231
+
232
+
233
+ def make_attention_mask(source_block, target_block):
234
+ """
235
+ Returns a 2-dimensional (2-D) attention mask
236
+ :param source_block: 1-D array
237
+ :param target_block: 1-D array
238
+ """
239
+ mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
240
+ mask = mask.astype(np.int64)
241
+ # (source_length, target_length)
242
+ return mask
243
+
244
+
245
+ def make_attention_mask_3d(source_block, target_block):
246
+ """
247
+ Returns a 3-dimensional (3-D) attention mask
248
+ :param source_block: 1-D array
249
+ :param target_block: 1-D array
250
+ """
251
+ mask = (target_block[:, None, :] >= 1) * (source_block[:, :, None] >= 1)
252
+ # (batch, source_length, target_length)
253
+ # mask = mask.astype(np.int64)
254
+ return mask
255
+
256
+
257
+ def make_history_mask(block):
258
+ length = block.shape[0]
259
+ arange = np.arange(length)
260
+ history_mask = (arange[None, ] <= arange[:, None])
261
+ history_mask = history_mask.astype(np.int64)
262
+ return history_mask
263
+
264
+
265
+ def make_history_mask_3d(block):
266
+ batch, length = block.shape
267
+ arange = torch.arange(length, device=block.device)
268
+ history_mask = (arange[None, ] <= arange[:, None])[None, ]
269
+ history_mask = history_mask.expand(batch, length, length)
270
+ return history_mask
megatron/data/test/test_indexed_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file isn't really a formal automated test, it's just a place to
2
+ # put some code used during development and manual testing of
3
+ # indexed_dataset.
4
+
5
+ from megatron.data import indexed_dataset
6
+ from megatron.tokenizer import build_tokenizer
7
+ import argparse
8
+ import os
9
+ import sys
10
+
11
+ import torch
12
+
13
+ script_dir = os.path.dirname(os.path.realpath(__file__))
14
+ sys.path.append(os.path.join(script_dir, "../../../"))
15
+
16
+
17
+ def test_indexed_dataset(args):
18
+ ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
19
+ tokenizer = build_tokenizer(args)
20
+ print(len(ds.doc_idx))
21
+ print(len(ds))
22
+ print(ds.doc_idx[-1])
23
+ if ds.supports_prefetch:
24
+ # just prefetch the whole thing in test (so assume it is small)
25
+ ds.prefetch(range(len(ds)))
26
+ if args.count > len(ds.doc_idx) - 1:
27
+ args.count = len(ds.doc_idx) - 1
28
+
29
+ for i in range(args.count):
30
+ start = ds.doc_idx[i]
31
+ end = ds.doc_idx[i + 1]
32
+ ids = ds[start:end]
33
+ print(f"Document {i}:")
34
+ print("--------------")
35
+ for s in ids:
36
+ assert len(s) > 0
37
+ l = s.data.tolist()
38
+ text = tokenizer.detokenize(l)
39
+ print(text)
40
+ print("---")
41
+
42
+
43
+ def test_indexed_dataset_get(args):
44
+ ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
45
+ tokenizer = build_tokenizer(args)
46
+ size = ds.sizes[0]
47
+ print(f"size: {size}")
48
+ full = ds.get(0)
49
+ print(full)
50
+ # print(tokenizer.detokenize(full.data.tolist()))
51
+ print("---")
52
+ end = ds.get(0, offset=size - 10)
53
+ print(end)
54
+ # print(tokenizer.detokenize(end.data.tolist()))
55
+
56
+ start = ds.get(0, length=10)
57
+ print(start)
58
+ # print(tokenizer.detokenize(start.data.tolist()))
59
+
60
+ part = ds.get(0, offset=2, length=8)
61
+ print(part)
62
+ # print(tokenizer.detokenize(part.data.tolist()))
63
+
64
+ # def test_albert_dataset(args):
65
+ # # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
66
+ # # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
67
+ # # ds = AlbertDataset(idataset, tokenizer)
68
+ # ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
69
+ # args.epochs, args.max_num_samples,
70
+ # args.masked_lm_prob, args.seq_length,
71
+ # args.short_seq_prob, args.seed)
72
+ # truncated = 0
73
+ # total = 0
74
+ # for i, s in enumerate(ds):
75
+ # ids = s['text']
76
+ # tokens = ds.tokenizer.convert_ids_to_tokens(ids)
77
+ # print(tokens)
78
+ # if i >= args.count-1:
79
+ # exit()
80
+
81
+
82
+ def main():
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument('--data', type=str, help='prefix to data files')
85
+ parser.add_argument('--dataset-impl', type=str, default='infer',
86
+ choices=['lazy', 'cached', 'mmap', 'infer'])
87
+ parser.add_argument('--count', type=int, default=10,
88
+ help='Number of samples/documents to print')
89
+
90
+ group = parser.add_argument_group(title='tokenizer')
91
+ group.add_argument('--tokenizer-type', type=str, required=True,
92
+ choices=['BertWordPieceLowerCase',
93
+ 'GPT2BPETokenizer'],
94
+ help='What type of tokenizer to use.')
95
+ group.add_argument('--vocab-file', type=str, default=None,
96
+ help='Path to the vocab file')
97
+ group.add_argument('--merge-file', type=str, default=None,
98
+ help='Path to the BPE merge file (if necessary).')
99
+
100
+ parser.add_argument('--epochs', type=int, default=5,
101
+ help='Number of epochs to plan for')
102
+ parser.add_argument('--max-num-samples', type=int, default=None,
103
+ help='Maximum number of samples to plan for')
104
+ parser.add_argument('--masked-lm-prob', type=float, default=0.15,
105
+ help='probability of masking tokens')
106
+ parser.add_argument('--seq-length', type=int, default=512,
107
+ help='maximum sequence length')
108
+ parser.add_argument('--short-seq-prob', type=float, default=0.1,
109
+ help='probability of creating a short sequence')
110
+ parser.add_argument('--seed', type=int, default=1234,
111
+ help='random seed')
112
+ args = parser.parse_args()
113
+ args.rank = 0
114
+ args.make_vocab_size_divisible_by = 128
115
+ args.tensor_model_parallel_size = 1
116
+
117
+ if args.dataset_impl == "infer":
118
+ args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
119
+
120
+ # test_albert_dataset(args)
121
+ test_indexed_dataset_get(args)
122
+
123
+
124
+ if __name__ == "__main__":
125
+ main()
megatron/data/test/test_preprocess_data.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ IMPL=cached
4
+ python ../preprocess_data.py \
5
+ --input test_samples.json \
6
+ --vocab vocab.txt \
7
+ --dataset-impl ${IMPL} \
8
+ --output-prefix test_samples_${IMPL} \
9
+ --workers 1 \
10
+ --log-interval 2
megatron/data/vit_dataset.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+ import random
17
+ import numpy as np
18
+ import torch
19
+ import torchvision.transforms as T
20
+ from torchvision import datasets
21
+ from megatron import get_args
22
+ from megatron.data.image_folder import ImageFolder
23
+ from megatron.data.autoaugment import ImageNetPolicy
24
+ from megatron.data.data_samplers import RandomSeedDataset
25
+ from PIL import Image, ImageFilter, ImageOps
26
+
27
+
28
+ class GaussianBlur(object):
29
+ """
30
+ Apply Gaussian Blur to the PIL image.
31
+ """
32
+ def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
33
+ self.prob = p
34
+ self.radius_min = radius_min
35
+ self.radius_max = radius_max
36
+
37
+ def __call__(self, img):
38
+ do_it = random.random() <= self.prob
39
+ if not do_it:
40
+ return img
41
+
42
+ return img.filter(
43
+ ImageFilter.GaussianBlur(
44
+ radius=random.uniform(self.radius_min, self.radius_max)
45
+ )
46
+ )
47
+
48
+
49
+ class Solarization(object):
50
+ """
51
+ Apply Solarization to the PIL image.
52
+ """
53
+ def __init__(self, p):
54
+ self.p = p
55
+
56
+ def __call__(self, img):
57
+ if random.random() < self.p:
58
+ return ImageOps.solarize(img)
59
+ else:
60
+ return img
61
+
62
+
63
+ class ClassificationTransform():
64
+ def __init__(self, image_size, train=True):
65
+ args = get_args()
66
+ assert args.fp16 or args.bf16
67
+ self.data_type = torch.half if args.fp16 else torch.bfloat16
68
+ if train:
69
+ self.transform = T.Compose([
70
+ T.RandomResizedCrop(image_size),
71
+ T.RandomHorizontalFlip(),
72
+ T.ColorJitter(0.4, 0.4, 0.4, 0.1),
73
+ ImageNetPolicy(),
74
+ T.ToTensor(),
75
+ T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
76
+ T.ConvertImageDtype(self.data_type)
77
+ ])
78
+ else:
79
+ self.transform = T.Compose([
80
+ T.Resize(image_size),
81
+ T.CenterCrop(image_size),
82
+ T.ToTensor(),
83
+ T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
84
+ T.ConvertImageDtype(self.data_type)
85
+ ])
86
+
87
+ def __call__(self, input):
88
+ output = self.transform(input)
89
+ return output
90
+
91
+
92
+ class InpaintingTransform():
93
+ def __init__(self, image_size, train=True):
94
+
95
+ args = get_args()
96
+ self.mask_factor = args.mask_factor
97
+ self.mask_type = args.mask_type
98
+ self.image_size = image_size
99
+ self.patch_size = args.patch_dim
100
+ self.mask_size = int(self.mask_factor*(image_size[0]/self.patch_size)*(image_size[1]/self.patch_size))
101
+ self.train = train
102
+ assert args.fp16 or args.bf16
103
+ self.data_type = torch.half if args.fp16 else torch.bfloat16
104
+
105
+ if self.train:
106
+ self.transform = T.Compose([
107
+ T.RandomResizedCrop(self.image_size),
108
+ T.RandomHorizontalFlip(),
109
+ T.ColorJitter(0.4, 0.4, 0.4, 0.1),
110
+ ImageNetPolicy(),
111
+ T.ToTensor(),
112
+ T.ConvertImageDtype(self.data_type)
113
+ ])
114
+ else:
115
+ self.transform = T.Compose([
116
+ T.Resize(self.image_size, interpolation=2),
117
+ T.CenterCrop(self.image_size),
118
+ T.ToTensor(),
119
+ T.ConvertImageDtype(self.data_type)
120
+ ])
121
+
122
+ def gen_mask(self, image_size, mask_size, mask_type, patch_size):
123
+ # output: mask as a list with indices for missing patches
124
+ action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
125
+ assert image_size[0] == image_size[1]
126
+ img_size_patch = image_size[0] // patch_size
127
+
128
+ # drop masked patches
129
+ mask = torch.zeros((image_size[0], image_size[1]), dtype=torch.float)
130
+
131
+ if mask_type == 'random':
132
+ x = torch.randint(0, img_size_patch, ())
133
+ y = torch.randint(0, img_size_patch, ())
134
+ for i in range(mask_size):
135
+ r = torch.randint(0, len(action_list), ())
136
+ x = torch.clamp(x + action_list[r][0], min=0, max=img_size_patch - 1)
137
+ y = torch.clamp(y + action_list[r][1], min=0, max=img_size_patch - 1)
138
+ x_offset = x * patch_size
139
+ y_offset = y * patch_size
140
+ mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
141
+ else:
142
+ assert mask_type == 'row'
143
+ count = 0
144
+ for x in reversed(range(img_size_patch)):
145
+ for y in reversed(range(img_size_patch)):
146
+ if (count < mask_size):
147
+ count += 1
148
+ x_offset = x * patch_size
149
+ y_offset = y * patch_size
150
+ mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
151
+ return mask
152
+
153
+ def __call__(self, input):
154
+ trans_input = self.transform(input)
155
+ mask = self.gen_mask(self.image_size, self.mask_size,
156
+ self.mask_type, self.patch_size)
157
+ mask = mask.unsqueeze(dim=0)
158
+ return trans_input, mask
159
+
160
+
161
+ class DinoTransform(object):
162
+ def __init__(self, image_size, train=True):
163
+ args = get_args()
164
+ self.data_type = torch.half if args.fp16 else torch.bfloat16
165
+
166
+ flip_and_color_jitter = T.Compose([
167
+ T.RandomHorizontalFlip(p=0.5),
168
+ T.RandomApply(
169
+ [T.ColorJitter(brightness=0.4, contrast=0.4,
170
+ saturation=0.2, hue=0.1)],
171
+ p=0.8
172
+ ),
173
+ T.RandomGrayscale(p=0.2),
174
+ ])
175
+
176
+ if args.fp16 or args.bf16:
177
+ normalize = T.Compose([
178
+ T.ToTensor(),
179
+ T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
180
+ T.ConvertImageDtype(self.data_type)
181
+ ])
182
+ else:
183
+ normalize = T.Compose([
184
+ T.ToTensor(),
185
+ T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
186
+ ])
187
+
188
+ # first global crop
189
+ scale_const = 0.4
190
+ self.global_transform1 = T.Compose([
191
+ T.RandomResizedCrop(image_size,
192
+ scale=(scale_const, 1),
193
+ interpolation=Image.BICUBIC),
194
+ flip_and_color_jitter,
195
+ GaussianBlur(1.0),
196
+ normalize
197
+ ])
198
+ # second global crop
199
+ self.global_transform2 = T.Compose([
200
+ T.RandomResizedCrop(image_size,
201
+ scale=(scale_const, 1),
202
+ interpolation=Image.BICUBIC),
203
+ flip_and_color_jitter,
204
+ GaussianBlur(0.1),
205
+ Solarization(0.2),
206
+ normalize
207
+ ])
208
+ # transformation for the local small crops
209
+ self.local_crops_number = args.dino_local_crops_number
210
+ self.local_transform = T.Compose([
211
+ T.RandomResizedCrop(args.dino_local_img_size,
212
+ scale=(0.05, scale_const),
213
+ interpolation=Image.BICUBIC),
214
+ flip_and_color_jitter,
215
+ GaussianBlur(p=0.5),
216
+ normalize
217
+ ])
218
+
219
+ def __call__(self, image):
220
+ crops = []
221
+ crops.append(self.global_transform1(image))
222
+ crops.append(self.global_transform2(image))
223
+ for _ in range(self.local_crops_number):
224
+ crops.append(self.local_transform(image))
225
+ return crops
226
+
227
+
228
+ def build_train_valid_datasets(data_path, image_size=224):
229
+ args = get_args()
230
+
231
+ if args.vision_pretraining_type == 'classify':
232
+ train_transform = ClassificationTransform(image_size)
233
+ val_transform = ClassificationTransform(image_size, train=False)
234
+ elif args.vision_pretraining_type == 'inpaint':
235
+ train_transform = InpaintingTransform(image_size, train=False)
236
+ val_transform = InpaintingTransform(image_size, train=False)
237
+ elif args.vision_pretraining_type == 'dino':
238
+ train_transform = DinoTransform(image_size, train=True)
239
+ val_transform = ClassificationTransform(image_size, train=False)
240
+ else:
241
+ raise Exception('{} vit pretraining type is not supported.'.format(
242
+ args.vit_pretraining_type))
243
+
244
+ # training dataset
245
+ train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2]
246
+ train_data = ImageFolder(
247
+ root=train_data_path,
248
+ transform=train_transform,
249
+ classes_fraction=args.classes_fraction,
250
+ data_per_class_fraction=args.data_per_class_fraction
251
+ )
252
+ train_data = RandomSeedDataset(train_data)
253
+
254
+ # validation dataset
255
+ val_data_path = data_path[1]
256
+ val_data = ImageFolder(
257
+ root=val_data_path,
258
+ transform=val_transform
259
+ )
260
+ val_data = RandomSeedDataset(val_data)
261
+
262
+ return train_data, val_data
megatron/dist_signal_handler.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import signal
2
+
3
+ import torch
4
+
5
+
6
+ def get_world_size():
7
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
8
+ world_size = torch.distributed.get_world_size()
9
+ else:
10
+ world_size = 1
11
+ return world_size
12
+
13
+
14
+ def get_device(local_rank=None):
15
+ backend = torch.distributed.get_backend()
16
+ if backend == 'nccl':
17
+ if local_rank is None:
18
+ device = torch.device('cuda')
19
+ else:
20
+ device = torch.device(f'cuda:{local_rank}')
21
+ elif backend == 'gloo':
22
+ device = torch.device('cpu')
23
+ else:
24
+ raise RuntimeError
25
+ return device
26
+
27
+
28
+ def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None):
29
+ if not torch.distributed.is_available() or \
30
+ not torch.distributed.is_initialized():
31
+ return [item]
32
+
33
+ device = get_device(local_rank)
34
+
35
+ if group is not None:
36
+ group_size = group.size()
37
+ else:
38
+ group_size = get_world_size()
39
+
40
+ tensor = torch.tensor([item], device=device, dtype=dtype)
41
+ output_tensors = [
42
+ torch.zeros(1, dtype=tensor.dtype, device=tensor.device)
43
+ for _ in range(group_size)
44
+ ]
45
+ torch.distributed.all_gather(output_tensors, tensor, group, async_op)
46
+ output = [elem.item() for elem in output_tensors]
47
+ return output
48
+
49
+
50
+ class DistributedSignalHandler:
51
+ def __init__(self, sig=signal.SIGTERM):
52
+ self.sig = sig
53
+
54
+ def signals_received(self):
55
+ all_received = all_gather_item(
56
+ self._signal_received, dtype=torch.int32
57
+ )
58
+ return all_received
59
+
60
+ def __enter__(self):
61
+ self._signal_received = False
62
+ self.released = False
63
+ self.original_handler = signal.getsignal(self.sig)
64
+
65
+ def handler(signum, frame):
66
+ self._signal_received = True
67
+
68
+ signal.signal(self.sig, handler)
69
+
70
+ return self
71
+
72
+ def __exit__(self, type, value, tb):
73
+ self.release()
74
+
75
+ def release(self):
76
+ if self.released:
77
+ return False
78
+
79
+ signal.signal(self.sig, self.original_handler)
80
+ self.released = True
81
+ return True
megatron/fp16_deprecated/loss_scaler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """For backward compatibility, we need the class definitions to deserialize."""
17
+
18
+ class LossScaler:
19
+ def __init__(self, scale=1):
20
+ self.cur_scale = scale
21
+
22
+ class DynamicLossScaler:
23
+ def __init__(self,
24
+ init_scale=2**32,
25
+ scale_factor=2.,
26
+ scale_window=1000,
27
+ min_scale=1,
28
+ delayed_shift=1,
29
+ consecutive_hysteresis=False):
30
+ self.cur_scale = init_scale
31
+ self.cur_iter = 0
32
+ self.last_overflow_iter = -1
33
+ self.scale_factor = scale_factor
34
+ self.scale_window = scale_window
35
+ self.min_scale = min_scale
36
+ self.delayed_shift = delayed_shift
37
+ self.cur_hysteresis = delayed_shift
38
+ self.consecutive_hysteresis = consecutive_hysteresis
39
+
megatron/fused_kernels/__init__.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import pathlib
18
+ import subprocess
19
+
20
+ from torch.utils import cpp_extension
21
+
22
+ # Setting this param to a list has a problem of generating different
23
+ # compilation commands (with diferent order of architectures) and
24
+ # leading to recompilation of fused kernels. Set it to empty string
25
+ # to avoid recompilation and assign arch flags explicity in
26
+ # extra_cuda_cflags below
27
+ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
28
+
29
+
30
+ def load(args):
31
+
32
+ # Check if cuda 11 is installed for compute capability 8.0
33
+ cc_flag = []
34
+ _, bare_metal_major, _ = _get_cuda_bare_metal_version(
35
+ cpp_extension.CUDA_HOME)
36
+ if int(bare_metal_major) >= 11:
37
+ cc_flag.append('-gencode')
38
+ cc_flag.append('arch=compute_80,code=sm_80')
39
+
40
+ # Build path
41
+ srcpath = pathlib.Path(__file__).parent.absolute()
42
+ buildpath = srcpath / 'build'
43
+ _create_build_dir(buildpath)
44
+
45
+ # Helper function to build the kernels.
46
+ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
47
+ return cpp_extension.load(
48
+ name=name,
49
+ sources=sources,
50
+ build_directory=buildpath,
51
+ extra_cflags=['-O3',],
52
+ extra_cuda_cflags=['-O3',
53
+ '-gencode', 'arch=compute_70,code=sm_70',
54
+ '--use_fast_math'] + extra_cuda_flags + cc_flag,
55
+ verbose=(args.rank == 0)
56
+ )
57
+
58
+ # ==============
59
+ # Fused softmax.
60
+ # ==============
61
+
62
+ if args.masked_softmax_fusion:
63
+ extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
64
+ '-U__CUDA_NO_HALF_CONVERSIONS__',
65
+ '--expt-relaxed-constexpr',
66
+ '--expt-extended-lambda']
67
+
68
+ # Upper triangular softmax.
69
+ sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
70
+ srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
71
+ scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
72
+ "scaled_upper_triang_masked_softmax_cuda",
73
+ sources, extra_cuda_flags)
74
+
75
+ # Masked softmax.
76
+ sources=[srcpath / 'scaled_masked_softmax.cpp',
77
+ srcpath / 'scaled_masked_softmax_cuda.cu']
78
+ scaled_masked_softmax_cuda = _cpp_extention_load_helper(
79
+ "scaled_masked_softmax_cuda", sources, extra_cuda_flags)
80
+
81
+ # Softmax
82
+ sources=[srcpath / 'scaled_softmax.cpp',
83
+ srcpath / 'scaled_softmax_cuda.cu']
84
+ scaled_softmax_cuda = _cpp_extention_load_helper(
85
+ "scaled_softmax_cuda", sources, extra_cuda_flags)
86
+
87
+ # =================================
88
+ # Mixed precision fused layer norm.
89
+ # =================================
90
+
91
+ extra_cuda_flags = ['-maxrregcount=50']
92
+ sources=[srcpath / 'layer_norm_cuda.cpp',
93
+ srcpath / 'layer_norm_cuda_kernel.cu']
94
+ fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
95
+ "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
96
+
97
+ # =================================
98
+ # Fused gradient accumulation to weight gradient computation of linear layer
99
+ # =================================
100
+
101
+ if args.gradient_accumulation_fusion:
102
+ sources=[srcpath / 'fused_weight_gradient_dense.cpp',
103
+ srcpath / 'fused_weight_gradient_dense.cu']
104
+ fused_dense_cuda = _cpp_extention_load_helper(
105
+ "fused_dense_cuda", sources, [])
106
+
107
+
108
+ def _get_cuda_bare_metal_version(cuda_dir):
109
+ raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
110
+ universal_newlines=True)
111
+ output = raw_output.split()
112
+ release_idx = output.index("release") + 1
113
+ release = output[release_idx].split(".")
114
+ bare_metal_major = release[0]
115
+ bare_metal_minor = release[1][0]
116
+
117
+ return raw_output, bare_metal_major, bare_metal_minor
118
+
119
+
120
+ def _create_build_dir(buildpath):
121
+ try:
122
+ os.mkdir(buildpath)
123
+ except OSError:
124
+ if not os.path.isdir(buildpath):
125
+ print(f"Creation of the build directory {buildpath} failed")
megatron/fused_kernels/build/.ninja_deps ADDED
Binary file (128 kB). View file
 
megatron/fused_kernels/build/.ninja_log ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ninja log v5
2
+ 0 10410 1666322917541195629 scaled_upper_triang_masked_softmax.o 3f7d6908014e2c4b
3
+ 0 44996 1666322952117194853 scaled_upper_triang_masked_softmax_cuda.cuda.o 63e21f5b42e40bd0
4
+ 44999 45365 1666322952493194844 scaled_upper_triang_masked_softmax_cuda.so f84b117fa81963e5
5
+ 1 10116 1666322962741194614 scaled_masked_softmax.o 9dfad674b44501f2
6
+ 1 45749 1666322998365193814 scaled_masked_softmax_cuda.cuda.o aec0e3e7e0fe0af5
7
+ 45749 46118 1666322998741193805 scaled_masked_softmax_cuda.so a24b9ad6a01f9db5
8
+ 0 10316 1666323009189193571 scaled_softmax.o ba78446f9188fae0
9
+ 0 44538 1666323043401192802 scaled_softmax_cuda.cuda.o 77ea362b8652c7e9
10
+ 44538 44904 1666323043777192794 scaled_softmax_cuda.so d8fa0ebfd78e8bd8
11
+ 0 10918 1666323054829192545 layer_norm_cuda.o 7dc5869ac5593422
12
+ 0 11891 1666323055797192524 layer_norm_cuda_kernel.cuda.o 13d0d213fbbb62de
13
+ 11891 12255 1666323056165192515 fused_mix_prec_layer_norm_cuda.so a21986e1b00b3401
14
+ 0 10072 1666682710301113263 scaled_upper_triang_masked_softmax.o c11d897e7800befb
15
+ 0 46206 1666682746425112452 scaled_upper_triang_masked_softmax_cuda.cuda.o bc610e36d8dfd435
16
+ 46206 46587 1666682746813112443 scaled_upper_triang_masked_softmax_cuda.so f84b117fa81963e5
17
+ 0 9858 1666682756829112218 scaled_masked_softmax.o fedebad209ed2d21
18
+ 0 46362 1666682793321111399 scaled_masked_softmax_cuda.cuda.o 51814239e7caea9a
19
+ 46362 46747 1666682793717111390 scaled_masked_softmax_cuda.so a24b9ad6a01f9db5
20
+ 0 9870 1666682803741111164 scaled_softmax.o 1d9e3231fe352c0b
21
+ 0 46512 1666682840373110342 scaled_softmax_cuda.cuda.o f9b5a976cff0a5ef
22
+ 46513 46900 1666682840769110333 scaled_softmax_cuda.so d8fa0ebfd78e8bd8
23
+ 0 10615 1666682851533110091 layer_norm_cuda.o 3cacb26d8faa2b99
24
+ 0 11849 1666682852761110063 layer_norm_cuda_kernel.cuda.o 319de99ce0920143
25
+ 11849 12230 1666682853145110055 fused_mix_prec_layer_norm_cuda.so a21986e1b00b3401
26
+ 0 12428 1666750089507718000 scaled_upper_triang_masked_softmax.o 8e61e453c7b77ff5
27
+ 0 46226 1666750123300534000 scaled_upper_triang_masked_softmax_cuda.cuda.o 193ac2a539f3f292
28
+ 46228 47144 1666750124224556000 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4
29
+ 0 11486 1666750135832836000 scaled_masked_softmax.o b6378099b518f069
30
+ 0 47221 1666750171561699000 scaled_masked_softmax_cuda.cuda.o 4faef3fa30fe1e1d
31
+ 47223 48124 1666750172469721000 scaled_masked_softmax_cuda.so d6611febaa933d3d
32
+ 0 11564 1666750184410010000 scaled_softmax.o a90db6c821074406
33
+ 0 46461 1666750219302852000 scaled_softmax_cuda.cuda.o bf0ec8bfec64157c
34
+ 46464 47488 1666750220334877000 scaled_softmax_cuda.so e7199387ed26e64e
35
+ 0 12007 1666750232439170000 layer_norm_cuda.o 6a55ca87d1a8c0b2
36
+ 0 12020 1666750232447170000 layer_norm_cuda_kernel.cuda.o 655c31ba3cbc10c2
37
+ 12022 13866 1666750234299214000 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da
38
+ 0 10929 1666856346511840836 scaled_upper_triang_masked_softmax.o 5a6ff0631bbc2735
39
+ 0 43626 1666856379199840197 scaled_upper_triang_masked_softmax_cuda.cuda.o c36b22f9d9fc2117
40
+ 43626 44026 1666856379607840189 scaled_upper_triang_masked_softmax_cuda.so f84b117fa81963e5
41
+ 0 10190 1666856389955839987 scaled_masked_softmax.o e032e1ed28e01e30
42
+ 0 44582 1666856424331839315 scaled_masked_softmax_cuda.cuda.o 2c8db7df38489475
43
+ 44582 44961 1666856424723839308 scaled_masked_softmax_cuda.so a24b9ad6a01f9db5
44
+ 0 10142 1666856435015839107 scaled_softmax.o 446947b66b18fa33
45
+ 0 44480 1666856469343838436 scaled_softmax_cuda.cuda.o 4fb07733497ecc29
46
+ 44480 44879 1666856469751838428 scaled_softmax_cuda.so d8fa0ebfd78e8bd8
47
+ 0 10396 1666856480295838222 layer_norm_cuda.o 73da6101a07a24a7
48
+ 1 11899 1666856481791838193 layer_norm_cuda_kernel.cuda.o 9ec8eab79e592ff4
49
+ 11899 12298 1666856482195838185 fused_mix_prec_layer_norm_cuda.so a21986e1b00b3401
50
+ 1 12100 1666925285098117232 scaled_upper_triang_masked_softmax.o f73d4cea858af8b1
51
+ 2 45529 1666925318557711559 scaled_upper_triang_masked_softmax_cuda.cuda.o 2fa8c20456ca471c
52
+ 45531 47926 1666925320961971387 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4
53
+ 1 12395 1666925333323307273 scaled_masked_softmax.o ae4878f941317da4
54
+ 1 46831 1666925368027057696 scaled_masked_softmax_cuda.cuda.o 94fda8fac85d606d
55
+ 46833 48525 1666925369731241867 scaled_masked_softmax_cuda.so d6611febaa933d3d
56
+ 1 12276 1666925382212590722 scaled_softmax.o 7a00c61166684714
57
+ 1 47263 1666925417188370543 scaled_softmax_cuda.cuda.o dfe840df0dc3178d
58
+ 47265 50987 1666925420916773471 scaled_softmax_cuda.so e7199387ed26e64e
59
+ 1 13036 1666925434150203603 layer_norm_cuda_kernel.cuda.o 128560bba544b6cb
60
+ 1 14561 1666925435354333732 layer_norm_cuda.o e02c8859a84e70db
61
+ 14569 15455 1666925436574465592 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da
62
+ 1 62343 1666962134003529000 scaled_upper_triang_masked_softmax_cuda.cuda.o abba0fca57f22344
63
+ 62363 63833 1666962135511529000 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4
64
+ 0 12215 1667460191194232000 scaled_upper_triang_masked_softmax.o 96d879ae2bf7b993
65
+ 1 49211 1667460228190231000 scaled_upper_triang_masked_softmax_cuda.cuda.o bc4b370b3c3d5c9e
66
+ 49213 50896 1667460229886231000 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4
67
+ 0 13297 1667460243334231000 scaled_masked_softmax.o a6809b1177c7ca02
68
+ 1 50055 1667460280082230000 scaled_masked_softmax_cuda.cuda.o dfbe364852f092fc
69
+ 50057 65422 1667460295430230000 scaled_masked_softmax_cuda.so d6611febaa933d3d
70
+ 1 12055 1667460307682230000 scaled_softmax.o cd4f40829964c3cb
71
+ 1 48489 1667460344126229000 scaled_softmax_cuda.cuda.o a6917d3b3ea80f97
72
+ 48526 49856 1667460345502229000 scaled_softmax_cuda.so e7199387ed26e64e
73
+ 0 13966 1667460359626229000 layer_norm_cuda.o e644ccb47b3615c
74
+ 1 15506 1667460361158229000 layer_norm_cuda_kernel.cuda.o 2d32cb24bea852c7
75
+ 15509 40152 1667460385810228000 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da
76
+ 0 11006 1676876089228927885 scaled_upper_triang_masked_softmax.o 9c8f7b7399ab2d1f
77
+ 0 42047 1676876120260456898 scaled_upper_triang_masked_softmax_cuda.cuda.o 12331cadf47cb899
78
+ 42047 42264 1676876120488482829 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4
79
+ 1 10823 1676876131453729835 scaled_masked_softmax.o 6e7c1e4df8bc11a8
80
+ 1 47902 1676876168521945360 scaled_masked_softmax_cuda.cuda.o c86371ad5d3e19ff
81
+ 47902 48087 1676876168713967197 scaled_masked_softmax_cuda.so d6611febaa933d3d
82
+ 0 13926 1676876182787567696 scaled_softmax.o eeaf8300d7ba52f3
83
+ 0 42044 1676876210894764143 scaled_softmax_cuda.cuda.o 379540e7c9ee343a
84
+ 42044 42242 1676876211098787345 scaled_softmax_cuda.so e7199387ed26e64e
85
+ 0 11125 1676876222348066652 layer_norm_cuda_kernel.cuda.o 5fa2a5b112be408c
86
+ 0 11375 1676876222600095314 layer_norm_cuda.o e369f8c0d20bc213
87
+ 11375 11572 1676876222800118061 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da
88
+ 2 12764 1682404121987049000 scaled_upper_triang_masked_softmax.o 20d2ceb970d5c522
89
+ 3 48533 1682404157745832000 scaled_upper_triang_masked_softmax_cuda.cuda.o e3bb3890927e826b
90
+ 48592 49557 1682404158781913000 scaled_upper_triang_masked_softmax_cuda.so 1faf6b7eefee1ef4
91
+ 2 13484 1682404172586987000 scaled_masked_softmax.o 8b2e2a7ca1fd841b
92
+ 2 48141 1682404207237685000 scaled_masked_softmax_cuda.cuda.o 6215ee08e9bfd383
93
+ 48143 49151 1682404208257765000 scaled_masked_softmax_cuda.so d6611febaa933d3d
94
+ 2 12243 1682404220754738000 scaled_softmax.o 2842472d594d0d1d
95
+ 2 51180 1682404259661769000 scaled_softmax_cuda.cuda.o 78419d022eea9ad2
96
+ 51184 52124 1682404260637845000 scaled_softmax_cuda.so e7199387ed26e64e
97
+ 2 12762 1682404273698863000 layer_norm_cuda_kernel.cuda.o dec411c038eb6254
98
+ 1 13361 1682404274238905000 layer_norm_cuda.o 14db0d087e6f7321
99
+ 13415 14196 1682404275138975000 fused_mix_prec_layer_norm_cuda.so 55e2400ed170f0da
megatron/fused_kernels/build/build.ninja ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ninja_required_version = 1.3
2
+ cxx = c++
3
+ nvcc = /usr/local/cuda/bin/nvcc
4
+
5
+ cflags = -DTORCH_EXTENSION_NAME=fused_mix_prec_layer_norm_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -isystem /opt/conda/lib/python3.8/site-packages/torch/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++14 -O3
6
+ post_cflags =
7
+ cuda_cflags = -DTORCH_EXTENSION_NAME=fused_mix_prec_layer_norm_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -isystem /opt/conda/lib/python3.8/site-packages/torch/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -O3 -gencode arch=compute_70,code=sm_70 --use_fast_math -maxrregcount=50 -gencode arch=compute_80,code=sm_80 -std=c++14
8
+ cuda_post_cflags =
9
+ ldflags = -shared -L/opt/conda/lib/python3.8/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart
10
+
11
+ rule compile
12
+ command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
13
+ depfile = $out.d
14
+ deps = gcc
15
+
16
+ rule cuda_compile
17
+ command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
18
+
19
+ rule link
20
+ command = $cxx $in $ldflags -o $out
21
+
22
+ build layer_norm_cuda.o: compile /root/ouyangxuan/project/big_model_finetune/Megatrion-LM-clear/megatron/fused_kernels/layer_norm_cuda.cpp
23
+ build layer_norm_cuda_kernel.cuda.o: cuda_compile /root/ouyangxuan/project/big_model_finetune/Megatrion-LM-clear/megatron/fused_kernels/layer_norm_cuda_kernel.cu
24
+
25
+ build fused_mix_prec_layer_norm_cuda.so: link layer_norm_cuda.o layer_norm_cuda_kernel.cuda.o
26
+
27
+ default fused_mix_prec_layer_norm_cuda.so
28
+
megatron/fused_kernels/build/fused_mix_prec_layer_norm_cuda.so ADDED
Binary file (700 kB). View file
 
megatron/fused_kernels/build/layer_norm_cuda.o ADDED
Binary file (293 kB). View file
 
megatron/fused_kernels/build/layer_norm_cuda_kernel.cuda.o ADDED
Binary file (545 kB). View file
 
megatron/fused_kernels/build/scaled_masked_softmax.o ADDED
Binary file (239 kB). View file
 
megatron/fused_kernels/build/scaled_masked_softmax_cuda.cuda.o ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23117ca78a4427a192781d5bff9b2ddc34711ec3c57f6bd5c7c4b7d3b634e429
3
+ size 1196624
megatron/fused_kernels/build/scaled_masked_softmax_cuda.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:409dbcba44e37f1e791e8ce1cb82f4a6884c6eb68aa281311a675462e91762ea
3
+ size 1283032
megatron/fused_kernels/build/scaled_softmax.o ADDED
Binary file (229 kB). View file
 
megatron/fused_kernels/build/scaled_softmax_cuda.cuda.o ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c2e292b61e4060b6a2c263bd95dcc08c4ab9c29ae4cf74df98bbd2ad4b566ee
3
+ size 1084024
megatron/fused_kernels/build/scaled_softmax_cuda.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c0b73baf4b0a10ccf124c4e20aabefc122e2f8b0b0887dccfb7eafe3cd5e39c
3
+ size 1170600
megatron/fused_kernels/build/scaled_upper_triang_masked_softmax.o ADDED
Binary file (230 kB). View file
 
megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.cuda.o ADDED
Binary file (944 kB). View file
 
megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e1b57adfdbb4254303f89ad4c8f786f9f7b4516c8fa2e95339ae5177a69e4a5
3
+ size 1032720
megatron/fused_kernels/compat.h ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*This code is copied fron NVIDIA apex:
18
+ * https://github.com/NVIDIA/apex
19
+ * with minor changes. */
20
+
21
+
22
+
23
+ #ifndef TORCH_CHECK
24
+ #define TORCH_CHECK AT_CHECK
25
+ #endif
26
+
27
+ #ifdef VERSION_GE_1_3
28
+ #define DATA_PTR data_ptr
29
+ #else
30
+ #define DATA_PTR data
31
+ #endif
megatron/fused_kernels/fused_weight_gradient_dense.cpp ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/torch.h>
2
+ #include <torch/extension.h>
3
+
4
+ #include <vector>
5
+ #include <stdio.h>
6
+
7
+ #include "type_shim.h"
8
+
9
+
10
+ template <typename T>
11
+ int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
12
+
13
+ void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at::Tensor d_weight) {
14
+ at::Tensor input_2d, d_output_2d;
15
+ // input tensor: collapse to the first dim
16
+ auto in_sizes = input.sizes();
17
+ if (input.dim() > 2) {
18
+ input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
19
+ } else {
20
+ input_2d = input;
21
+ }
22
+ // d_output tensor: collapse to the first dim
23
+ auto d_out_sizes = d_output.sizes();
24
+ if (d_output.dim() > 2) {
25
+ d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
26
+ } else {
27
+ d_output_2d = d_output;
28
+ }
29
+
30
+ int hidden_dim = input_2d.size(0);
31
+ int in_dim = input_2d.size(1);
32
+ int out_dim = d_weight.size(0);
33
+
34
+ DISPATCH_HALF_BFLOAT_AND_FLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32",
35
+ int result = wgrad_gemm_accum_fp32_cuda<scalar_t>(
36
+ input_2d.data_ptr<scalar_t>(),
37
+ d_output_2d.data_ptr<scalar_t>(),
38
+ d_weight.data_ptr<float>(),
39
+ in_dim,
40
+ hidden_dim,
41
+ out_dim);
42
+ );
43
+ }
44
+
45
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
46
+ m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32, "wgrad gemm accum in fp32");
47
+ }
megatron/fused_kernels/fused_weight_gradient_dense.cu ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
+ #include <assert.h>
4
+ #include <stdio.h>
5
+ #include <stdlib.h>
6
+ #include <string.h>
7
+ #include <torch/torch.h>
8
+
9
+ /* Includes, cuda */
10
+ #include <cublas_v2.h>
11
+ #include <cuda_runtime.h>
12
+
13
+
14
+ // BF16 Tensor core wrapper around cublas GEMMEx
15
+ cublasStatus_t gemmex_wrapper(
16
+ cublasHandle_t handle,
17
+ cublasOperation_t transa,
18
+ cublasOperation_t transb,
19
+ int m,
20
+ int n,
21
+ int k,
22
+ const float* alpha,
23
+ at::BFloat16* A,
24
+ int lda,
25
+ at::BFloat16* B,
26
+ int ldb,
27
+ const float* beta,
28
+ float* C,
29
+ int ldc) {
30
+ return cublasGemmEx(
31
+ handle,
32
+ transa,
33
+ transb,
34
+ m,
35
+ n,
36
+ k,
37
+ alpha,
38
+ A,
39
+ CUDA_R_16BF,
40
+ lda,
41
+ B,
42
+ CUDA_R_16BF,
43
+ ldb,
44
+ beta,
45
+ C,
46
+ CUDA_R_32F,
47
+ ldc,
48
+ CUDA_R_32F,
49
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP);
50
+ }
51
+
52
+ // FP16 Tensor core wrapper around cublas GEMMEx
53
+ cublasStatus_t gemmex_wrapper(
54
+ cublasHandle_t handle,
55
+ cublasOperation_t transa,
56
+ cublasOperation_t transb,
57
+ int m,
58
+ int n,
59
+ int k,
60
+ const float* alpha,
61
+ at::Half* A,
62
+ int lda,
63
+ at::Half* B,
64
+ int ldb,
65
+ const float* beta,
66
+ float* C,
67
+ int ldc) {
68
+ return cublasGemmEx(
69
+ handle,
70
+ transa,
71
+ transb,
72
+ m,
73
+ n,
74
+ k,
75
+ alpha,
76
+ A,
77
+ CUDA_R_16F,
78
+ lda,
79
+ B,
80
+ CUDA_R_16F,
81
+ ldb,
82
+ beta,
83
+ C,
84
+ CUDA_R_32F,
85
+ ldc,
86
+ CUDA_R_32F,
87
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP);
88
+ }
89
+
90
+ // FP32 Tensor core wrapper around cublas GEMMEx
91
+ cublasStatus_t gemmex_wrapper(
92
+ cublasHandle_t handle,
93
+ cublasOperation_t transa,
94
+ cublasOperation_t transb,
95
+ int m,
96
+ int n,
97
+ int k,
98
+ const float* alpha,
99
+ float* A,
100
+ int lda,
101
+ float* B,
102
+ int ldb,
103
+ const float* beta,
104
+ float* C,
105
+ int ldc) {
106
+ return cublasGemmEx(
107
+ handle,
108
+ transa,
109
+ transb,
110
+ m,
111
+ n,
112
+ k,
113
+ alpha,
114
+ A,
115
+ CUDA_R_32F,
116
+ lda,
117
+ B,
118
+ CUDA_R_32F,
119
+ ldb,
120
+ beta,
121
+ C,
122
+ CUDA_R_32F,
123
+ ldc,
124
+ CUDA_R_32F,
125
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP);
126
+ }
127
+
128
+ template <typename T>
129
+ int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
130
+ cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
131
+ cudaStream_t stream;
132
+ cublasGetStream(handle, &stream);
133
+ const float alpha = 1.0;
134
+ const float beta = 1.0;
135
+ int status = 1;
136
+
137
+ status = gemmex_wrapper(
138
+ handle,
139
+ CUBLAS_OP_N,
140
+ CUBLAS_OP_T,
141
+ in_dim,
142
+ out_dim,
143
+ hidden_dim,
144
+ &alpha,
145
+ input,
146
+ in_dim,
147
+ d_output,
148
+ out_dim,
149
+ &beta,
150
+ d_weight,
151
+ in_dim);
152
+ return status;
153
+ }
154
+
155
+ template int wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
156
+ template int wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
157
+ template int wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
megatron/fused_kernels/layer_norm_cuda.cpp ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*This code is copied fron NVIDIA apex:
18
+ * https://github.com/NVIDIA/apex
19
+ * with minor changes. */
20
+
21
+ #include <torch/extension.h>
22
+ #include <vector>
23
+ #include <cassert>
24
+ #include "compat.h"
25
+
26
+ namespace {
27
+
28
+ void compute_n1_n2(
29
+ at::Tensor input,
30
+ at::IntArrayRef normalized_shape,
31
+ int& n1,
32
+ int& n2) {
33
+ int idiff = input.ndimension() - normalized_shape.size();
34
+ n2 = 1;
35
+ for (int i = 0; i < (int)normalized_shape.size(); ++i) {
36
+ assert( input.sizes()[i+idiff] == normalized_shape[i] );
37
+ n2 *= normalized_shape[i];
38
+ }
39
+ n1 = 1;
40
+ for (int i = 0; i < idiff; ++i) {
41
+ n1 *= input.sizes()[i];
42
+ }
43
+ }
44
+
45
+ void check_args(
46
+ at::IntArrayRef normalized_shape,
47
+ at::Tensor gamma,
48
+ at::Tensor beta
49
+ )
50
+ {
51
+ TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
52
+ TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
53
+ }
54
+
55
+ void check_args(
56
+ at::Tensor input,
57
+ at::IntArrayRef normalized_shape,
58
+ int& n1,
59
+ int& n2
60
+ )
61
+ {
62
+ int64_t normalized_ndim = normalized_shape.size();
63
+
64
+ if (normalized_ndim < 1) {
65
+ std::stringstream ss;
66
+ ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
67
+ << "containing at least one element, but got normalized_shape="
68
+ << normalized_shape;
69
+ throw std::runtime_error(ss.str());
70
+ }
71
+
72
+ auto input_shape = input.sizes();
73
+ auto input_ndim = input.dim();
74
+
75
+ if (input_ndim < normalized_ndim ||
76
+ !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
77
+ std::stringstream ss;
78
+ ss << "Given normalized_shape=" << normalized_shape
79
+ << ", expected input with shape [*";
80
+ for (auto size : normalized_shape) {
81
+ ss << ", " << size;
82
+ }
83
+ ss << "], but got input of size" << input_shape;
84
+ throw std::runtime_error(ss.str());
85
+ }
86
+
87
+ compute_n1_n2(input,normalized_shape,n1,n2);
88
+ }
89
+
90
+
91
+ void check_args(
92
+ at::Tensor input,
93
+ at::IntArrayRef normalized_shape,
94
+ at::Tensor gamma,
95
+ at::Tensor beta,
96
+ int& n1,
97
+ int& n2
98
+ )
99
+ {
100
+ check_args(input,normalized_shape,n1,n2);
101
+ check_args(normalized_shape,gamma,beta);
102
+ }
103
+ }
104
+
105
+ void cuda_layer_norm(
106
+ at::Tensor* output,
107
+ at::Tensor* mean,
108
+ at::Tensor* invvar,
109
+ at::Tensor* input,
110
+ int n1,
111
+ int n2,
112
+ at::IntArrayRef normalized_shape,
113
+ at::Tensor* gamma,
114
+ at::Tensor* beta,
115
+ double epsilon);
116
+
117
+ #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
118
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
119
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
120
+
121
+ std::vector<at::Tensor> layer_norm_affine(
122
+ at::Tensor input,
123
+ at::IntArrayRef normalized_shape,
124
+ at::Tensor gamma,
125
+ at::Tensor beta,
126
+ double epsilon) {
127
+
128
+ CHECK_INPUT(input);
129
+ CHECK_INPUT(gamma);
130
+ CHECK_INPUT(beta);
131
+ int n1, n2;
132
+ check_args(input, normalized_shape, gamma, beta, n1, n2);
133
+
134
+ at::Tensor output = at::empty_like(
135
+ input, gamma.options().dtype(gamma.scalar_type()));
136
+ at::Tensor mean = at::empty(
137
+ {n1}, input.options().dtype(at::ScalarType::Float));
138
+ at::Tensor invvar = at::empty_like(mean);
139
+
140
+ cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
141
+ normalized_shape, &gamma, &beta, epsilon);
142
+
143
+ return {output, mean, invvar};
144
+
145
+ }
146
+
147
+
148
+ void cuda_layer_norm_gradient(
149
+ at::Tensor* dout,
150
+ at::Tensor* mean,
151
+ at::Tensor* invvar,
152
+ at::Tensor* input,
153
+ int n1,
154
+ int n2,
155
+ at::IntArrayRef normalized_shape,
156
+ at::Tensor* gamma,
157
+ at::Tensor* beta,
158
+ double epsilon,
159
+ at::Tensor* grad_input,
160
+ at::Tensor* grad_gamma,
161
+ at::Tensor* grad_beta
162
+ );
163
+
164
+ std::vector<at::Tensor> layer_norm_gradient_affine(
165
+ at::Tensor dout,
166
+ at::Tensor mean,
167
+ at::Tensor invvar,
168
+ at::Tensor input,
169
+ at::IntArrayRef normalized_shape,
170
+ at::Tensor gamma,
171
+ at::Tensor beta,
172
+ double epsilon) {
173
+
174
+ CHECK_INPUT(dout);
175
+ CHECK_INPUT(mean);
176
+ CHECK_INPUT(invvar);
177
+ CHECK_INPUT(input);
178
+ CHECK_INPUT(gamma);
179
+ CHECK_INPUT(beta);
180
+ int n1, n2;
181
+ check_args(input, normalized_shape, gamma, beta, n1, n2);
182
+
183
+ at::Tensor grad_input = at::empty_like(input);
184
+ at::Tensor grad_gamma = at::empty_like(gamma);
185
+ at::Tensor grad_beta = at::empty_like(beta);
186
+
187
+ cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
188
+ normalized_shape, &gamma, &beta, epsilon,
189
+ &grad_input, &grad_gamma, &grad_beta);
190
+
191
+ return {grad_input, grad_gamma, grad_beta};
192
+
193
+ }
194
+
195
+
196
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
197
+ m.def("forward_affine", &layer_norm_affine,
198
+ "LayerNorm forward (CUDA)");
199
+ m.def("backward_affine", &layer_norm_gradient_affine,
200
+ "LayerNorm backward (CUDA)");
201
+ }
megatron/fused_kernels/layer_norm_cuda_kernel.cu ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*This code is copied fron NVIDIA apex:
18
+ * https://github.com/NVIDIA/apex
19
+ * with minor changes. */
20
+
21
+ #include "ATen/ATen.h"
22
+ #include "ATen/AccumulateType.h"
23
+ #include "ATen/cuda/CUDAContext.h"
24
+ #include "ATen/cuda/DeviceUtils.cuh"
25
+
26
+ #include <cuda.h>
27
+ #include <cuda_runtime.h>
28
+
29
+ #include "type_shim.h"
30
+
31
+ template<typename U> __device__
32
+ void cuWelfordOnlineSum(
33
+ const U curr,
34
+ U& mu,
35
+ U& sigma2,
36
+ U& count)
37
+ {
38
+ count = count + U(1);
39
+ U delta = curr - mu;
40
+ U lmean = mu + delta / count;
41
+ mu = lmean;
42
+ U delta2 = curr - lmean;
43
+ sigma2 = sigma2 + delta * delta2;
44
+ }
45
+
46
+ template<typename U> __device__
47
+ void cuChanOnlineSum(
48
+ const U muB,
49
+ const U sigma2B,
50
+ const U countB,
51
+ U& mu,
52
+ U& sigma2,
53
+ U& count)
54
+ {
55
+ U delta = muB - mu;
56
+ U nA = count;
57
+ U nB = countB;
58
+ count = count + countB;
59
+ U nX = count;
60
+ if (nX > U(0)) {
61
+ nA = nA / nX;
62
+ nB = nB / nX;
63
+ mu = nA*mu + nB*muB;
64
+ sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
65
+ } else {
66
+ mu = U(0);
67
+ sigma2 = U(0);
68
+ }
69
+ }
70
+
71
+ template<typename T, typename U> __device__
72
+ void cuWelfordMuSigma2(
73
+ const T* __restrict__ vals,
74
+ const int n1,
75
+ const int n2,
76
+ const int i1,
77
+ U& mu,
78
+ U& sigma2,
79
+ U* buf)
80
+ {
81
+ // Assumptions:
82
+ // 1) blockDim.x == warpSize
83
+ // 2) Tensor is contiguous
84
+ // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
85
+ //
86
+ // compute variance and mean over n2
87
+ U count = U(0);
88
+ mu= U(0);
89
+ sigma2 = U(0);
90
+ if (i1 < n1) {
91
+ // one warp normalizes one n1 index,
92
+ // synchronization is implicit
93
+ // initialize with standard Welford algorithm
94
+ const int numx = blockDim.x * blockDim.y;
95
+ const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
96
+ const T* lvals = vals + i1*n2;
97
+ int l = 4*thrx;
98
+ for (; l+3 < n2; l+=4*numx) {
99
+ for (int k = 0; k < 4; ++k) {
100
+ U curr = static_cast<U>(lvals[l+k]);
101
+ cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
102
+ }
103
+ }
104
+ for (; l < n2; ++l) {
105
+ U curr = static_cast<U>(lvals[l]);
106
+ cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
107
+ }
108
+ // intra-warp reductions
109
+ for (int l = 0; l <= 4; ++l) {
110
+ int srcLaneB = (threadIdx.x+(1<<l))&31;
111
+ U muB = WARP_SHFL(mu, srcLaneB);
112
+ U countB = WARP_SHFL(count, srcLaneB);
113
+ U sigma2B = WARP_SHFL(sigma2, srcLaneB);
114
+ cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
115
+ }
116
+ // threadIdx.x == 0 has correct values for each warp
117
+ // inter-warp reductions
118
+ if (blockDim.y > 1) {
119
+ U* ubuf = (U*)buf;
120
+ U* ibuf = (U*)(ubuf + blockDim.y);
121
+ for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
122
+ // upper half of warps write to shared
123
+ if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
124
+ const int wrt_y = threadIdx.y - offset;
125
+ ubuf[2*wrt_y] = mu;
126
+ ubuf[2*wrt_y+1] = sigma2;
127
+ ibuf[wrt_y] = count;
128
+ }
129
+ __syncthreads();
130
+ // lower half merges
131
+ if (threadIdx.x == 0 && threadIdx.y < offset) {
132
+ U muB = ubuf[2*threadIdx.y];
133
+ U sigma2B = ubuf[2*threadIdx.y+1];
134
+ U countB = ibuf[threadIdx.y];
135
+ cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
136
+ }
137
+ __syncthreads();
138
+ }
139
+ // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
140
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
141
+ ubuf[0] = mu;
142
+ ubuf[1] = sigma2;
143
+ }
144
+ __syncthreads();
145
+ mu = ubuf[0];
146
+ sigma2 = ubuf[1]/U(n2);
147
+ // don't care about final value of count, we know count == n2
148
+ } else {
149
+ mu = WARP_SHFL(mu, 0);
150
+ sigma2 = WARP_SHFL(sigma2/U(n2), 0);
151
+ }
152
+ }
153
+ }
154
+
155
+ template<> __device__
156
+ void cuWelfordMuSigma2(
157
+ const at::Half* __restrict__ vals,
158
+ const int n1,
159
+ const int n2,
160
+ const int i1,
161
+ float& mu,
162
+ float& sigma2,
163
+ float* buf)
164
+ {
165
+ // Assumptions:
166
+ // 1) blockDim.x == warpSize
167
+ // 2) Tensor is contiguous
168
+ // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
169
+ //
170
+ // compute variance and mean over n2
171
+ float count = 0.0f;
172
+ mu= float(0);
173
+ sigma2 = float(0);
174
+ if (i1 < n1) {
175
+ // one warp normalizes one n1 index,
176
+ // synchronization is implicit
177
+ // initialize with standard Welford algorithm
178
+ const int numx = blockDim.x * blockDim.y;
179
+ const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
180
+ const at::Half* lvals = vals + i1*n2;
181
+ int l = 8*thrx;
182
+ if ((((size_t)lvals)&3) != 0) {
183
+ // 16 bit alignment
184
+ // first thread consumes first point
185
+ if (thrx == 0) {
186
+ float curr = static_cast<float>(lvals[0]);
187
+ cuWelfordOnlineSum(curr,mu,sigma2,count);
188
+ }
189
+ ++l;
190
+ }
191
+ // at this point, lvals[l] are 32 bit aligned for all threads.
192
+ for (; l+7 < n2; l+=8*numx) {
193
+ for (int k = 0; k < 8; k+=2) {
194
+ float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
195
+ cuWelfordOnlineSum(curr.x,mu,sigma2,count);
196
+ cuWelfordOnlineSum(curr.y,mu,sigma2,count);
197
+ }
198
+ }
199
+ for (; l < n2; ++l) {
200
+ float curr = static_cast<float>(lvals[l]);
201
+ cuWelfordOnlineSum(curr,mu,sigma2,count);
202
+ }
203
+ // intra-warp reductions
204
+ for (int l = 0; l <= 4; ++l) {
205
+ int srcLaneB = (threadIdx.x+(1<<l))&31;
206
+ float muB = WARP_SHFL(mu, srcLaneB);
207
+ float countB = WARP_SHFL(count, srcLaneB);
208
+ float sigma2B = WARP_SHFL(sigma2, srcLaneB);
209
+ cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
210
+ }
211
+ // threadIdx.x == 0 has correct values for each warp
212
+ // inter-warp reductions
213
+ if (blockDim.y > 1) {
214
+ float* ubuf = (float*)buf;
215
+ float* ibuf = (float*)(ubuf + blockDim.y);
216
+ for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
217
+ // upper half of warps write to shared
218
+ if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
219
+ const int wrt_y = threadIdx.y - offset;
220
+ ubuf[2*wrt_y] = mu;
221
+ ubuf[2*wrt_y+1] = sigma2;
222
+ ibuf[wrt_y] = count;
223
+ }
224
+ __syncthreads();
225
+ // lower half merges
226
+ if (threadIdx.x == 0 && threadIdx.y < offset) {
227
+ float muB = ubuf[2*threadIdx.y];
228
+ float sigma2B = ubuf[2*threadIdx.y+1];
229
+ float countB = ibuf[threadIdx.y];
230
+ cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
231
+ }
232
+ __syncthreads();
233
+ }
234
+ // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
235
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
236
+ ubuf[0] = mu;
237
+ ubuf[1] = sigma2;
238
+ }
239
+ __syncthreads();
240
+ mu = ubuf[0];
241
+ sigma2 = ubuf[1]/float(n2);
242
+ // don't care about final value of count, we know count == n2
243
+ } else {
244
+ mu = WARP_SHFL(mu, 0);
245
+ sigma2 = WARP_SHFL(sigma2/float(n2), 0);
246
+ }
247
+ }
248
+ }
249
+
250
+ template<typename U> U rsqrt(U v) {
251
+ return U(1) / sqrt(v);
252
+ }
253
+ template<> float rsqrt(float v) {
254
+ return rsqrtf(v);
255
+ }
256
+ template<> double rsqrt(double v) {
257
+ return rsqrt(v);
258
+ }
259
+
260
+ namespace {
261
+ // This is the un-specialized struct. Note that we prevent instantiation of this
262
+ // struct by putting an undefined symbol in the function body so it won't compile.
263
+ // template <typename T>
264
+ // struct SharedMemory
265
+ // {
266
+ // // Ensure that we won't compile any un-specialized types
267
+ // __device__ T *getPointer()
268
+ // {
269
+ // extern __device__ void error(void);
270
+ // error();
271
+ // return NULL;
272
+ // }
273
+ // };
274
+ // https://github.com/NVIDIA/apex/issues/246
275
+ template <typename T>
276
+ struct SharedMemory;
277
+
278
+ template <>
279
+ struct SharedMemory <float>
280
+ {
281
+ __device__ float *getPointer()
282
+ {
283
+ extern __shared__ float s_float[];
284
+ return s_float;
285
+ }
286
+ };
287
+
288
+ }
289
+
290
+ template<typename T, typename U, typename V> __global__
291
+ void cuApplyLayerNorm(
292
+ V* __restrict__ output_vals,
293
+ U* __restrict__ mean,
294
+ U* __restrict__ invvar,
295
+ const T* __restrict__ vals,
296
+ const int n1,
297
+ const int n2,
298
+ const U epsilon,
299
+ const V* __restrict__ gamma,
300
+ const V* __restrict__ beta
301
+ )
302
+ {
303
+ // Assumptions:
304
+ // 1) blockDim.x == warpSize
305
+ // 2) Tensors are contiguous
306
+ //
307
+ for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
308
+ SharedMemory<U> shared;
309
+ U* buf = shared.getPointer();
310
+ U mu,sigma2;
311
+ cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
312
+ const T* lvals = vals + i1*n2;
313
+ V* ovals = output_vals + i1*n2;
314
+ U c_invvar = rsqrt(sigma2 + epsilon);
315
+ const int numx = blockDim.x * blockDim.y;
316
+ const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
317
+ if (gamma != NULL && beta != NULL) {
318
+ for (int i = thrx; i < n2; i+=numx) {
319
+ U curr = static_cast<U>(lvals[i]);
320
+ ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
321
+ }
322
+ } else {
323
+ for (int i = thrx; i < n2; i+=numx) {
324
+ U curr = static_cast<U>(lvals[i]);
325
+ ovals[i] = static_cast<V>(c_invvar * (curr - mu));
326
+ }
327
+ }
328
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
329
+ mean[i1] = mu;
330
+ invvar[i1] = c_invvar;
331
+ }
332
+ __syncthreads();
333
+ }
334
+ }
335
+
336
+ template<typename T, typename U, typename V> __device__
337
+ void cuLoadWriteStridedInputs(
338
+ const int i1_block,
339
+ const int thr_load_row_off,
340
+ const int thr_load_col_off,
341
+ const int i2_off,
342
+ const int row_stride,
343
+ U* warp_buf1,
344
+ U* warp_buf2,
345
+ const T* input,
346
+ const V* dout,
347
+ const int i1_end,
348
+ const int n2,
349
+ const U* __restrict__ mean,
350
+ const U* __restrict__ invvar
351
+ )
352
+ {
353
+ int i1 = i1_block+thr_load_row_off;
354
+ if (i1 < i1_end) {
355
+ U curr_mean = mean[i1];
356
+ U curr_invvar = invvar[i1];
357
+ for (int k = 0; k < blockDim.y; ++k) {
358
+ int i2 = i2_off + k;
359
+ int load_idx = i1*n2+i2;
360
+ int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
361
+ if (i2<n2) {
362
+ U curr_input = static_cast<U>(input[load_idx]);
363
+ U curr_dout = static_cast<U>(dout[load_idx]);
364
+ warp_buf1[write_idx] = curr_dout;
365
+ warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
366
+ } else {
367
+ warp_buf1[write_idx] = U(0);
368
+ warp_buf2[write_idx] = U(0);
369
+ }
370
+ }
371
+ } else {
372
+ for (int k = 0; k < blockDim.y; ++k) {
373
+ int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
374
+ warp_buf1[write_idx] = U(0);
375
+ warp_buf2[write_idx] = U(0);
376
+ }
377
+ }
378
+ }
379
+
380
+ template<typename T, typename U, typename V> __device__
381
+ void cuLoadAddStridedInputs(
382
+ const int i1_block,
383
+ const int thr_load_row_off,
384
+ const int thr_load_col_off,
385
+ const int i2_off,
386
+ const int row_stride,
387
+ U* warp_buf1,
388
+ U* warp_buf2,
389
+ const T* input,
390
+ const V* dout,
391
+ const int i1_end,
392
+ const int n2,
393
+ const U* __restrict__ mean,
394
+ const U* __restrict__ invvar
395
+ )
396
+ {
397
+ int i1 = i1_block+thr_load_row_off;
398
+ if (i1 < i1_end) {
399
+ U curr_mean = mean[i1];
400
+ U curr_invvar = invvar[i1];
401
+ for (int k = 0; k < blockDim.y; ++k) {
402
+ int i2 = i2_off + k;
403
+ int load_idx = i1*n2+i2;
404
+ int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
405
+ if (i2<n2) {
406
+ U curr_input = static_cast<U>(input[load_idx]);
407
+ U curr_dout = static_cast<U>(dout[load_idx]);
408
+ warp_buf1[write_idx] += curr_dout;
409
+ warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
410
+ }
411
+ }
412
+ }
413
+ }
414
+
415
+ template<typename T, typename U, typename V> __global__
416
+ void cuComputePartGradGammaBeta(
417
+ const V* __restrict__ dout,
418
+ const T* __restrict__ input,
419
+ const int n1,
420
+ const int n2,
421
+ const U* __restrict__ mean,
422
+ const U* __restrict__ invvar,
423
+ U epsilon,
424
+ U* part_grad_gamma,
425
+ U* part_grad_beta)
426
+ {
427
+ const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
428
+ const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
429
+ const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
430
+ const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
431
+ const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
432
+ const int row_stride = blockDim.x+1;
433
+ const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
434
+ const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
435
+ const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
436
+ SharedMemory<U> shared;
437
+ U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
438
+ U* warp_buf1 = (U*)buf;
439
+ U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
440
+ // compute partial sums from strided inputs
441
+ // do this to increase number of loads in flight
442
+ cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
443
+ for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
444
+ cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
445
+ }
446
+ __syncthreads();
447
+ // inter-warp reductions
448
+ // sum within each warp
449
+ U acc1 = U(0);
450
+ U acc2 = U(0);
451
+ for (int k = 0; k < blockDim.y; ++k) {
452
+ int row1 = threadIdx.y + k*blockDim.y;
453
+ int idx1 = row1*row_stride + threadIdx.x;
454
+ acc1 += warp_buf1[idx1];
455
+ acc2 += warp_buf2[idx1];
456
+ }
457
+ warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
458
+ warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
459
+ __syncthreads();
460
+ // sum all warps
461
+ for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
462
+ if (threadIdx.y < offset) {
463
+ int row1 = threadIdx.y;
464
+ int row2 = threadIdx.y + offset;
465
+ int idx1 = row1*row_stride + threadIdx.x;
466
+ int idx2 = row2*row_stride + threadIdx.x;
467
+ warp_buf1[idx1] += warp_buf1[idx2];
468
+ warp_buf2[idx1] += warp_buf2[idx2];
469
+ }
470
+ __syncthreads();
471
+ }
472
+ int i2 = blockIdx.x * blockDim.x + threadIdx.x;
473
+ if (threadIdx.y == 0 && i2 < n2) {
474
+ int row1 = threadIdx.y;
475
+ int row2 = threadIdx.y + 1;
476
+ int idx1 = row1*row_stride + threadIdx.x;
477
+ int idx2 = row2*row_stride + threadIdx.x;
478
+ part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
479
+ part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
480
+ }
481
+ }
482
+
483
+ template<typename U, typename V> __global__
484
+ void cuComputeGradGammaBeta(
485
+ const U* part_grad_gamma,
486
+ const U* part_grad_beta,
487
+ const int part_size,
488
+ const int n1,
489
+ const int n2,
490
+ V* grad_gamma,
491
+ V* grad_beta)
492
+ {
493
+ // sum partial gradients for gamma and beta
494
+ SharedMemory<U> shared;
495
+ U* buf = shared.getPointer();
496
+ int i2 = blockIdx.x * blockDim.x + threadIdx.x;
497
+ if (i2 < n2) {
498
+ // each warp does sequential reductions until reduced part_size is num_warps
499
+ int num_warp_reductions = part_size / blockDim.y;
500
+ U sum_gamma = U(0);
501
+ U sum_beta = U(0);
502
+ const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
503
+ const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
504
+ for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
505
+ sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
506
+ sum_beta += part_grad_beta_ptr[warp_offset*n2];
507
+ }
508
+ // inter-warp reductions
509
+ const int nbsize3 = blockDim.x * blockDim.y / 2;
510
+ for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
511
+ // top half write to shared memory
512
+ if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
513
+ const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
514
+ buf[write_idx] = sum_gamma;
515
+ buf[write_idx+nbsize3] = sum_beta;
516
+ }
517
+ __syncthreads();
518
+ // bottom half sums
519
+ if (threadIdx.y < offset) {
520
+ const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
521
+ sum_gamma += buf[read_idx];
522
+ sum_beta += buf[read_idx+nbsize3];
523
+ }
524
+ __syncthreads();
525
+ }
526
+ // write out fully summed gradients
527
+ if (threadIdx.y == 0) {
528
+ grad_gamma[i2] = sum_gamma;
529
+ grad_beta[i2] = sum_beta;
530
+ }
531
+ }
532
+ }
533
+
534
+ template<typename T, typename U, typename V> __global__
535
+ void cuComputeGradInput(
536
+ const V* __restrict__ dout,
537
+ const T* __restrict__ input,
538
+ const int n1,
539
+ const int n2,
540
+ const U* __restrict__ mean,
541
+ const U* __restrict__ invvar,
542
+ U epsilon,
543
+ const V* gamma,
544
+ T* grad_input)
545
+ {
546
+ for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
547
+ U sum_loss1 = U(0);
548
+ U sum_loss2 = U(0);
549
+ const U c_mean = mean[i1];
550
+ const U c_invvar = invvar[i1];
551
+ const T* k_input = input + i1*n2;
552
+ const V* k_dout = dout + i1*n2;
553
+ const int numx = blockDim.x * blockDim.y;
554
+ const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
555
+ if (gamma != NULL) {
556
+ int l = 4*thrx;
557
+ for (; l+3 < n2; l+=4*numx) {
558
+ for (int k = 0; k < 4; ++k) {
559
+ const U c_h = static_cast<U>(k_input[l+k]);
560
+ const U c_loss = static_cast<U>(k_dout[l+k]);
561
+ sum_loss1 += c_loss * gamma[l+k];
562
+ sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
563
+ }
564
+ }
565
+ for (; l < n2; ++l) {
566
+ const U c_h = static_cast<U>(k_input[l]);
567
+ const U c_loss = static_cast<U>(k_dout[l]);
568
+ sum_loss1 += c_loss * gamma[l];
569
+ sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
570
+ }
571
+ } else {
572
+ int l = 4*thrx;
573
+ for (; l+3 < n2; l+=4*numx) {
574
+ for (int k = 0; k < 4; ++k) {
575
+ const U c_h = static_cast<U>(k_input[l+k]);
576
+ const U c_loss = static_cast<U>(k_dout[l+k]);
577
+ sum_loss1 += c_loss;
578
+ sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
579
+ }
580
+ }
581
+ for (; l < n2; ++l) {
582
+ const U c_h = static_cast<U>(k_input[l]);
583
+ const U c_loss = static_cast<U>(k_dout[l]);
584
+ sum_loss1 += c_loss;
585
+ sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
586
+ }
587
+ }
588
+ // intra-warp reductions
589
+ for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
590
+ sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
591
+ sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
592
+ }
593
+ // inter-warp reductions
594
+ if (blockDim.y > 1) {
595
+ SharedMemory<U> shared;
596
+ U* buf = shared.getPointer();
597
+ for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
598
+ // upper half of warps write to shared
599
+ if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
600
+ const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
601
+ buf[2*wrt_i] = sum_loss1;
602
+ buf[2*wrt_i+1] = sum_loss2;
603
+ }
604
+ __syncthreads();
605
+ // lower half merges
606
+ if (threadIdx.y < offset) {
607
+ const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
608
+ sum_loss1 += buf[2*read_i];
609
+ sum_loss2 += buf[2*read_i+1];
610
+ }
611
+ __syncthreads();
612
+ }
613
+ if (threadIdx.y == 0) {
614
+ buf[2*threadIdx.x] = sum_loss1;
615
+ buf[2*threadIdx.x+1] = sum_loss2;
616
+ }
617
+ __syncthreads();
618
+ if (threadIdx.y !=0) {
619
+ sum_loss1 = buf[2*threadIdx.x];
620
+ sum_loss2 = buf[2*threadIdx.x+1];
621
+ }
622
+ }
623
+ // all threads now have the two sums over l
624
+ U fH = (U)n2;
625
+ U term1 = (U(1) / fH) * c_invvar;
626
+ T* k_grad_input = grad_input + i1*n2;
627
+ if (gamma != NULL) {
628
+ for (int l = thrx; l < n2; l+=numx) {
629
+ const U c_h = static_cast<U>(k_input[l]);
630
+ const U c_loss = static_cast<U>(k_dout[l]);
631
+ U f_grad_input = fH * c_loss * gamma[l];
632
+ f_grad_input -= sum_loss1;
633
+ f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
634
+ f_grad_input *= term1;
635
+ k_grad_input[l] = static_cast<T>(f_grad_input);
636
+ }
637
+ } else {
638
+ for (int l = thrx; l < n2; l+=numx) {
639
+ const U c_h = static_cast<U>(k_input[l]);
640
+ const U c_loss = static_cast<U>(k_dout[l]);
641
+ U f_grad_input = fH * c_loss;
642
+ f_grad_input -= sum_loss1;
643
+ f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
644
+ f_grad_input *= term1;
645
+ k_grad_input[l] = static_cast<T>(f_grad_input);
646
+ }
647
+ }
648
+ // prevent race where buf is written again before reads are done
649
+ __syncthreads();
650
+ }
651
+ }
652
+
653
+
654
+
655
+
656
+ template<typename T, typename U, typename V>
657
+ void HostApplyLayerNorm(
658
+ V* output,
659
+ U* mean,
660
+ U* invvar,
661
+ const T* input,
662
+ int n1,
663
+ int n2,
664
+ double epsilon,
665
+ const V* gamma,
666
+ const V* beta
667
+ )
668
+ {
669
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
670
+ const dim3 threads(32,4,1);
671
+ const uint64_t maxGridY =
672
+ at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
673
+ const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
674
+ int nshared =
675
+ threads.y > 1 ?
676
+ threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
677
+ 0;
678
+ cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
679
+ output,
680
+ mean,
681
+ invvar,
682
+ input,
683
+ n1,n2,
684
+ U(epsilon),
685
+ gamma,beta);
686
+ }
687
+
688
+
689
+ void cuda_layer_norm(
690
+ at::Tensor* output,
691
+ at::Tensor* mean,
692
+ at::Tensor* invvar,
693
+ at::Tensor* input,
694
+ int n1,
695
+ int n2,
696
+ #ifdef VERSION_GE_1_1
697
+ at::IntArrayRef normalized_shape,
698
+ #else
699
+ at::IntList normalized_shape,
700
+ #endif
701
+ at::Tensor* gamma,
702
+ at::Tensor* beta,
703
+ double epsilon)
704
+ {
705
+ using namespace at;
706
+ DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
707
+ input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
708
+ HostApplyLayerNorm(
709
+ output->DATA_PTR<scalar_t_out>(),
710
+ mean->DATA_PTR<float>(),
711
+ invvar->DATA_PTR<float>(),
712
+ input->DATA_PTR<scalar_t_in>(),
713
+ n1,n2,
714
+ epsilon,
715
+ gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
716
+ beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
717
+ )
718
+ }
719
+
720
+
721
+ template<typename T, typename U, typename V>
722
+ void HostLayerNormGradient(
723
+ const V* dout,
724
+ const U* mean,
725
+ const U* invvar,
726
+ at::Tensor* input,
727
+ int n1,
728
+ int n2,
729
+ const V* gamma,
730
+ const V* beta,
731
+ double epsilon,
732
+ T* grad_input,
733
+ V* grad_gamma,
734
+ V* grad_beta
735
+ )
736
+ {
737
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
738
+
739
+ if (gamma != NULL && beta != NULL) {
740
+ // compute grad_gamma(j) and grad_beta(j)
741
+ const int part_size = 16;
742
+ const dim3 threads2(32,4,1);
743
+ const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
744
+ const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
745
+ (threads2.x + 1);
746
+ const int nshared2_b = threads2.x * threads2.y * sizeof(U);
747
+ const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
748
+ at::Tensor part_grad_gamma = at::empty(
749
+ {part_size,n2}, input->options().dtype(at::ScalarType::Float));
750
+ at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
751
+ cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
752
+ dout,
753
+ input->DATA_PTR<T>(),
754
+ n1,n2,
755
+ mean,
756
+ invvar,
757
+ U(epsilon),
758
+ part_grad_gamma.DATA_PTR<U>(),
759
+ part_grad_beta.DATA_PTR<U>());
760
+
761
+ const dim3 threads3(32,8,1);
762
+ const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
763
+ const int nshared3 = threads3.x * threads3.y * sizeof(U);
764
+ cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
765
+ part_grad_gamma.DATA_PTR<U>(),
766
+ part_grad_beta.DATA_PTR<U>(),
767
+ part_size,
768
+ n1,n2,
769
+ grad_gamma,
770
+ grad_beta);
771
+ }
772
+
773
+ // compute grad_input
774
+ const uint64_t maxGridY =
775
+ at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
776
+ const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
777
+ const dim3 threads1(32,4,1);
778
+ int nshared =
779
+ threads1.y > 1 ?
780
+ threads1.y*threads1.x*sizeof(U) :
781
+ 0;
782
+ cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
783
+ dout,
784
+ input->DATA_PTR<T>(),
785
+ n1,n2,
786
+ mean,
787
+ invvar,
788
+ U(epsilon),
789
+ gamma,
790
+ grad_input);
791
+ }
792
+
793
+
794
+ void cuda_layer_norm_gradient(
795
+ at::Tensor* dout,
796
+ at::Tensor* mean,
797
+ at::Tensor* invvar,
798
+ at::Tensor* input,
799
+ int n1,
800
+ int n2,
801
+ #ifdef VERSION_GE_1_1
802
+ at::IntArrayRef normalized_shape,
803
+ #else
804
+ at::IntList normalized_shape,
805
+ #endif
806
+ at::Tensor* gamma,
807
+ at::Tensor* beta,
808
+ double epsilon,
809
+ at::Tensor* grad_input,
810
+ at::Tensor* grad_gamma,
811
+ at::Tensor* grad_beta)
812
+ {
813
+ using namespace at;
814
+ DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
815
+ input->scalar_type(), gamma->scalar_type(),
816
+ "cuda_layer_norm_gradient_kernel",
817
+ HostLayerNormGradient(
818
+ dout->DATA_PTR<scalar_t_out>(),
819
+ mean->DATA_PTR<float>(),
820
+ invvar->DATA_PTR<float>(),
821
+ input,
822
+ n1,n2,
823
+ // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
824
+ // if gamma Tensor is NULL on input.
825
+ gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
826
+ gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
827
+ epsilon,
828
+ grad_input->DATA_PTR<scalar_t_in>(),
829
+ gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
830
+ gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
831
+ )
832
+ }