Shawn001 commited on
Commit
1101a21
1 Parent(s): c2c125c

Upload 21 files

Browse files
tools/bert-vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
tools/checkpoint_loader_megatron.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import types
5
+
6
+ import torch
7
+
8
+ def add_arguments(parser):
9
+ group = parser.add_argument_group(title='Megatron loader')
10
+
11
+ group.add_argument('--true-vocab-size', type=int, default=None,
12
+ help='original size of vocab, if specified will trim padding from embedding table.')
13
+ group.add_argument('--vocab-file', type=str, default=None,
14
+ help='Path to the vocab file. If specified will use this to get vocab size and '
15
+ 'trim padding from the embedding table.')
16
+ group.add_argument('--megatron-path', type=str, default=None,
17
+ help='Base directory of deepspeed repository')
18
+
19
+ def _load_checkpoint(queue, args):
20
+
21
+ # Search in directory above this
22
+ sys.path.append(os.path.abspath(
23
+ os.path.join(os.path.dirname(__file__),
24
+ os.path.pardir)))
25
+ if args.megatron_path is not None:
26
+ sys.path.insert(0, args.megatron_path)
27
+
28
+ try:
29
+ from megatron.arguments import parse_args, validate_args
30
+ from megatron.global_vars import set_args, set_global_variables
31
+ from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
32
+ from megatron.model import ModelType, module
33
+ from megatron import mpu, fused_kernels
34
+ except ModuleNotFoundError:
35
+ print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
36
+ queue.put("exit")
37
+ exit(1)
38
+
39
+ # We want all arguments to come from us
40
+ sys.argv = ['script.py',
41
+ '--no-masked-softmax-fusion',
42
+ '--no-bias-gelu-fusion',
43
+ '--no-bias-dropout-fusion',
44
+ '--use-cpu-initialization',
45
+ '--micro-batch-size', '1',
46
+ '--no-load-optim',
47
+ '--no-load-rng',
48
+ '--no-save-optim',
49
+ '--no-save-rng',
50
+ '--no-initialization',
51
+ '--load', args.load_dir
52
+ ]
53
+
54
+ margs = parse_args()
55
+ margs = load_args_from_checkpoint(margs)
56
+
57
+ # Arguments do sanity checks on the world size, but we don't care,
58
+ # so trick it into thinking we are plenty of processes
59
+ margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size
60
+
61
+ margs = validate_args(margs)
62
+
63
+ def check_for_arg(arg_name):
64
+ if getattr(margs, arg_name, None) is None:
65
+ print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
66
+ print(f"Arguments: {margs}")
67
+ queue.put("exit")
68
+ exit(1)
69
+
70
+ check_for_arg('tensor_model_parallel_size')
71
+ check_for_arg('pipeline_model_parallel_size')
72
+ check_for_arg('num_layers')
73
+ check_for_arg('hidden_size')
74
+ check_for_arg('seq_length')
75
+ check_for_arg('num_attention_heads')
76
+ check_for_arg('max_position_embeddings')
77
+ check_for_arg('tokenizer_type')
78
+ check_for_arg('iteration')
79
+ check_for_arg('bert_binary_head')
80
+ check_for_arg('params_dtype')
81
+
82
+ # Determine how to make our models
83
+ if args.model_type == 'GPT':
84
+ from pretrain_gpt import model_provider
85
+ margs.model_type = ModelType.encoder_or_decoder
86
+ elif args.model_type == 'BERT':
87
+ from pretrain_bert import model_provider
88
+ margs.model_type = ModelType.encoder_or_decoder
89
+ else:
90
+ raise Exception(f'unrecognized model type: {args.model_type}')
91
+
92
+ # supress warning about torch.distributed not being initialized
93
+ module.MegatronModule.embedding_warning_printed = True
94
+
95
+ consumed_train_samples = None
96
+ consumed_valid_samples = None
97
+ def get_models(count, dtype, pre_process, post_process):
98
+ nonlocal consumed_train_samples
99
+ nonlocal consumed_valid_samples
100
+ models = []
101
+ for rank in range(count):
102
+ mpu.initialize.set_tensor_model_parallel_rank(rank)
103
+ model_ = [model_provider(pre_process, post_process).to(dtype)]
104
+ margs.consumed_train_samples = 0
105
+ margs.consumed_valid_samples = 0
106
+ load_checkpoint(model_, None, None)
107
+ assert(len(model_) == 1)
108
+ model_ = model_[0]
109
+ if consumed_train_samples is not None:
110
+ assert(margs.consumed_train_samples == consumed_train_samples)
111
+ else:
112
+ consumed_train_samples = margs.consumed_train_samples
113
+ if consumed_valid_samples is not None:
114
+ assert(margs.consumed_valid_samples == consumed_valid_samples)
115
+ else:
116
+ consumed_valid_samples = margs.consumed_valid_samples
117
+ models.append(model_)
118
+ return models
119
+
120
+ if margs.num_layers_per_virtual_pipeline_stage is not None:
121
+ print("Model with an interleaved pipeline schedule are not yet supported.")
122
+ queue.put("exit")
123
+ exit(1)
124
+
125
+ set_global_variables(margs)
126
+ mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
127
+ mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
128
+ fused_kernels.load(margs)
129
+
130
+ # Get true (non-padded) vocab size
131
+ if args.true_vocab_size is not None:
132
+ true_vocab_size = args.true_vocab_size
133
+ elif args.vocab_file is not None:
134
+ vocab = json.load(open(args.vocab_file))
135
+ true_vocab_size = len(vocab)
136
+ if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size:
137
+ print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.")
138
+ queue.put("exit")
139
+ exit(1)
140
+ else:
141
+ true_vocab_size = None
142
+
143
+ # short aliases
144
+ tp_size = margs.tensor_model_parallel_size
145
+ pp_size = margs.pipeline_model_parallel_size
146
+
147
+ # metadata
148
+ md = types.SimpleNamespace()
149
+ md.model_type = args.model_type
150
+ md.num_layers = margs.num_layers
151
+ md.hidden_size = margs.hidden_size
152
+ md.seq_length = margs.seq_length
153
+ md.num_attention_heads = margs.num_attention_heads
154
+ md.max_position_embeddings = margs.max_position_embeddings
155
+ md.tokenizer_type = margs.tokenizer_type
156
+ md.iteration = margs.iteration
157
+ md.params_dtype = margs.params_dtype
158
+ md.bert_binary_head = margs.bert_binary_head
159
+ md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
160
+ md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
161
+ md.true_vocab_size = true_vocab_size
162
+ md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
163
+
164
+ # Get first pipe stage
165
+ mpu.initialize.set_pipeline_model_parallel_rank(0)
166
+ post_process = pp_size == 1
167
+ models = get_models(tp_size, md.params_dtype, True, post_process)
168
+
169
+ md.consumed_train_samples = consumed_train_samples
170
+ md.consumed_valid_samples = consumed_valid_samples
171
+ queue.put(md)
172
+
173
+ def queue_put(name, msg):
174
+ print(f"sending {name}")
175
+ msg["name"] = name
176
+ queue.put(msg)
177
+
178
+ # Send embeddings
179
+ message = {
180
+ "position embeddings": models[0].language_model.embedding.position_embeddings.weight.data,
181
+ "word embeddings": torch.cat(
182
+ [models[tp_rank].language_model.embedding.word_embeddings.weight.data for tp_rank in range(tp_size)],
183
+ dim = 0)
184
+ }
185
+
186
+ queue_put("embeddings", message)
187
+
188
+ total_layer_num = 0
189
+ for pp_rank in range(pp_size):
190
+ if pp_rank > 0:
191
+ mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
192
+ post_process = pp_rank == pp_size - 1
193
+ models = get_models(tp_size, md.params_dtype, False, post_process)
194
+ for layer_num in range(len(models[0].language_model.encoder.layers)):
195
+ message = {}
196
+
197
+ # Get non-parallel tensors from tp_rank 0
198
+ layer = models[0].language_model.encoder.layers[layer_num]
199
+ message["input layernorm weight"] = layer.input_layernorm.weight.data
200
+ message["input layernorm bias"] = layer.input_layernorm.bias.data
201
+ message["dense bias"] = layer.self_attention.dense.bias.data
202
+ message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
203
+ message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
204
+ message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
205
+
206
+ # Grab all parallel tensors for this layer
207
+ qkv_weight = []
208
+ qkv_bias = []
209
+ dense_weight = []
210
+ mlp_l0_weight = []
211
+ mlp_l0_bias = []
212
+ mlp_l1_weight = []
213
+ for tp_rank, model in enumerate(models):
214
+ layer = model.language_model.encoder.layers[layer_num]
215
+ qkv_weight.append(layer.self_attention.query_key_value.weight.data)
216
+ qkv_bias.append(layer.self_attention.query_key_value.bias.data)
217
+ dense_weight.append(layer.self_attention.dense.weight.data)
218
+ mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
219
+ mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
220
+ mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
221
+
222
+ # concat them
223
+ message["qkv weight"] = torch.cat(qkv_weight, dim=0)
224
+ message["qkv bias"] = torch.cat(qkv_bias, dim=0)
225
+ message["dense weight"] = torch.cat(dense_weight, dim=1)
226
+ message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
227
+ message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
228
+ message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
229
+
230
+ queue_put(f"transformer layer {total_layer_num}", message)
231
+
232
+ total_layer_num = total_layer_num + 1
233
+
234
+ # Send final layernorm from tp_rank 0
235
+ message = {
236
+ "weight": models[0].language_model.encoder.final_layernorm.weight.data,
237
+ "bias": models[0].language_model.encoder.final_layernorm.bias.data
238
+ }
239
+ queue_put("final layernorm", message)
240
+
241
+ # Send BERT lm head and binary head if it exists
242
+ if md.model_type == 'BERT':
243
+ print("Sending LM Pooler")
244
+ message = {
245
+ "weight": models[0].language_model.pooler.dense.weight.data,
246
+ "bias": models[0].language_model.pooler.dense.bias.data
247
+ }
248
+ queue_put("pooler", message)
249
+
250
+ message = {
251
+ "dense weight": models[0].lm_head.dense.weight.data,
252
+ "dense bias": models[0].lm_head.dense.bias.data,
253
+ "layernorm weight": models[0].lm_head.layernorm.weight.data,
254
+ "layernorm bias": models[0].lm_head.layernorm.bias.data
255
+ }
256
+ queue_put("lm head", message)
257
+
258
+ if md.bert_binary_head:
259
+ print("Sending BERT Binary head")
260
+ queue.put("binary head")
261
+ message = {
262
+ "weight": models[0].binary_head.weight.data,
263
+ "bias": models[0].binary_head.bias.data
264
+ }
265
+ queue_put("binary head", message)
266
+ queue.put("done")
267
+
268
+ def load_checkpoint(queue, args):
269
+ try:
270
+ _load_checkpoint(queue, args)
271
+ except:
272
+ queue.put("exit")
273
+ raise
tools/checkpoint_saver_megatron.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from collections.abc import Mapping
3
+ import concurrent.futures
4
+ import os
5
+ import sys
6
+
7
+ import torch
8
+
9
+ def add_arguments(parser):
10
+ group = parser.add_argument_group(title='Megatron saver')
11
+
12
+ group.add_argument('--megatron-path', type=str, default=None,
13
+ help='Base directory of Megatron repository')
14
+
15
+ group.add_argument('--target-tensor-parallel-size', type=int,
16
+ help='Target tensor model parallel size, defaults to the tensor parallel size '
17
+ 'in the input checkpoint if provided by the loader, otherwise to 1')
18
+ group.add_argument('--target-pipeline-parallel-size', type=int,
19
+ help='Target tensor model parallel size, default to the pipeline parall size '
20
+ 'in the input checkpoint if provided by the loader, otherwise to 1')
21
+
22
+ def save_checkpoint(queue, args):
23
+
24
+ # Search in directory above this
25
+ sys.path.append(os.path.abspath(
26
+ os.path.join(os.path.dirname(__file__),
27
+ os.path.pardir)))
28
+ if args.megatron_path is not None:
29
+ sys.path.insert(0, args.megatron_path)
30
+
31
+ try:
32
+ from megatron.arguments import (parse_args, validate_args)
33
+ from megatron.checkpointing import save_checkpoint
34
+ from megatron.global_vars import set_global_variables, get_args
35
+ from megatron.model import ModelType
36
+ from megatron.tokenizer.tokenizer import _vocab_size_with_padding
37
+ from megatron import mpu, fused_kernels
38
+ except ModuleNotFoundError:
39
+ print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
40
+ exit(1)
41
+
42
+ def queue_get(name=None):
43
+ val = queue.get()
44
+ if val == "exit":
45
+ print("Loader exited, exiting saver")
46
+ exit(1)
47
+ if name is not None and args.checking and val["name"] != name:
48
+ val_name = val["name"]
49
+ print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.')
50
+ exit(1)
51
+ if name is not None:
52
+ print(f"received {name}")
53
+ return val
54
+
55
+ def check_message(msg):
56
+ if not args.checking:
57
+ return
58
+ msg_name = msg.pop("name")
59
+ if len(msg.keys()) > 0:
60
+ print(f"Unexpected values in {msg_name}:")
61
+ for key in msg.keys():
62
+ print(f" {key}")
63
+ print(f"Exiting. If you want to ignore this, use the argument --no-checking.")
64
+ exit(1)
65
+
66
+
67
+ md = queue_get()
68
+
69
+ if args.target_tensor_parallel_size is None:
70
+ if hasattr(md, 'previous_tensor_parallel_size'):
71
+ args.target_tensor_parallel_size = md.previous_tensor_parallel_size
72
+ else:
73
+ print("loader did not provide a tensor parallel size and --target-tensor-parallel-size not provided on command line. "
74
+ "Default to 1.")
75
+ args.target_tensor_parallel_size = 1
76
+
77
+ if args.target_pipeline_parallel_size is None:
78
+ if hasattr(md, 'previous_pipeline_parallel_size'):
79
+ args.target_pipeline_parallel_size = md.previous_pipeline_parallel_size
80
+ else:
81
+ print("loader did not provide a pipeline parallel size and --target-pipeline-parallel-size not provided on command line. "
82
+ "Default to 1.")
83
+ args.target_pipeline_parallel_size = 1
84
+
85
+
86
+ # Arguments do sanity checks on the world size, but we don't care,
87
+ # so trick it into thinking we are plenty of processes
88
+ if args.target_tensor_parallel_size is not None and args.target_pipeline_parallel_size is not None:
89
+ os.environ["WORLD_SIZE"] = f'{args.target_tensor_parallel_size * args.target_pipeline_parallel_size}'
90
+
91
+ # We want all arguments to come from us
92
+ sys.argv = ['script.py',
93
+ '--num-layers', str(md.num_layers),
94
+ '--hidden-size', str(md.hidden_size),
95
+ '--seq-length', str(md.seq_length),
96
+ '--num-attention-heads', str(md.num_attention_heads),
97
+ '--max-position-embeddings', str(md.max_position_embeddings),
98
+ '--tokenizer-type', str(md.tokenizer_type),
99
+ '--tensor-model-parallel-size', str(args.target_tensor_parallel_size),
100
+ '--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size),
101
+ '--no-masked-softmax-fusion',
102
+ '--no-bias-gelu-fusion',
103
+ '--no-bias-dropout-fusion',
104
+ '--use-cpu-initialization',
105
+ '--micro-batch-size', '1',
106
+ '--no-load-optim',
107
+ '--no-load-rng',
108
+ '--no-save-optim',
109
+ '--no-save-rng',
110
+ '--no-initialization',
111
+ '--save-interval', '1',
112
+ '--save', args.save_dir
113
+ ]
114
+
115
+ if md.make_vocab_size_divisible_by is not None:
116
+ sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)])
117
+ if md.params_dtype == torch.float16:
118
+ sys.argv.append('--fp16')
119
+ elif md.params_dtype == torch.bfloat16:
120
+ sys.argv.append('--bf16')
121
+
122
+ if md.model_type == 'BERT' and not md.bert_binary_head:
123
+ sys.argv.append('--bert-no-binary-head')
124
+
125
+ margs = parse_args()
126
+ validate_args(margs)
127
+ set_global_variables(margs)
128
+
129
+ # margs = megatron args
130
+ margs = get_args()
131
+
132
+ if hasattr(md, 'consumed_train_samples'):
133
+ margs.consumed_train_samples = md.consumed_train_samples
134
+ margs.consumed_valid_samples = md.consumed_valid_samples
135
+ print(f"Setting consumed_train_samples to {margs.consumed_train_samples}"
136
+ f" and consumed_valid_samples to {margs.consumed_valid_samples}")
137
+ else:
138
+ print("consumed_train_samples not provided.")
139
+
140
+ # Determine how to make our models
141
+ if md.model_type == 'GPT':
142
+ from pretrain_gpt import model_provider
143
+ margs.model_type = ModelType.encoder_or_decoder
144
+ elif md.model_type == 'BERT':
145
+ from pretrain_bert import model_provider
146
+ margs.model_type = ModelType.encoder_or_decoder
147
+ else:
148
+ raise Exception(f'unrecognized model type: {args.model_type}')
149
+
150
+ def get_models(count, dtype, pre_process, post_process):
151
+ models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)]
152
+ return models
153
+
154
+ # fake initializing distributed
155
+ mpu.initialize.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size)
156
+ mpu.initialize.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
157
+ mpu.initialize.set_tensor_model_parallel_rank(0)
158
+ mpu.initialize.set_pipeline_model_parallel_rank(0)
159
+ fused_kernels.load(margs)
160
+
161
+ # Embeddings
162
+ #-----------
163
+ embeddings_msg = queue_get("embeddings")
164
+
165
+ pos_embed = embeddings_msg.pop("position embeddings")
166
+ orig_word_embed = embeddings_msg.pop("word embeddings")
167
+ check_message(embeddings_msg)
168
+
169
+ # Deal with padding
170
+ if md.true_vocab_size is not None:
171
+ # figure out what our padded vocab size is
172
+ orig_vocab_size = orig_word_embed.shape[0]
173
+ margs.padded_vocab_size = _vocab_size_with_padding(md.true_vocab_size, margs)
174
+
175
+ # Cut out extra padding we don't need
176
+ if orig_vocab_size > margs.padded_vocab_size:
177
+ full_word_embed = orig_word_embed[0:margs.padded_vocab_size,:]
178
+
179
+ # Expanding embedding to larger size by replicating final entry
180
+ elif orig_vocab_size < margs.padded_vocab_size:
181
+ padding_size = margs.padded_vocab_size - orig_vocab_size
182
+
183
+ full_word_embed = torch.cat((
184
+ orig_word_embed,
185
+ orig_word_embed[-1].unsqueeze(0).expand(padding_size, -1)))
186
+
187
+ # Same size!
188
+ else:
189
+ full_word_embed = orig_word_embed
190
+ else:
191
+ print("Original vocab size not specified, leaving embedding table as-is. "
192
+ "If you've changed the tensor parallel size this could cause problems.")
193
+ margs.padded_vocab_size = orig_word_embed.shape[0]
194
+ full_word_embed = orig_word_embed
195
+
196
+ # Split into new tensor model parallel sizes
197
+ out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)
198
+
199
+ # Make models for first pipeline stage and fill in embeddings
200
+ mpu.initialize.set_pipeline_model_parallel_rank(0)
201
+ post_process = args.target_pipeline_parallel_size == 1
202
+ models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
203
+ for tp_rank, model in enumerate(models):
204
+ print(f"word embeddings shape {model.language_model.embedding.word_embeddings.weight.shape}")
205
+ model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
206
+ model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed)
207
+
208
+ # Transformer layers
209
+ #-------------------
210
+ total_layer_num = 0
211
+ for pp_rank in range(args.target_pipeline_parallel_size):
212
+ # For later pipeline parallel ranks, make the new models
213
+ if pp_rank > 0:
214
+ mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
215
+ post_process = pp_rank == args.target_pipeline_parallel_size - 1
216
+ models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)
217
+
218
+ for layer in range(len(models[0].language_model.encoder.layers)):
219
+ msg = queue_get(f"transformer layer {total_layer_num}")
220
+
221
+ # duplicated tensors
222
+ input_layernorm_weight = msg.pop("input layernorm weight")
223
+ input_layernorm_bias = msg.pop("input layernorm bias")
224
+ dense_bias = msg.pop("dense bias")
225
+ post_layernorm_weight = msg.pop("post layernorm weight")
226
+ post_layernorm_bias = msg.pop("post layernorm bias")
227
+ mlp_l1_bias = msg.pop("mlp l1 bias")
228
+
229
+ # Split up the parallel tensors
230
+ qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
231
+ qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
232
+ dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
233
+ mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
234
+ mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
235
+ mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1)
236
+
237
+ # Save them to the model
238
+ for tp_rank in range(args.target_tensor_parallel_size):
239
+ l = models[tp_rank].language_model.encoder.layers[layer]
240
+ l.input_layernorm.weight.data.copy_(input_layernorm_weight)
241
+ l.input_layernorm.bias.data.copy_(input_layernorm_bias)
242
+ l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
243
+ l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
244
+ l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
245
+ l.self_attention.dense.bias.data.copy_(dense_bias)
246
+ l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
247
+ l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
248
+ l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
249
+ l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank])
250
+ l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
251
+ l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias)
252
+ total_layer_num = total_layer_num + 1
253
+ check_message(msg)
254
+
255
+
256
+ if post_process:
257
+ msg = queue_get("final layernorm")
258
+ final_layernorm_weight = msg.pop("weight")
259
+ final_layernorm_bias = msg.pop("bias")
260
+ for tp_rank in range(args.target_tensor_parallel_size):
261
+ models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
262
+ models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
263
+ if pp_rank != 0:
264
+ # Copy word embeddings to final pipeline rank
265
+ models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
266
+ del final_layernorm_weight
267
+ del final_layernorm_bias
268
+ check_message(msg)
269
+
270
+ msg = queue_get()
271
+ if msg != "done" and msg["name"] == "pooler":
272
+ if not hasattr(models[0].language_model, 'pooler'):
273
+ print("ERROR: got a pooler, but model does not have one")
274
+ exit(1)
275
+ print("received pooler")
276
+ pooler_weight = msg.pop("weight")
277
+ pooler_bias = msg.pop("bias")
278
+ for tp_rank in range(args.target_tensor_parallel_size):
279
+ models[tp_rank].language_model.pooler.dense.weight.data.copy_(pooler_weight)
280
+ models[tp_rank].language_model.pooler.dense.bias.data.copy_(pooler_bias)
281
+ del pooler_weight
282
+ del pooler_bias
283
+ check_message(msg)
284
+ msg = queue_get()
285
+
286
+ if msg != "done" and msg["name"] == "lm head":
287
+ if not hasattr(models[0], 'lm_head'):
288
+ print("ERROR: got an lm head, but model does not have one")
289
+ exit(1)
290
+ print("received lm head")
291
+ lm_head_dense_weight = msg.pop("dense weight")
292
+ lm_head_dense_bias = msg.pop("dense bias")
293
+ lm_head_layernorm_weight = msg.pop("layernorm weight")
294
+ lm_head_layernorm_bias = msg.pop("layernorm bias")
295
+ for tp_rank in range(args.target_tensor_parallel_size):
296
+ models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight)
297
+ models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias)
298
+ models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight)
299
+ models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
300
+ check_message(msg)
301
+ msg = queue_get()
302
+
303
+ if msg != "done" and msg["name"] == "binary head":
304
+ if not hasattr(models[0], 'binary_head'):
305
+ print("ERROR: got a binary head, but model does not have one")
306
+ exit(1)
307
+ print("received binary head")
308
+ binary_head_weight = msg.pop("weight")
309
+ binary_head_bias = msg.pop("bias")
310
+ for tp_rank in range(args.target_tensor_parallel_size):
311
+ models[tp_rank].binary_head.weight.data.copy_(binary_head_weight)
312
+ models[tp_rank].binary_head.bias.data.copy_(binary_head_bias)
313
+ check_message(msg)
314
+ msg = queue_get()
315
+
316
+ if msg != "done":
317
+ print("ERROR: got some more data but was expecting to be done")
318
+
319
+ for tp_rank in range(args.target_tensor_parallel_size):
320
+ mpu.initialize.set_tensor_model_parallel_rank(tp_rank)
321
+ save_checkpoint(md.iteration, [models[tp_rank]], None, None)
322
+ print("Done!")
tools/checkpoint_util.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import importlib
3
+ import torch.multiprocessing as mp
4
+ import os
5
+ import sys
6
+
7
+ # A loader is a python file with at least two functions
8
+ # - add_arguments - takes in a parser and adds any arguments needed
9
+ # - load_checkpoint - takes in the queue and parsed arguments
10
+
11
+ # A saver is similar but has save_checkpoint instead of
12
+ # load_checkpoint
13
+
14
+ # The loader and saver process are each given a queue, the loader
15
+ # should load the checkpoint and send the weights in messages in the
16
+ # following order, the saver should receive them in this order and
17
+ # save the checkpoints. A message consists of a python dictionary with
18
+ # a "name" for error checking and an entry for each tensor as
19
+ # indicated below. Note that the weight sent over the queue are the
20
+ # full model weights, nothing split.
21
+
22
+ # If the loader ever sends "exit" to the queue, that means something
23
+ # went wrong and it is exiting.
24
+
25
+ # - Metadata Namespace with the following attributes:
26
+ # model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line)
27
+ # num_layers - Number of transformer layers
28
+ # hidden_size
29
+ # seq_length
30
+ # num_attention_heads
31
+ # max_position_embeddings
32
+ # tokenizer_type
33
+ # iteration
34
+ # params_dtype
35
+ # bert_binary_head - Used only if model_type is BERT
36
+ # previous_tensor_parallel_size - Optional
37
+ # previous_pipeline_parallel_size - Optional
38
+ # true_vocab_size
39
+ # make_vocab_size_divisble_by
40
+ # consumed_train_samples
41
+ # consumed_valid_samples
42
+ # messages
43
+ # {
44
+ # "name": "embeddings"
45
+ # "position embeddings"
46
+ # "word embeddings"
47
+ # }
48
+ # (for each transformer layer):
49
+ # {
50
+ # "name": "transformer layer N"
51
+ # "input layernorm weight"
52
+ # "input layernorm bias"
53
+ # "qkv weight"
54
+ # "qkv bias"
55
+ # "dense weight"
56
+ # "dense bias"
57
+ # "post layernorm weight"
58
+ # "post layernorm bias"
59
+ # "mlp l0 weight"
60
+ # "mlp l0 bias"
61
+ # "mlp l1 weight"
62
+ # "mlp l1 bias"
63
+ # }
64
+ # {
65
+ # "name": "final layer norm"
66
+ # "weight"
67
+ # "bias"
68
+ # }
69
+ # if present (i.e. for BERT):
70
+ # {
71
+ # "name": "pooler"
72
+ # "weight"
73
+ # "bias"
74
+ # }
75
+ # {
76
+ # "name": "lm head"
77
+ # "dense weight"
78
+ # "dense bias"
79
+ # "layernorm weight"
80
+ # "layernorm bias"
81
+ # }
82
+ # {
83
+ # "name": "binary head"
84
+ # "weight"
85
+ # "bias"
86
+ # }
87
+ # - "done"
88
+
89
+ def load_plugin(plugin_type, name):
90
+ module_name = f"checkpoint_{plugin_type}_{name}"
91
+ try:
92
+ plugin = importlib.import_module(module_name)
93
+ except ModuleNotFoundError:
94
+ module_name = name
95
+ try:
96
+ plugin = importlib.import_module(module_name)
97
+ except ModuleNotFoundError:
98
+ sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.")
99
+
100
+ if not hasattr(plugin, 'add_arguments'):
101
+ sys.exit(f"{module_name} module is not a plugin. Exiting.")
102
+
103
+ print(f"Loaded {module_name} as the {plugin_type}.")
104
+ return plugin
105
+
106
+ def main():
107
+ import argparse
108
+ parser = argparse.ArgumentParser(description="Megatron Checkpoint Utility Arguments",
109
+ allow_abbrev=False, conflict_handler='resolve')
110
+
111
+ parser.add_argument('--model-type', type=str, required=True,
112
+ choices=['GPT', 'BERT'],
113
+ help='Type of the model')
114
+ parser.add_argument('--loader', type=str, default='megatron',
115
+ help='Module name to load checkpoint, should be on python path')
116
+ parser.add_argument('--saver', type=str, default='megatron',
117
+ help='Module name to save checkpoint, shdoul be on python path')
118
+ parser.add_argument('--load-dir', type=str, required=True,
119
+ help='Directory to load model checkpoint from')
120
+ parser.add_argument('--save-dir', type=str, required=True,
121
+ help='Directory to save model checkpoint to')
122
+ parser.add_argument('--max-queue-size', type=int, default=50,
123
+ help='Maximum number of tensors in the queue')
124
+ parser.add_argument('--no-checking', action='store_false',
125
+ help='Do not perform checking on the name and ordering of weights',
126
+ dest='checking')
127
+
128
+ known_args, _ = parser.parse_known_args()
129
+ loader = load_plugin('loader', known_args.loader)
130
+ saver = load_plugin('saver', known_args.saver)
131
+
132
+ loader.add_arguments(parser)
133
+ saver.add_arguments(parser)
134
+
135
+ args = parser.parse_args()
136
+
137
+ queue = mp.Queue(maxsize=args.max_queue_size)
138
+
139
+ print("Starting saver...")
140
+ saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args))
141
+ saver_proc.start()
142
+
143
+ print("Starting loader...")
144
+ loader.load_checkpoint(queue, args)
145
+
146
+ print("Waiting for saver to complete...")
147
+ saver_proc.join()
148
+
149
+
150
+ if __name__ == '__main__':
151
+ main()
tools/linter.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import pathlib
4
+ import subprocess
5
+
6
+
7
+ def recursively_lint_files():
8
+ """Recursively lint all python files in chosen subdirectories of megatron-lm"""
9
+
10
+ try:
11
+ import autopep8
12
+ except ModuleNotFoundError:
13
+ print("Please first install autopep8 via `pip install autopep8`")
14
+ return
15
+
16
+ # get all python file paths from top level directory
17
+ file_dir = str(pathlib.Path(__file__).parent.absolute())
18
+ working_dir = osp.join(file_dir, os.pardir)
19
+ all_py_paths = set(os.path.join(working_dir, fname)
20
+ for fname in os.listdir(working_dir) if ".py" in fname)
21
+
22
+ # get all python file paths from chosen subdirectories
23
+ check_dirs = ['docker', 'megatron', 'openwebtext', 'scripts', 'tasks']
24
+ for sub_dir in check_dirs:
25
+ for path, _, fnames in os.walk(osp.join(working_dir, sub_dir)):
26
+ all_py_paths.update(set(osp.join(path, fname) for fname in fnames if ".py" in fname))
27
+
28
+ print("Linting the following: ")
29
+ for py_path in all_py_paths:
30
+ print(py_path)
31
+ command = 'autopep8 --max-line-length 100 --aggressive --in-place {}'.format(py_path)
32
+ subprocess.check_call(command)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ recursively_lint_files()
tools/merge_datasets.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import argparse
5
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
6
+ os.path.pardir)))
7
+
8
+ from megatron.data import indexed_dataset
9
+
10
+
11
+ def main(args):
12
+
13
+ prefixes = set()
14
+ for basename in os.listdir(args.input):
15
+ prefix, ext = os.path.splitext(basename)
16
+
17
+ if prefix in prefixes:
18
+ continue
19
+
20
+ if not os.path.isfile(os.path.join(args.input, basename)):
21
+ continue
22
+
23
+ ext_pair = '.bin' if ext == '.idx' else '.idx'
24
+ assert os.path.isfile(os.path.join(args.input, prefix) + ext_pair), \
25
+ f'ERROR: {ext_pair} file not provided for {os.path.join(args.input, prefix)}'
26
+
27
+ prefixes.add(prefix)
28
+
29
+ builder = None
30
+ for prefix in sorted(prefixes):
31
+ if builder is None:
32
+ dataset = indexed_dataset.make_dataset(os.path.join(args.input, prefix), 'infer')
33
+
34
+ if isinstance(dataset, indexed_dataset.MMapIndexedDataset):
35
+ builder = indexed_dataset.MMapIndexedDatasetBuilder(args.output_prefix + '.bin', dtype=dataset._index.dtype)
36
+ else:
37
+ builder = indexed_dataset.IndexedDatasetBuilder(args.output_prefix + '.bin')
38
+
39
+ del dataset
40
+
41
+ builder.merge_file_(os.path.join(args.input, prefix))
42
+
43
+ builder.finalize(args.output_prefix + '.idx')
44
+
45
+
46
+ if __name__ == '__main__':
47
+ parser = argparse.ArgumentParser()
48
+
49
+ group = parser.add_argument_group(title='input data')
50
+ group.add_argument('--input', type=str, required=True,
51
+ help='Path to directory containing all document files to merge')
52
+
53
+ group = parser.add_argument_group(title='output data')
54
+ group.add_argument('--output-prefix', type=str, required=True,
55
+ help='Path to binary output file without suffix')
56
+
57
+ args = parser.parse_args()
58
+
59
+ assert os.path.isdir(args.input), \
60
+ f'ERROR: {args.input} is not a directory or does not exist'
61
+
62
+ assert os.path.isdir(os.path.dirname(args.output_prefix)), \
63
+ f'ERROR: {os.path.dirname(args.output_prefix)} is not a directory or does not exist'
64
+
65
+ main(args)
66
+
tools/merge_mp_partitions.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Merge model parallel partitions."""
17
+
18
+ import os
19
+ import re
20
+ import sys
21
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
22
+ os.path.pardir)))
23
+
24
+ import torch
25
+
26
+ from megatron import mpu
27
+ from megatron.checkpointing import load_checkpoint, save_checkpoint
28
+ from megatron.checkpointing import ensure_directory_exists
29
+ from megatron.checkpointing import get_checkpoint_name
30
+ from megatron.checkpointing import get_checkpoint_version
31
+ from megatron.checkpointing import get_checkpoint_tracker_filename
32
+ from megatron.global_vars import set_global_variables, get_args
33
+ from megatron.global_vars import rebuild_tokenizer
34
+
35
+
36
+ def split_into_partitions(tensor, num_partitions, partition_dim, stride):
37
+
38
+ per_partition_size = mpu.utils.divide(tensor.size(partition_dim),
39
+ num_partitions)
40
+ per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
41
+
42
+ partitions_list = torch.split(tensor,
43
+ per_partition_per_stride_size,
44
+ dim=partition_dim)
45
+
46
+ partitions = []
47
+ for i in range(num_partitions):
48
+ partition = torch.cat(partitions_list[i::num_partitions],
49
+ dim=partition_dim)
50
+ partitions.append(partition)
51
+
52
+ return partitions
53
+
54
+
55
+ def merge_partitions(merged, partitions, partition_dim, stride):
56
+
57
+ # Number and size of each partition.
58
+ num_partitions = len(partitions)
59
+ per_partition_size = None
60
+ for partition in partitions:
61
+ if per_partition_size is None:
62
+ per_partition_size = partition.size(partition_dim)
63
+ else:
64
+ assert per_partition_size == partition.size(partition_dim)
65
+
66
+ def concat_partitions(partitions_):
67
+ with torch.no_grad():
68
+ if (per_partition_size * num_partitions) == merged.size(
69
+ partition_dim):
70
+ torch.cat(partitions_, dim=partition_dim, out=merged)
71
+ else:
72
+ print(' ***WARNING*** sizes do not match. Will cut '
73
+ 'the merged partitions by {} along dimension {} '
74
+ 'to reduce the size from {} to {} ...'.format(
75
+ (per_partition_size * num_partitions) - \
76
+ merged.size(partition_dim), partition_dim,
77
+ per_partition_size * num_partitions,
78
+ merged.size(partition_dim)))
79
+ merged_ = torch.cat(partitions_, dim=partition_dim)
80
+ merged_split = torch.split(merged_, merged.size(partition_dim),
81
+ dim=partition_dim)
82
+ merged_ = merged_split[0]
83
+ assert merged_.size(partition_dim) == merged.size(partition_dim)
84
+ merged.data.copy_(merged_.data)
85
+
86
+ # If stride is 1, then do simple concatination.
87
+ if stride == 1:
88
+ concat_partitions(partitions)
89
+ return
90
+
91
+ # For none unity strides, first split based on stride and then group.
92
+ per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
93
+ # Chunk and build a list.
94
+ chunks = None
95
+ for i, partition in enumerate(partitions):
96
+ chunk = torch.split(partition,
97
+ per_partition_per_stride_size,
98
+ dim=partition_dim)
99
+
100
+ if chunks is None:
101
+ chunks = [0]*(num_partitions*len(chunk))
102
+ chunks[i::num_partitions] = chunk
103
+
104
+ # Concatinate.
105
+ concat_partitions(chunks)
106
+
107
+ return
108
+
109
+
110
+ def get_model(model_type):
111
+
112
+ if model_type == 'BERT':
113
+ from pretrain_bert import model_provider
114
+ elif model_type == 'GPT':
115
+ from pretrain_gpt import model_provider
116
+ elif model_type == 'RACE':
117
+ from tasks.race.finetune import model_provider
118
+ elif model_type == ['MNLI', 'QQP']:
119
+ num_classes = 2
120
+ if model_type == 'MNLI':
121
+ num_classes = 3
122
+ from megatron.model.classification import Classification
123
+ def model_provider():
124
+ return Classification(num_classes=num_classes, num_tokentypes=2)
125
+ else:
126
+ raise Exception('unrecognized model type: {}'.format(model_type))
127
+
128
+ model = model_provider()
129
+ model = model.half()
130
+
131
+ return model
132
+
133
+
134
+ def get_parallel_checkpoint_name(path):
135
+
136
+ tracker_filename = get_checkpoint_tracker_filename(path)
137
+ iteration = 0
138
+ with open(tracker_filename, 'r') as f:
139
+ metastring = f.read().strip()
140
+ iteration = int(metastring)
141
+ assert iteration > 0
142
+ checkpoint_name = get_checkpoint_name(path, iteration)
143
+
144
+ return checkpoint_name, iteration
145
+
146
+
147
+ def test_split_merge():
148
+
149
+ print('testing split and merge ...')
150
+
151
+ #[QKV.ROW-COL]
152
+ tensor = torch.FloatTensor([[1.11, 1.12, 1.13, 1.14, 1.15],
153
+ [1.21, 1.22, 1.23, 1.24, 1.25],
154
+ [1.31, 1.32, 1.33, 1.34, 1.35],
155
+ [1.41, 1.42, 1.43, 1.44, 1.45],
156
+ [2.11, 2.12, 2.13, 2.14, 2.15],
157
+ [2.21, 2.22, 2.23, 2.24, 2.25],
158
+ [2.31, 2.32, 2.33, 2.34, 2.35],
159
+ [2.41, 2.42, 2.43, 2.44, 2.45],
160
+ [3.11, 3.12, 3.13, 3.14, 3.15],
161
+ [3.21, 3.22, 3.23, 3.24, 3.25],
162
+ [3.31, 3.32, 3.33, 3.34, 3.35],
163
+ [3.41, 3.42, 3.43, 3.44, 3.45]])
164
+
165
+ num_partitions = 2
166
+ partition_dim = 0
167
+ stride = 3
168
+ partitions = split_into_partitions(tensor, num_partitions,
169
+ partition_dim, stride)
170
+
171
+ merged = torch.zeros_like(tensor)
172
+ merge_partitions(merged, partitions, partition_dim, stride)
173
+
174
+ max_error = (merged - tensor).abs().max()
175
+ print(' > max error (should be zero): {}'.format(max_error))
176
+
177
+
178
+ def get_mp_merge_args(parser):
179
+ """Provide extra arguments required for merging."""
180
+ group = parser.add_argument_group(title='mp merge')
181
+
182
+ group.add_argument('--model-type', type=str, required=True,
183
+ choices=['BERT', 'GPT', 'RACE', 'MNLI', 'QQP'],
184
+ help='Type of the mdoel.')
185
+ group.add_argument('--target-pipeline-model-parallel-size', type=int, default=1,
186
+ help='Degree of pipeline model parallelism in output model.')
187
+
188
+ return parser
189
+
190
+
191
+ def main():
192
+
193
+ # Arguments do sanity checks on the world size, but we don't care,
194
+ # so trick it into thinking we are plenty of processes
195
+ os.environ["WORLD_SIZE"] = f'{2**31}'
196
+
197
+ # Args
198
+ set_global_variables(extra_args_provider=get_mp_merge_args,
199
+ args_defaults = {'use_cpu_initialization': True,
200
+ 'micro_batch_size': 1,
201
+ 'no_load_optim': True,
202
+ 'no_load_rng': True,
203
+ 'no_save_optim': True,
204
+ 'no_save_rng': True,
205
+ 'save_interval': 1})
206
+ args = get_args()
207
+
208
+ if args.pipeline_model_parallel_size > 1:
209
+ print("Checkpoints with pipeline model parallelism are not currently supported.")
210
+ exit()
211
+
212
+ model_type = args.model_type
213
+ orig_tensor_model_parallel_size = args.tensor_model_parallel_size
214
+ args.tensor_model_parallel_size = 1
215
+ tokenizer = rebuild_tokenizer(args)
216
+
217
+ print('\n merging model parallel partitions ...')
218
+ print(' > number of partitions: {}'.format(orig_tensor_model_parallel_size))
219
+ print(' > checkpoint path: {}'.format(args.load))
220
+ print(' > model parameters:')
221
+ print(' number of tokens ................ {} '.format(
222
+ tokenizer.vocab_size))
223
+ print(' number of layers ................ {}'.format(args.num_layers))
224
+ print(' hidden size ..................... {}'.format(args.hidden_size))
225
+ print(' number of attention heads ....... {}'.format(
226
+ args.num_attention_heads))
227
+ print(' maximum position embeddings ..... {}'.format(
228
+ args.max_position_embeddings))
229
+
230
+ # Full model.
231
+ print('> building the full model ...')
232
+ mpu.initialize.set_tensor_model_parallel_world_size(1)
233
+ mpu.initialize.set_tensor_model_parallel_rank(0)
234
+ mpu.initialize.set_pipeline_model_parallel_world_size(1)
235
+ mpu.initialize.set_pipeline_model_parallel_rank(0)
236
+ merged_model = get_model(model_type)
237
+
238
+ # Build and load partitions.
239
+ partitions = []
240
+ iteration = 0
241
+ args.tensor_model_parallel_size = orig_tensor_model_parallel_size
242
+ tokenizer = rebuild_tokenizer(args)
243
+ mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
244
+ for rank in range(args.tensor_model_parallel_size):
245
+ # Reset these since load_checkpoint asserts they are 0, but we are loading
246
+ # multiple checkpoints in the same process and they get set each time
247
+ args.consumed_train_samples = 0
248
+ args.consumed_valid_samples = 0
249
+
250
+ mpu.initialize.set_tensor_model_parallel_rank(rank)
251
+ checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
252
+ model_ = get_model(model_type)
253
+ print(f'> loading {checkpoint_name} ...')
254
+ load_checkpoint(model_, None, None)
255
+ print(f'> checkpoint version {get_checkpoint_version()}')
256
+ partitions.append(model_)
257
+
258
+ # Parameter generators so we can loop through them semiltaneouly.
259
+ merged_params_gen = merged_model.named_parameters()
260
+ partitions_params_gen = [partition.named_parameters()
261
+ for partition in partitions]
262
+ while True:
263
+ try:
264
+
265
+ # Get the params and check names.
266
+ name, merged_param = next(merged_params_gen)
267
+ print(' > working on {} ...'.format(name))
268
+ print(' merged type: {}, size: {}'.format(
269
+ merged_param.dtype, list(merged_param.size())))
270
+ partitions_param = []
271
+ for rank, partition_params_gen in enumerate(partitions_params_gen):
272
+ partition_name, partition_param = next(partition_params_gen)
273
+ assert partition_name == name
274
+ partitions_param.append(partition_param)
275
+ print(' partition {} type: {}, size: {}'.format(
276
+ rank, partition_param.dtype, list(partition_param.size())))
277
+
278
+ # For the non-parallel parameters, simply copy the rank 0 values.
279
+ if not hasattr(merged_param, 'tensor_model_parallel'):
280
+ print(' none-parallel parameter, simple copy from rank 0')
281
+ with torch.no_grad():
282
+ merged_param.data.copy_(partitions_param[0].data)
283
+ # For parallel parameters, merge the values
284
+ else:
285
+ dim = merged_param.partition_dim
286
+ stride = merged_param.partition_stride
287
+ print(f' parallel parameter merge with stride {stride} along '
288
+ f'dimention {dim}')
289
+ merge_partitions(merged_param,
290
+ partitions_param,
291
+ dim,
292
+ stride)
293
+
294
+ except StopIteration:
295
+ break
296
+
297
+ partitions = []
298
+ args.tensor_model_parallel_size = 1
299
+ args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
300
+
301
+ assert args.num_layers % args.pipeline_model_parallel_size == 0, \
302
+ 'num_layers must be divisible by target pipeline model parallel size'
303
+ layers_per_part = args.num_layers // args.pipeline_model_parallel_size
304
+
305
+ tokenizer = rebuild_tokenizer(args)
306
+ mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
307
+ mpu.initialize.set_tensor_model_parallel_rank(0)
308
+ mpu.initialize.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size)
309
+
310
+ # regex to parse out layer number from param name
311
+ layer_re = re.compile('layers\.([0-9]+)')
312
+
313
+ if args.pipeline_model_parallel_size > 1:
314
+ merged_params = {}
315
+ for name, merged_param in merged_model.named_parameters():
316
+ merged_params[name] = merged_param
317
+
318
+ for rank in range(args.pipeline_model_parallel_size):
319
+ mpu.initialize.set_pipeline_model_parallel_rank(rank)
320
+ model = get_model(model_type)
321
+ def update_layer_num(m):
322
+ # TODO! This assumes no interleaved pipeline execution
323
+ layer = int(m.group(1))
324
+ layer += rank * layers_per_part
325
+ return f'layers.{layer}'
326
+
327
+ for dst_name, partition_param in model.named_parameters():
328
+ if dst_name == "word_embeddings.weight":
329
+ # See comment in MegatronModule.initialize_word_embeddings()
330
+ src_name = "language_model.embedding.word_embeddings.weight"
331
+ else:
332
+ # Translate destination layer number (0-N for each partition)
333
+ # to source layer number (single-model layer number)
334
+ src_name = re.sub(layer_re, update_layer_num, dst_name)
335
+ print(f" > copying {src_name} to {dst_name} in rank {rank}'s model")
336
+ partition_param.data.copy_(merged_params[src_name].data)
337
+
338
+ partitions.append(model)
339
+ else:
340
+ partitions = [merged_model]
341
+
342
+ for rank, model in enumerate(partitions):
343
+ mpu.initialize.set_pipeline_model_parallel_rank(rank)
344
+ print(f"> saving rank {rank}'s model")
345
+ save_checkpoint(iteration, model, None, None)
346
+
347
+ print('done :-)')
348
+
349
+
350
+ if __name__ == '__main__':
351
+
352
+ main()
tools/openwebtext/README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The following steps show how to prepare training dataset to train the mode.
2
+
3
+ # Libraries to install
4
+
5
+ ```
6
+ pip install ftfy langdetect numpy torch pandas nltk sentencepiece boto3 tqdm regex bs4 newspaper3k htmlmin tldextract
7
+ git clone https://github.com/mattilyra/LSH
8
+ cd LSH
9
+ python setup.py install
10
+ ```
11
+
12
+ # Download the dataset
13
+
14
+ 1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ)
15
+ 2. Remove blacklisted URLs.
16
+ ```
17
+ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for clean urls. e.g. clean_urls.txt>
18
+ ```
19
+ 3. Download the content from the clean urls with [openwebtext's utilities](https://github.com/eukaryote31/openwebtext/blob/master/download.py).
20
+
21
+ 4. Merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`. It is important for the url to be unique.
22
+
23
+ # Prepare the data for GPT training:
24
+
25
+ 1. Perform ftfy, english detection and remove documents with less than 128 tokens. This step can be sharded and run on shards.
26
+ ```
27
+ python cleanup_dataset.py <input data file> <output cleaned data filename>
28
+ ```
29
+ Additional cleanup (e.g. remove documents less than 512 characters or dataset specific cleaning like stories, realnews datasets) can be done using `cleanup_fix_dataset.py`. More details can be found by running `python cleanup_fix_dataset.py --help`.
30
+ 2. Using LSH, find possible duplicates and store then in a file for later processing. The code supports saving and loading fingerprints for recurrent deduplications, and is also multithreaded for faster processing. More details are can be found by `python find_duplicate.py --help`.
31
+ ```
32
+ python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
33
+ ```
34
+ 3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
35
+ ```
36
+ python group_duplicate_urls.py <possible duplicate urls file> <output file containing similar urls>
37
+ ```
38
+ 4. Remove similar documents that were detected in the last step.
39
+ ```
40
+ python remove_group_duplicates.py <file containing simialr documents> <cleaned data file> <outputfile containing deduplicate data>
41
+ ```
42
+
43
+ 5. Shuffle the dataset.
44
+ ```
45
+ shuf <cleaned deduped data file> -o train_data.json
46
+ ```
47
+
48
+ # Deduplicating ngrams
49
+
50
+ To deduplicate the downstream tasks (e.g. lambada, squad) from the training dataset, we run the following command.
51
+
52
+ ```
53
+ python filter_ngrams.py --tasks <name of the task, e.g. lambada, squad> --dedup-dataset <training dataset to deduplicate> <json key> --output <output training dataset>
54
+ ```
55
+ We use 13-grams by default for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times. These parameters can be changed using corresponding arguments.
56
+
57
+ Only for the lambada task, we need to provide the path, `--lambada-path <path of the lambada test data>`.
58
+
59
+ Several other features (e.g. save and load dictionary) have been added, look at `python filter_ngrams.py --help` for details.
tools/openwebtext/add_id.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import json
18
+ import os
19
+ import time
20
+
21
+ """
22
+ This code adds id to each json object in a json file. User can add prefix
23
+ to the ids.
24
+ """
25
+
26
+ if __name__ == '__main__':
27
+
28
+ print('parsing the arguments ...')
29
+
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument('--input-file', type=str, default=None, help='Input'\
32
+ ' json file where id needs to be added')
33
+ parser.add_argument('--output-file', type=str, default=None, help=\
34
+ 'Output file name with id')
35
+ parser.add_argument('--id-prefix', type=str, default=None, help=\
36
+ 'Id prefix')
37
+ parser.add_argument('--log-interval', type=int, default=100,
38
+ help='Log interval')
39
+ args = parser.parse_args()
40
+
41
+ print('Adding ids to dataset ...')
42
+
43
+ f_input = open(args.input_file, 'r', encoding='utf-8')
44
+ f_output = open(args.output_file, 'wb')
45
+
46
+ unique_ids = 1
47
+ start_time = time.time()
48
+ for row in f_input:
49
+ each_row = json.loads(row)
50
+ adlr_id_string = args.id_prefix + '-{:010d}'.format(int(unique_ids))
51
+ each_row['adlr_id'] = adlr_id_string
52
+ myjson = json.dumps(each_row, ensure_ascii=False)
53
+
54
+ f_output.write(myjson.encode('utf-8'))
55
+ f_output.write('\n'.encode('utf-8'))
56
+
57
+ if unique_ids % args.log_interval == 0:
58
+ print(' processed {:9d} documents in {:.2f} seconds ...'.format( \
59
+ unique_ids, time.time() - start_time), flush=True)
60
+
61
+ unique_ids += 1
62
+
63
+ # Close the file.
64
+ f_input.close()
65
+ f_output.close()
66
+
67
+ print('done :-)', flush=True)
tools/openwebtext/blacklist_urls.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import glob
18
+ import re
19
+ import time
20
+ import tldextract
21
+ import sys
22
+
23
+
24
+ # List of the domains to blacklist.
25
+ domain_blacklist = set([
26
+ '500px',
27
+ 'aapks',
28
+ 'akamaihd',
29
+ 'amazon',
30
+ 'apple',
31
+ 'artifactfire',
32
+ 'artstation',
33
+ 'awwni',
34
+ 'bandcamp',
35
+ 'battleforthenet',
36
+ 'coinscalendar',
37
+ 'dailymotion',
38
+ 'deviantart',
39
+ 'discord',
40
+ 'discordapp',
41
+ 'dlapkandroid',
42
+ 'dropbox',
43
+ 'e621',
44
+ 'ebay',
45
+ 'edealinfo',
46
+ 'erome',
47
+ 'eroshare',
48
+ 'explosm',
49
+ 'facebook',
50
+ 'fbcdn',
51
+ 'flickr',
52
+ 'furaffinity',
53
+ 'futhead',
54
+ 'gatopardo',
55
+ 'gfycat',
56
+ 'gifsound',
57
+ 'gifsoup',
58
+ 'giphy',
59
+ 'github',
60
+ 'google',
61
+ 'gunprime',
62
+ 'gyazo',
63
+ 'hotdealstar',
64
+ 'imagefap',
65
+ 'imageshack',
66
+ 'imgflip',
67
+ 'imgur',
68
+ 'instagram',
69
+ 'karmadecay',
70
+ 'kryptocal',
71
+ 'kym-cdn',
72
+ 'liveleak',
73
+ 'livememe',
74
+ 'lmgtfy',
75
+ 'magaimg',
76
+ 'memegenerator',
77
+ 'minorplanetcenter',
78
+ 'minus',
79
+ 'mobafire',
80
+ 'morejpeg',
81
+ 'nocookie',
82
+ 'pcpartpicker',
83
+ 'photobucket',
84
+ 'pinimg',
85
+ 'pinterest',
86
+ 'pixiv',
87
+ 'pornhub',
88
+ 'prntscr',
89
+ 'puu',
90
+ 'qkme',
91
+ 'quickmeme',
92
+ 'radd',
93
+ 'redd',
94
+ 'reddit',
95
+ 'reddit-stream',
96
+ 'redditlog',
97
+ 'redditmedia',
98
+ 'reddituploads',
99
+ 'redtube',
100
+ 'reupp',
101
+ 'reverb',
102
+ 'roanoke',
103
+ 'rollingstone',
104
+ 'sli',
105
+ 'soundcloud',
106
+ 'soundgasm',
107
+ 'spankbang',
108
+ 'spotify',
109
+ 'strawpoll',
110
+ 'streamable',
111
+ 'timeanddate',
112
+ 'tinypic',
113
+ 'touhouradio',
114
+ 'tumblr',
115
+ 'twimg',
116
+ 'twitch',
117
+ 'twitter',
118
+ 'vid',
119
+ 'vimeo',
120
+ 'vine',
121
+ 'vkaao',
122
+ 'vocaroo',
123
+ 'voyagefusion',
124
+ 'walmart',
125
+ 'wciu',
126
+ 'wikimedia',
127
+ 'wikipedia',
128
+ 'xhamster',
129
+ 'xkcd',
130
+ 'xvideos',
131
+ 'youtu',
132
+ 'youtube',
133
+ 'youtubedoubler',
134
+ 'ytimg',
135
+ 'zillexplorer',
136
+ ])
137
+
138
+ def domain_is_in_blacklist(url):
139
+ domain = tldextract.extract(url).domain
140
+ return domain in domain_blacklist
141
+
142
+
143
+ # List of extentions to blacklist.
144
+ extentions_blacklist = (
145
+ '.3gp',
146
+ '.7z'
147
+ '.ai',
148
+ '.aif',
149
+ '.apk',
150
+ '.app',
151
+ '.avi',
152
+ '.bin',
153
+ '.bmp',
154
+ '.bz2',
155
+ '.css',
156
+ '.csv',
157
+ '.dat',
158
+ '.deb',
159
+ '.dmg',
160
+ '.doc',
161
+ '.docx',
162
+ '.exe',
163
+ '.gif',
164
+ '.gifv',
165
+ '.gz',
166
+ '.iso',
167
+ '.jar',
168
+ '.jpeg',
169
+ '.jpg',
170
+ '.js',
171
+ '.log',
172
+ '.mid',
173
+ '.midi',
174
+ '.mkv',
175
+ '.mov',
176
+ '.mp3',
177
+ '.mp4',
178
+ '.mpeg',
179
+ '.mpg',
180
+ '.ogg',
181
+ '.ogv',
182
+ '.otf',
183
+ '.pdf',
184
+ '.pkg',
185
+ '.png',
186
+ '.pps',
187
+ '.ppt',
188
+ '.pptx',
189
+ '.psd',
190
+ '.py',
191
+ '.qt',
192
+ '.ram',
193
+ '.rar',
194
+ '.sql',
195
+ '.svg',
196
+ '.swf',
197
+ '.tar.gz',
198
+ '.tar',
199
+ '.tgz',
200
+ '.tiff',
201
+ '.ttf',
202
+ '.txt',
203
+ '.wav',
204
+ '.webm',
205
+ '.wma',
206
+ '.wmv',
207
+ '.xls',
208
+ '.xlsx',
209
+ '.xml',
210
+ '.xz',
211
+ '.zip',
212
+ )
213
+
214
+ def extention_is_in_blacklist(url):
215
+ if url.split('?')[0].lower().endswith(extentions_blacklist):
216
+ return True
217
+ return False
218
+
219
+
220
+ # Malformed urls.
221
+ # This function is adapted from:
222
+ # https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
223
+ url_regex = re.compile(
224
+ r'^(?:http)s?://' # http:// or https://
225
+ r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
226
+ r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
227
+ r'(?::\d+)?' # optional port
228
+ r'(?:/?|[/?]\S+)$', re.IGNORECASE)
229
+ def url_is_malformed(url):
230
+ return re.match(url_regex, url) is None
231
+
232
+
233
+ def print_progress(prefix, start_time, urls_counter,
234
+ domain_blacklist_counter,
235
+ extention_blacklist_counter,
236
+ short_url_counter, malformed_url_counter,
237
+ duplicate_url_counter):
238
+ string = prefix + ' | '
239
+ string += 'time elapsed (s): {:.2f} | '.format(time.time() - start_time)
240
+ string += 'number of urls: {} | '.format(urls_counter)
241
+ string += 'domain blacklisted: {} | '.format(domain_blacklist_counter)
242
+ string += 'extention blacklisted: {} | '.format(extention_blacklist_counter)
243
+ string += 'short urls (<=8): {} | '.format(short_url_counter)
244
+ string += 'malformed urls: {} | '.format(malformed_url_counter)
245
+ string += 'duplicate urls: {}'.format(duplicate_url_counter)
246
+ print(string, flush=True)
247
+
248
+
249
+ if __name__ == '__main__':
250
+
251
+
252
+ print('remove blacklisted urls ..')
253
+
254
+ # Path to the url files.
255
+ path = sys.argv[1]
256
+ # Output url file.
257
+ output = sys.argv[2]
258
+
259
+ # Get the list of url files.
260
+ files = glob.glob(path + '/*.txt')
261
+ print('> found {} files'.format(len(files)))
262
+
263
+ urls = set()
264
+ urls_counter = 0
265
+ domain_blacklist_counter = 0
266
+ extention_blacklist_counter = 0
267
+ short_url_counter = 0
268
+ malformed_url_counter = 0
269
+ duplicate_url_counter = 0
270
+ start_time = time.time()
271
+ for filename in files:
272
+ with open(filename, 'r') as f:
273
+ for line in f:
274
+ url = line.strip()
275
+ urls_counter += 1
276
+ if domain_is_in_blacklist(url):
277
+ print('[DOMAIN BLACKLIST]: {}'.format(url), flush=True)
278
+ domain_blacklist_counter += 1
279
+ elif extention_is_in_blacklist(url):
280
+ print('[EXTENTION BLACKLIST]: {}'.format(url), flush=True)
281
+ extention_blacklist_counter += 1
282
+ elif len(url) <= 8:
283
+ print('[SHORT URL]: {}'.format(url), flush=True)
284
+ short_url_counter += 1
285
+ elif url_is_malformed(url):
286
+ print('[MALFORMED URL]: {}'.format(url), flush=True)
287
+ malformed_url_counter += 1
288
+ elif url in urls:
289
+ print('[DUPLICATE URL]: {}'.format(url), flush=True)
290
+ duplicate_url_counter += 1
291
+ else:
292
+ urls.add(url)
293
+ if urls_counter % 100000 == 0:
294
+ print_progress('PROGRESS', start_time, urls_counter,
295
+ domain_blacklist_counter,
296
+ extention_blacklist_counter,
297
+ short_url_counter, malformed_url_counter,
298
+ duplicate_url_counter)
299
+
300
+ print_progress('FINAL', start_time, urls_counter,
301
+ domain_blacklist_counter,
302
+ extention_blacklist_counter,
303
+ short_url_counter, malformed_url_counter,
304
+ duplicate_url_counter)
305
+
306
+ # Write the final set of urls.
307
+ print('> writing cleaned up url list to {}'.format(output))
308
+ with open(output, 'w') as f:
309
+ for url in urls:
310
+ f.write(url + '\n')
311
+
312
+ print('done :-)')
tools/openwebtext/cleanup_dataset.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import ftfy
18
+ import json
19
+ from langdetect import detect
20
+ import numpy as np
21
+ import time
22
+ import os
23
+ import sys
24
+
25
+ from tokenizer import Tokenizer
26
+
27
+ MIN_DOCUMENT_LENGHT = 128
28
+
29
+
30
+ def print_progress(prefix, start_time, num_docs, num_fixed_text,
31
+ num_non_english_docs, chars_non_english_docs,
32
+ num_small_docs, chars_small_docs):
33
+
34
+ string = prefix + ' | '
35
+ string += 'elapsed time: {:.2f} | '.format(time.time() - start_time)
36
+ string += 'documents: {} | '.format(num_docs)
37
+ string += 'fixed text: {} | '.format(num_fixed_text)
38
+ string += 'non-english: {} | '.format(num_non_english_docs)
39
+ string += 'non-english chars: {} | '.format(chars_non_english_docs)
40
+ string += 'small docs: {} | '.format(num_small_docs)
41
+ string += 'small docs chars: {}'.format(chars_small_docs)
42
+ print(string, flush=True)
43
+
44
+
45
+ def filter_corpus(filename, out_filename, print_interval=10000):
46
+
47
+ print(' > filtering {}'.format(filename))
48
+
49
+ tokenizer = Tokenizer(cache_dir='./cache')
50
+
51
+ num_docs = 0
52
+ num_written_docs = 0
53
+ num_small_docs = 0
54
+ num_fixed_text = 0
55
+ num_non_english_docs = 0
56
+ chars_non_english_docs = 0
57
+ chars_small_docs = 0
58
+ start_time = time.time()
59
+ with open(out_filename, 'wb') as f:
60
+ with open(filename, 'r') as fin:
61
+ for line in fin:
62
+ try:
63
+ num_docs += 1
64
+ myjson = json.loads(line)
65
+ # Fix text
66
+ text = ftfy.fix_text(myjson['text'])
67
+ if text != myjson['text']:
68
+ num_fixed_text += 1
69
+ myjson['text'] = text
70
+ # Detect language.
71
+ if detect(text) != 'en':
72
+ print('[non-english text]', myjson)
73
+ num_non_english_docs += 1
74
+ chars_non_english_docs += len(text)
75
+ continue
76
+ # On average each token is 5 characters so 8 is an
77
+ # upper bound.
78
+ if len(text) < (8 * MIN_DOCUMENT_LENGHT):
79
+ tokens = tokenizer.tokenize_document(text)
80
+ if len(tokens) < MIN_DOCUMENT_LENGHT:
81
+ print('[small document, skipping]:', myjson)
82
+ num_small_docs += 1
83
+ chars_small_docs += len(text)
84
+ continue
85
+ myjson = json.dumps(myjson, ensure_ascii=False)
86
+ f.write(myjson.encode('utf-8'))
87
+ f.write('\n'.encode('utf-8'))
88
+ num_written_docs += 1
89
+ if num_docs % print_interval == 0:
90
+ print_progress('[PROGRESS]', start_time, num_docs,
91
+ num_fixed_text, num_non_english_docs,
92
+ chars_non_english_docs,
93
+ num_small_docs, chars_small_docs)
94
+ except Exception as e:
95
+ print(' skipping ', line, e)
96
+
97
+ print_progress('[FINAL]', start_time, num_docs,
98
+ num_fixed_text, num_non_english_docs,
99
+ chars_non_english_docs,
100
+ num_small_docs, chars_small_docs)
101
+
102
+
103
+ if __name__ == '__main__':
104
+
105
+ print('building gpt2 dataset ...')
106
+
107
+ input_filename = sys.argv[1]
108
+ output_filename = sys.argv[2]
109
+
110
+ print('will be reading {}'.format(input_filename))
111
+ print('and will write the results to {}'.format(output_filename))
112
+
113
+ filter_corpus(input_filename, output_filename)
114
+
115
+
tools/openwebtext/cleanup_fix_dataset.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Filter and clean documents:
18
+ Capable to clean docs with less than 512 characters, less than
19
+ 256 characters and contains javascript, fix text and dataset specific
20
+ cleaning like stories and realnews datasets.
21
+ Program arguments have the details.
22
+ """
23
+
24
+ import argparse
25
+ from functools import partial
26
+ import glob
27
+ import ftfy
28
+ import json
29
+ from langdetect import detect
30
+ import multiprocessing
31
+ import os
32
+ from pathlib import Path
33
+ import re
34
+ import time
35
+
36
+ def process_doc(json_line, args):
37
+
38
+ # Read the line.
39
+ document = json.loads(json_line)
40
+ text = document['text']
41
+
42
+ output = {'remove_512': False, 'remove_256_javascript': False, \
43
+ 'remove_512_non_english': False, 'ftfy_fix_text': False, \
44
+ 'general_cleaning': False}
45
+
46
+ try:
47
+ # Reomove all docs with less than 512 characters
48
+ if "remove_512" in args.tasks:
49
+ if len(text) < 512:
50
+ output['remove_512'] = True
51
+ return output, text, document, True
52
+
53
+ # Remove docs if less than 256 character length and contains Javascript
54
+ if "remove_256_javascript" in args.tasks:
55
+ if len(text) < 256 and 'javascript' in text.lower():
56
+ output['remove_256_javascript'] = True
57
+ return output, text, document, True
58
+
59
+ # Remove docs < 512 and nonenglish
60
+ if "remove_512_non_english" in args.tasks:
61
+ if len(text) < 512 and detect(text) != 'en':
62
+ output['remove_512_non_english'] = True
63
+ return output, text, document, True
64
+
65
+ # Fix the text using ftfy, don't remove the text, hence return False
66
+ if "ftfy_fix_text" in args.tasks:
67
+ fixed_text = ftfy.fix_text(text)
68
+ output['ftfy_fix_text'] = True
69
+ return output, fixed_text, document, False
70
+
71
+ # Cleaning extra spaces and newlines
72
+ if "general_cleaning" in args.tasks:
73
+ cleaned_text = re.sub(r" +|\b\n+ |\b\n+", " ", text)
74
+ #cleaned_text = re.sub(r"\n\n+", "\n\n", text) # used this for Gutenberg dataset
75
+ #cleaned_text = re.sub(r"\n", "\n\n", text) # Used this for realnews
76
+
77
+ # stories datasets
78
+ #cleaned_text = re.sub(r" \'", "'", text)
79
+ #cleaned_text = re.sub(r" \!", "!", cleaned_text)
80
+ #cleaned_text = re.sub(r" \.", ".", cleaned_text)
81
+ #cleaned_text = re.sub(r" \?", "?", cleaned_text)
82
+ #cleaned_text = re.sub(r" - ", "-", cleaned_text)
83
+ ##cleaned_text = re.sub(r"\" ", "\"", cleaned_text)
84
+ #cleaned_text = re.sub(r" @ ", "@", cleaned_text)
85
+
86
+ output['general_cleaning'] = True
87
+ return output, cleaned_text, document, False
88
+
89
+ except Exception as e:
90
+ print('Error: *************************\n{}\ntext: {}'.format(e, \
91
+ text), flush=True)
92
+ return output, text, document, True
93
+
94
+ # don't remove
95
+ return output, text, document, False
96
+
97
+
98
+ def process_set(args, input_file, output_f_cleaned, output_f_filtered):
99
+
100
+ print(' > working on {} ...'.format(input_file), flush=True)
101
+
102
+ num_docs = num_remove_512 = num_remove_java = num_remove_512_non_english \
103
+ = num_ftfy_fix_text = num_general_cleaning = 0
104
+
105
+ # Output file and counters.
106
+ output_cleaned = open(output_f_cleaned, 'wb')
107
+ output_filtered = open(output_f_filtered, 'wb')
108
+
109
+ start_time = time.time()
110
+
111
+ # Setup multi-processing.
112
+ num_workers = 40
113
+ fin = open(input_file, 'r', encoding='utf-8')
114
+ pool = multiprocessing.Pool(num_workers)
115
+ process_doc_partial = partial(process_doc, args=args)
116
+ processed_docs = pool.imap(process_doc_partial, fin, 500)
117
+
118
+ # Process documents.
119
+ for output, text, document, to_filter in processed_docs:
120
+ num_docs += 1
121
+
122
+ num_remove_512 += 1 if output['remove_512'] else 0
123
+ num_remove_java += 1 if output['remove_256_javascript'] else 0
124
+ num_remove_512_non_english += 1 if output['remove_512_non_english'] \
125
+ else 0
126
+ num_ftfy_fix_text += 1 if output['ftfy_fix_text'] else 0
127
+ num_general_cleaning += 1 if output['general_cleaning'] else 0
128
+
129
+ document['text'] = text
130
+ myjson = json.dumps(document, ensure_ascii=False)
131
+
132
+ if to_filter:
133
+ output_filtered.write(myjson.encode('utf-8'))
134
+ output_filtered.write('\n'.encode('utf-8'))
135
+ else:
136
+ output_cleaned.write(myjson.encode('utf-8'))
137
+ output_cleaned.write('\n'.encode('utf-8'))
138
+
139
+ if num_docs % args.log_interval == 0:
140
+ print(' processed {:9d} documents in {:.2f} seconds ...'.format(
141
+ num_docs, time.time() - start_time), flush=True)
142
+
143
+ # Close the file.
144
+ output_cleaned.close()
145
+ output_filtered.close()
146
+ fin.close()
147
+
148
+ # Print stats.
149
+ print(' >> total docs: {} remove_512 {} remove_256_javascript {} '\
150
+ 'remove_512_non_english {} ftfy_fix_text {} general_cleaning {}'.\
151
+ format(num_docs, num_remove_512, num_remove_java,\
152
+ num_remove_512_non_english, num_ftfy_fix_text, \
153
+ num_general_cleaning), flush=True)
154
+
155
+ if __name__ == '__main__':
156
+
157
+
158
+ print('parsing the arguments ...')
159
+
160
+ parser = argparse.ArgumentParser()
161
+ parser.add_argument('--input-files', nargs = '*', required=True, default=\
162
+ None, help = 'Input json files that needs to be'\
163
+ ' cleaned')
164
+ parser.add_argument('--tasks', nargs = '*', required=True, default=None,\
165
+ help = 'Tasks to perform on the input files, ' \
166
+ 'such as remove_512, remove_256_javascript, ' \
167
+ 'remove_512_non_english, ftfy_fix_text, and ' \
168
+ 'general_cleaning. 256 or 512 means the number' \
169
+ ' of characters.')
170
+
171
+ parser.add_argument('--output-path', type=str, default=None,
172
+ help='Directory where the output should go')
173
+ parser.add_argument('--log-interval', type=int, default=100,
174
+ help='Log interval')
175
+
176
+ args = parser.parse_args()
177
+
178
+ print('cleanup dataset ...')
179
+
180
+ for input_file in args.input_files:
181
+ input_filename, input_filename_ext = os.path.splitext(Path(input_file)\
182
+ .name)
183
+
184
+ output_f_cleaned = os.path.join(args.output_path, input_filename + \
185
+ "_cleaned" + input_filename_ext)
186
+ output_f_filtered = os.path.join(args.output_path, input_filename + \
187
+ "_filtered" + input_filename_ext)
188
+
189
+ process_set(args, input_file, output_f_cleaned, output_f_filtered)
190
+
191
+ print('done :-)', flush=True)
tools/openwebtext/filter_ngrams.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Deduplicate downstream tasks from training dataset. 13-grams have been used.
18
+ All split documents with less than 200 characters got filtered. Any document
19
+ with more than 10 splits got filtered as well.
20
+ """
21
+
22
+ import argparse
23
+ from functools import partial
24
+ import json
25
+ import multiprocessing
26
+ import nltk
27
+ import pickle
28
+ import re
29
+ import string
30
+ import sys
31
+ import time
32
+
33
+ def get_words(text):
34
+ # get all the lowercase words from text
35
+ words, positions = [], []
36
+ for match in re.finditer(r'\w+', text.lower()):
37
+ words.append(match.group(0))
38
+ positions.append(match.start())
39
+ return words, positions
40
+
41
+ # splits the text
42
+ def split_text(text, start_position, remove_char_each_side, seq):
43
+ # first part of the text
44
+ punctuations = ".!?"
45
+ pos = start_position - remove_char_each_side
46
+ text_first = ""
47
+ while pos > 0 and not text[pos] in punctuations:
48
+ pos -= 1
49
+ if pos > 0:
50
+ text_first = text[0:pos+1]
51
+
52
+ # add length of seq and remove_char_each_side
53
+ pos = start_position + len(seq) + remove_char_each_side
54
+
55
+ # last part of the text
56
+ text_second = ""
57
+ while pos < len(text) and not text[pos] in punctuations:
58
+ pos += 1
59
+ if pos + 1 < len(text):
60
+ text_second = text[pos+1:len(text)]
61
+
62
+ return text_first, text_second
63
+
64
+ def check_and_clean_text(args, words, ngrams, text, start_position, \
65
+ text_buf_ngram_free, text_buf, local_ngram):
66
+
67
+ seq = " ".join(words)
68
+ if seq in ngrams:
69
+ print(" [matched]: {}".format(seq), flush=True)
70
+
71
+ if args.get_ngram_freq_only:
72
+ # increase freq of this seq and then only consider the later part
73
+ # of the text for further processing
74
+ if seq in local_ngram:
75
+ local_ngram[seq] += 1
76
+ else:
77
+ local_ngram[seq] = 1
78
+ #print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True)
79
+ if (start_position + len(seq) + 1) < len(text):
80
+ text_buf.append(text[start_position + len(seq) + 1:len(text)])
81
+ return False
82
+
83
+ # split the text
84
+ text_first, text_second = split_text(text, start_position, \
85
+ args.remove_char_each_side, seq)
86
+
87
+ # first part of ngrams free
88
+ if len(text_first) > args.filter_text_char_len:
89
+ text_buf_ngram_free.append(text_first)
90
+
91
+ # add second part for further processing
92
+ if len(text_second) > args.filter_text_char_len:
93
+ text_buf.append(text_second)
94
+
95
+ return False # not ngram free
96
+
97
+ # ngram free
98
+ return True
99
+
100
+
101
+ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
102
+ # remove all the ngrams
103
+
104
+ try:
105
+ myjson = json.loads(line)
106
+ text_buf = [myjson[key]]
107
+ except Exception as e:
108
+ print("Error: {}".format(e), flush=True)
109
+ text_buf = []
110
+
111
+ text_buf_ngram_free = []
112
+ local_ngram = {}
113
+ while len(text_buf) > 0:
114
+
115
+ # get the first one from the buffer
116
+ text = text_buf.pop(0)
117
+ words, positions = get_words(text)
118
+
119
+ ngram_free = True
120
+ # find each max n-grams and check dictionary
121
+ for i in range(len(words) - args.max_ngram_size + 1):
122
+ check_ngram_free = check_and_clean_text(args, words[i:\
123
+ i+args.max_ngram_size], ngrams, text, positions[i], \
124
+ text_buf_ngram_free, text_buf, local_ngram)
125
+
126
+ # the seq is ngram free? if yes, break
127
+ if not check_ngram_free:
128
+ ngram_free = False
129
+ break
130
+
131
+ # if max ngrams doesn't match, check if any other lower n-grams
132
+ # within max ngram macthes
133
+ for ngram_len, _ in ngrams_freq_sorted:
134
+ check_ngram_free = check_and_clean_text(args, words[i:\
135
+ i+ngram_len], ngrams, text, positions[i], \
136
+ text_buf_ngram_free, text_buf, local_ngram)
137
+
138
+ # same check as above
139
+ if not check_ngram_free:
140
+ ngram_free = False
141
+ break
142
+
143
+ # check break from lower than max ngram loop above
144
+ if not ngram_free:
145
+ break
146
+
147
+ # for the last max n-gram, check all the lower ngrams in it
148
+ if ngram_free and len(words) - args.max_ngram_size > 0:
149
+ # get the last words of the lax max ngram
150
+ last_seq_words = words[(len(words)-args.max_ngram_size):len(words)]
151
+ last_seq_start_position = len(words) - args.max_ngram_size
152
+
153
+ # check all n-grams lower than the max
154
+ for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):
155
+
156
+ # ignore the max ngram as has been considered already
157
+ if ngram_len == args.max_ngram_size:
158
+ continue
159
+
160
+ # find each ngram of ngram_len in max n-grams and check
161
+ for i in range(len(last_seq_words) - ngram_len + 1):
162
+ check_ngram_free = check_and_clean_text(args, \
163
+ last_seq_words[i:i+ngram_len], ngrams, text,\
164
+ positions[last_seq_start_position+i], \
165
+ text_buf_ngram_free, text_buf, local_ngram)
166
+
167
+ if not check_ngram_free:
168
+ ngram_free = False
169
+ break
170
+
171
+ if not ngram_free:
172
+ break
173
+
174
+ # texts are ngram free
175
+ if ngram_free and not args.get_ngram_freq_only:
176
+ text_buf_ngram_free.append(text)
177
+
178
+ # check if the text has only been trimmed
179
+ trimmed = 0
180
+ if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \
181
+ len(text_buf_ngram_free[0]) < len(myjson[key]):
182
+ trimmed = 1
183
+
184
+ return text_buf_ngram_free, trimmed, myjson, local_ngram
185
+
186
+ # insert word sequence into dictionary
187
+ def insert_dict(words, ngrams, pos):
188
+ seq = " ".join(words)
189
+ if seq not in ngrams:
190
+ ngrams[seq] = 0
191
+ #ngrams[seq] = pos
192
+
193
+ # insert each ngram from text into the ngrams dictionary
194
+ def compute_ngrams_insert_dict(args, text, ngrams):
195
+ words, positions = get_words(text)
196
+ if len(words) < args.min_ngram_size:
197
+ return
198
+
199
+ if len(words) < args.max_ngram_size:
200
+ insert_dict(words, ngrams, positions[0])
201
+
202
+ for i in range(len(words) - args.max_ngram_size+1):
203
+ insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i])
204
+
205
+
206
+ # Build ngrams for the lambada dataset
207
+ def process_task_lambda(args, task_file, ngrams):
208
+ print(' reading from {} and computing ngrams'.format(task_file))
209
+ with open(task_file, 'r') as f:
210
+ for line in f:
211
+ try:
212
+ myjson = json.loads(line)
213
+ text = myjson['text']
214
+ compute_ngrams_insert_dict(args, text, ngrams)
215
+ except Exception as e:
216
+ print('Error:', e)
217
+ print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
218
+
219
+
220
+ # Build ngrams for the dataset of the given task
221
+ def process_task(args, task_name, ngrams):
222
+
223
+ print(' reading from {} and computing ngrams'.format('import datasets'))
224
+ print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
225
+ # using validation/test data from datasets
226
+ from datasets import load_dataset
227
+
228
+ entities_in_ngrams = len(ngrams)
229
+
230
+ # load the dataset
231
+ if task_name == 'squad':
232
+ dataset = load_dataset('squad_v2', split='validation')
233
+ elif task_name == 'natural_questions':
234
+ dataset = load_dataset('natural_questions', split='validation')
235
+ elif task_name == 'triviaqa':
236
+ dataset = load_dataset('trivia_qa', 'unfiltered', split='test')
237
+ elif task_name == 'webqa':
238
+ dataset = load_dataset('web_questions', split='test')
239
+ elif task_name == 'race':
240
+ dataset = load_dataset('race', 'all', split='test')
241
+ elif task_name == 'drop':
242
+ dataset = load_dataset('drop', split='validation')
243
+ elif task_name == 'coqa':
244
+ dataset = load_dataset('coqa', split='validation')
245
+ elif task_name == 'piqa':
246
+ dataset = load_dataset('piqa', split='test')
247
+ else:
248
+ print("Invalid task name: {}".format(task_name), flush=True)
249
+ return
250
+
251
+ # read the dataset and add to ngrams
252
+ for line in dataset:
253
+ try:
254
+ if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']:
255
+ text = line['question']
256
+ compute_ngrams_insert_dict(args, text, ngrams)
257
+ elif task_name == 'natural_questions':
258
+ text = line['question']['text']
259
+ compute_ngrams_insert_dict(args, text, ngrams)
260
+ elif task_name == 'coqa':
261
+ all_questions = line['questions']
262
+ for question in all_questions:
263
+ compute_ngrams_insert_dict(args, question, ngrams)
264
+ elif task_name == 'piqa':
265
+ text = line['goal']
266
+ compute_ngrams_insert_dict(args, text, ngrams)
267
+ except Exception as e:
268
+ print('Error:', e)
269
+
270
+ print(" After task {} entities in ngrams {}, added {}".format(task_name, \
271
+ len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)
272
+
273
+ def compute_tasks_ngrams(args, ngrams):
274
+ start_time = time.time()
275
+ for _, task_name in enumerate(args.tasks):
276
+ print('Task: {}'.format(task_name), flush=True)
277
+ if task_name == 'lambada':
278
+ assert args.lambada_path is not None
279
+ process_task_lambda(args, args.lambada_path, ngrams)
280
+ else:
281
+ process_task(args, task_name, ngrams)
282
+ print(" Taken time to compute ngrams {:.2f}".format(time.time() - \
283
+ start_time), flush=True)
284
+
285
+ def compute_ngram_freq_sorted(args, ngrams):
286
+ ngrams_freq = {}
287
+ for ngram_key in ngrams.keys():
288
+ length = len(ngram_key.split())
289
+ ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
290
+ ngrams_freq else 1
291
+
292
+ ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0])
293
+ print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
294
+ print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
295
+ len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
296
+ ngrams_freq_sorted) -1 ][0]), flush=True)
297
+ return ngrams_freq_sorted
298
+
299
+ def get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
300
+ dedup_file, dedup_key, ngrams_freq_sorted):
301
+
302
+ start_time = time.time()
303
+ # get the ngrams frequency
304
+ args.get_ngram_freq_only = True
305
+
306
+ # Open the large file to process in parallel
307
+ num_workers = args.num_threads
308
+ pool = multiprocessing.Pool(num_workers)
309
+ fin = open(dedup_file, 'r', encoding='utf-8')
310
+ free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \
311
+ ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted)
312
+ free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500)
313
+
314
+ counter = 0
315
+ for _, _, _, local_ngram in free_ngrams_abt:
316
+ counter += 1
317
+ if counter % 1000 == 0:
318
+ print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'.
319
+ format(counter, time.time() - start_time), flush=True)
320
+ for local_key in local_ngram:
321
+ if local_key in ngrams:
322
+ ngrams[local_key] += 1
323
+ local_ngram = {}
324
+
325
+ print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \
326
+ start_time), flush=True)
327
+ pool.close()
328
+ pool.join()
329
+
330
+ start_time = time.time()
331
+ counter_threshold = 0
332
+ # Get ngram below theadhold
333
+ for local_key, local_val in ngrams.items():
334
+ if ngrams[local_key] < args.key_threshold:
335
+ print(" [threshold] {} {}".format(local_key, local_val), flush=True)
336
+ counter_threshold += 1
337
+ ngrams_below_threshold[local_key] = 1
338
+
339
+ print(' Ngrams below threshold {}'.format(counter_threshold), flush=True)
340
+ fin.close()
341
+
342
+ def clean_ngrams_below_threshold(args, ngrams_below_threshold, dedup_file, \
343
+ dedup_key):
344
+
345
+ start_time = time.time()
346
+ # Now actually filter the dataset
347
+ args.get_ngram_freq_only = False
348
+ #id_prefix = '-'.join(args.tasks[::2])
349
+ id_prefix = '-'.join(args.tasks[::1])
350
+
351
+ # get the range of the size of the ngrams
352
+ ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_below_threshold)
353
+
354
+ # Open the large file to process in parallel
355
+ counter = splitted = ignored = split_mt_thld = trimmed_count = 0
356
+ num_workers = args.num_threads
357
+ pool = multiprocessing.Pool(num_workers)
358
+ fin = open(dedup_file, 'r', encoding='utf-8')
359
+ free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \
360
+ ngrams=ngrams_below_threshold, ngrams_freq_sorted=ngrams_freq_sorted)
361
+ free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500)
362
+
363
+ out_f = open(args.output, 'wb')
364
+
365
+ for text_buf_ngram_free, trimmed, myjson, _ in free_ngrams_clean:
366
+ counter += 1
367
+ try:
368
+
369
+ trimmed_count += trimmed
370
+
371
+ if len(text_buf_ngram_free) > 1:
372
+ splitted += 1
373
+ if len(text_buf_ngram_free) == 0:
374
+ ignored += 1
375
+ # more than 10 splits ignored
376
+ if len(text_buf_ngram_free) > args.splits_count:
377
+ text_buf_ngram_free = []
378
+ split_mt_thld += 1
379
+
380
+ if args.output is not None:
381
+ if "split_id" in myjson:
382
+ use_prefix = myjson["split_id"] + "-"
383
+ else:
384
+ use_prefix = ""
385
+
386
+ for i in range(len(text_buf_ngram_free)):
387
+ split_id_string = id_prefix + '-{:010d}'.format(int(\
388
+ counter)) + '-{:04d}'.format(int(i))
389
+ myjson[dedup_key] = text_buf_ngram_free[i]
390
+ myjson["split_id"] = use_prefix + split_id_string
391
+ outjson = json.dumps(myjson, ensure_ascii=False)
392
+ #outjson = json.dumps({"text":text_buf_ngram_free[i],
393
+ # id_prefix+"_split_id":split_id_string},
394
+ # ensure_ascii=False)
395
+ out_f.write(outjson.encode('utf-8'))
396
+ out_f.write('\n'.encode('utf-8'))
397
+
398
+ if counter % 1000 == 0:
399
+ print(' [final]> processed {} documents in {:.2f} seconds ...'.
400
+ format(counter, time.time() - start_time), flush=True)
401
+ except Exception as e:
402
+ print('Error:', e)
403
+
404
+ print(' [final]> processed {} documents in {:.2f} seconds ...'.
405
+ format(counter, time.time() - start_time), flush=True)
406
+
407
+ print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\
408
+ ' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\
409
+ , flush=True)
410
+
411
+ pool.close()
412
+ pool.join()
413
+
414
+ out_f.close()
415
+ fin.close()
416
+
417
+ if __name__ == '__main__':
418
+
419
+ # we use 13-grams, any text less than 200 characters got removed
420
+ # any text splitted more than 10 got removed as well
421
+
422
+ print('parsing the arguments ...')
423
+
424
+ parser = argparse.ArgumentParser()
425
+ parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
426
+ help = 'Tasks to use for deduplication: currently '
427
+ ' suuport [lambada, squad, natural_questions,'
428
+ ' triviaqa, webqa, race, drop, coqa, and piqa]')
429
+ parser.add_argument('--lambada-path', type=str, default=None,
430
+ help='Only Lambada task needs the path')
431
+ parser.add_argument('--dedup-dataset', nargs = '*', default=None,
432
+ help='Dataset to deduplicate with the key to use'
433
+ ' e.g. cc.json text')
434
+ parser.add_argument('--output', type=str, default=None,
435
+ help='Output file name to save dedup dataset')
436
+ parser.add_argument('--num-threads', type=int, default=40,
437
+ help='Number of threads to use')
438
+ # Default dedup values
439
+ parser.add_argument('--max-ngram-size', type=int, default=13,
440
+ help='Maximum size of ngram to use.')
441
+ parser.add_argument('--min-ngram-size', type=int, default=8,
442
+ help='Minimum size of ngram to use.')
443
+ parser.add_argument('--filter-text-char-len', type=int, default=200,
444
+ help='Remove any text below this length.')
445
+ parser.add_argument('--key-threshold', type=int, default=10,
446
+ help='Number of keys to consider as threshold')
447
+ parser.add_argument('--save-dictionary', type=str, default=None,
448
+ help='Save the dictionary')
449
+ parser.add_argument('--load-dictionary', type=str, default=None,
450
+ help='Load the dictionary')
451
+ parser.add_argument('--splits-count', type=int, default=10,
452
+ help='Remove any documents more than this many splits')
453
+ parser.add_argument('--remove-char-each-side', type=int, default=200,
454
+ help='Maximum size of ngram to use.')
455
+
456
+ args = parser.parse_args()
457
+
458
+ assert len(args.dedup_dataset) == 2
459
+ dedup_file = args.dedup_dataset[0]
460
+ dedup_key = args.dedup_dataset[1]
461
+
462
+ # Setup multi-processing
463
+ num_workers = args.num_threads
464
+ if args.load_dictionary is None:
465
+
466
+ # Build ngrams
467
+ ngrams = {}
468
+ compute_tasks_ngrams(args, ngrams)
469
+
470
+ # get the range of the size of the ngrams
471
+ ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams)
472
+
473
+ # get ngram freq from large file in parallel
474
+ # get ngrams below threshold
475
+ ngrams_below_threshold = {}
476
+ get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
477
+ dedup_file, dedup_key, ngrams_freq_sorted)
478
+
479
+ # save the dictionary if needed
480
+ if args.save_dictionary is not None:
481
+ with open(args.save_dictionary, 'wb') as save_dict_handle:
482
+ pickle.dump(ngrams_below_threshold, save_dict_handle)
483
+ else:
484
+ with open(args.load_dictionary, 'rb') as load_dict_handle:
485
+ ngrams_below_threshold = pickle.load(load_dict_handle)
486
+
487
+ # filter the large file
488
+ if args.output is not None:
489
+ clean_ngrams_below_threshold(args, ngrams_below_threshold, \
490
+ dedup_file, dedup_key)
491
+
492
+ print('done :-)')
tools/openwebtext/find_duplicates.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ from functools import partial
18
+ import itertools
19
+ import json
20
+ from lsh import cache, minhash
21
+ import multiprocessing
22
+ import numpy as np
23
+ import time
24
+ import pickle
25
+ import sys
26
+ import os
27
+
28
+ # This function is adapted from:
29
+ # https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
30
+ def shingles(text, char_ngram=5):
31
+ return set(text[head:head + char_ngram]
32
+ for head in range(0, len(text) - char_ngram))
33
+
34
+
35
+ # This function is adapted from:
36
+ # https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
37
+ def jaccard(set_a, set_b, args):
38
+ if len(set_a) < 1 or len(set_b) < 1:
39
+ return 0.0
40
+
41
+ intersection = set_a & set_b
42
+ union = set_a | set_b
43
+
44
+ if args.jaccard == 'min':
45
+ return len(intersection) / min(len(set_a), len(set_b))
46
+ elif args.jaccard == 'max':
47
+ return len(intersection) / max(len(set_a), len(set_b))
48
+ else:
49
+ return len(intersection) / len(union)
50
+
51
+ def compute_fingerprint(line, key):
52
+ try:
53
+ myjson = json.loads(line)
54
+ url = myjson[key]
55
+ text = myjson['text']
56
+ fingerprint = hasher.fingerprint(text)
57
+ except Exception as e:
58
+ print('Error:', e)
59
+ return None, None, None, False
60
+
61
+ return url, text, fingerprint, True
62
+
63
+ def url_pairs_to_remove(args, bucket_urls, url_doc):
64
+ remove_urls_list = []
65
+ deduped_local, counter_local = 0, 0
66
+ iteration = 0
67
+ while len(bucket_urls) > 1:
68
+ if args.heuristic_iter != -1 and \
69
+ iteration == args.heuristic_iter:
70
+ break
71
+
72
+ items = list(bucket_urls)
73
+ remove_urls = []
74
+ main_url = items[np.random.randint(0, len(items))]
75
+ main_dhingles = shingles(url_doc[main_url])
76
+
77
+ for i in range(0, len(items)):
78
+ counter_local += 1
79
+ other_url = items[i]
80
+ if other_url == main_url:
81
+ continue
82
+ other_shingles = shingles(url_doc[other_url])
83
+ try:
84
+ jaccard_sim = jaccard(main_dhingles, other_shingles, args)
85
+ except Exception as e:
86
+ print('Error:', e)
87
+ jaccard_sim = 0.0
88
+ if jaccard_sim > 0.5:
89
+ remove_urls.append({other_url: jaccard_sim})
90
+ deduped_local += 1
91
+ bucket_urls.remove(other_url)
92
+
93
+ bucket_urls.remove(main_url)
94
+ if len(remove_urls) > 0:
95
+ remove_urls_list.append({main_url: remove_urls})
96
+ iteration += 1
97
+ return remove_urls_list, deduped_local, counter_local
98
+
99
+ def write_remove_urls_list(remove_urls_list, f_out):
100
+ if len(remove_urls_list) > 0:
101
+ for each_url_remove in remove_urls_list:
102
+ myjson = json.dumps(each_url_remove, ensure_ascii=False)
103
+ f_out.write(myjson.encode('utf-8'))
104
+ f_out.write('\n'.encode('utf-8'))
105
+
106
+ def compute_jaccard(each_bin, num_bins, start_time_local):
107
+
108
+ remove_urls_list = []
109
+ deduped_local, counter_local, bucket_local = 0, 0, 0
110
+
111
+ for bucket_id in each_bin:
112
+ bucket_local += 1
113
+ if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
114
+ print("Counter {}, progress {:.2f} time {:.2f}".\
115
+ format(bucket_local, float(bucket_local)/float(len(each_bin)),\
116
+ time.time() - start_time_local), flush=True)
117
+
118
+ if len(each_bin[bucket_id]) <= 1:
119
+ continue
120
+
121
+ bucket_urls = each_bin[bucket_id].copy()
122
+ remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
123
+ url_pairs_to_remove(args, bucket_urls, url_doc)
124
+
125
+ deduped_local += deduped_local_sub
126
+ counter_local += counter_local_sub
127
+ if len(remove_urls_list_sub) > 0:
128
+ remove_urls_list.extend(remove_urls_list_sub)
129
+
130
+ return remove_urls_list, deduped_local, counter_local
131
+
132
+ def find_pair_urls_parallel(args, lshcache, url_doc):
133
+ start_time = time.time()
134
+ f_out = open(args.output, 'wb')
135
+ deduped, counter = 0, 0
136
+
137
+ # compute jaccards of buckets in bin in parallel (parallelism
138
+ # limited to # of bins)
139
+ num_bins = len(lshcache.bins)
140
+ pool = multiprocessing.Pool(num_bins)
141
+ compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
142
+ start_time_local=start_time)
143
+ # don't need to pass args and url_doc as they are already shared
144
+ compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)
145
+
146
+ print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
147
+ flush=True)
148
+ for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
149
+ deduped += deduped_local
150
+ counter += counter_local
151
+ write_remove_urls_list(remove_urls_list, f_out)
152
+ print(' [write]> processed {} documents in {:.2f} '
153
+ 'seoncds and deduped {} documents ...'.format(counter, time.time()\
154
+ - start_time, deduped), flush=True)
155
+
156
+ pool.close()
157
+ pool.join()
158
+ f_out.close()
159
+
160
+ print(' Taken time for jaccard similariries {:.2f} seconds'.format(\
161
+ time.time() - start_time), flush=True)
162
+
163
+ def find_pair_urls_sequential(args, lshcache, url_doc):
164
+ start_time = time.time()
165
+ f_out = open(args.output, 'wb')
166
+ deduped, counter = 0, 0
167
+ for b in lshcache.bins:
168
+ for bucket_id in b:
169
+ if len(b[bucket_id]) <= 1:
170
+ continue
171
+
172
+ bucket_urls = b[bucket_id].copy()
173
+ remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
174
+ url_pairs_to_remove(args, bucket_urls, url_doc)
175
+
176
+ deduped += deduped_local_sub
177
+ counter += counter_local_sub
178
+ write_remove_urls_list(remove_urls_list_sub, f_out)
179
+ if counter % 10000 == 0:
180
+ print(' [write]> processed {} documents in {:.2f} '
181
+ 'seoncds and deduped {} documents ...'.
182
+ format(counter, time.time() - start_time,
183
+ deduped), flush=True)
184
+ f_out.close()
185
+ print(' [write]> processed {} documents in {:.2f} '
186
+ 'seoncds and deduped {} documents ...'.
187
+ format(counter, time.time() - start_time,
188
+ deduped), flush=True)
189
+
190
+ if __name__ == '__main__':
191
+
192
+ print('parsing the arguments ...')
193
+
194
+ parser = argparse.ArgumentParser()
195
+ parser.add_argument('--seed', type=int, default=1234,
196
+ help='Random seed used for python, numpy')
197
+ parser.add_argument('--inputs', nargs = '*', default=None, help = \
198
+ 'Pairwise list of the input files and keys, '
199
+ 'e.g. --inputs cc.json cc_id news.json news_id')
200
+ parser.add_argument('--load-fingerprints', nargs = '*', default=None,
201
+ help='Load fingerprints from a list of pickle files,'
202
+ ' e.g. cc.pkl news.pkl')
203
+ parser.add_argument('--save-fingerprints', type=str, default=None,
204
+ help='Save the fingerprints of the inputs.')
205
+ parser.add_argument('--output', type=str, default=None,
206
+ help='Output file name that consists of all ids'
207
+ ' with matching similarities')
208
+ parser.add_argument('--jaccard', type=str, default='union',
209
+ choices=['union', 'min', 'max'], help='Jaccard'\
210
+ ' similarity computation')
211
+ parser.add_argument('--heuristic-iter', type=int, default=1,
212
+ help='Number of iterations to run the heuristics'
213
+ ': use -1 for exact')
214
+ parser.add_argument('--num-bands', type=int, default=10,
215
+ help='Number of bands to use in cache')
216
+ parser.add_argument('--num-seeds', type=int, default=100,
217
+ help='Number of seeds to use for minhash. Note that'
218
+ ' this value should be divisible by num-bands')
219
+ parser.add_argument('--jaccard-parallel', action='store_true',
220
+ help='Use this to process large number of documents.')
221
+ args = parser.parse_args()
222
+
223
+ print('finding possible duplicate content ...')
224
+
225
+ # set seed and get an array of seeds of 100 integers
226
+ np.random.seed(args.seed)
227
+ seeds = np.random.randint(0, 1e6, size=args.num_seeds)
228
+
229
+ # initialize minhash and lsh cache
230
+ hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
231
+ lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
232
+
233
+ url_doc = {}
234
+
235
+ # load fingerprints from pickle file if needed
236
+ if args.load_fingerprints is not None:
237
+ for count_fp, fp_file_name in enumerate(args.load_fingerprints):
238
+ print("Loading fingerprints from pickle file {}".format(
239
+ fp_file_name), flush=True)
240
+ fp = open(fp_file_name, "rb")
241
+ if count_fp == 0:
242
+ # assign directory for the first pkl
243
+ lshcache = pickle.load(fp)
244
+ url_doc = pickle.load(fp)
245
+ else:
246
+ # append these to lshcache and url_doc
247
+ local_lshcache = pickle.load(fp)
248
+ local_url_doc = pickle.load(fp)
249
+ for url in local_lshcache.fingerprints.keys():
250
+ url_doc[url] = local_url_doc[url]
251
+ lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
252
+ fp.close()
253
+
254
+ counter = 0
255
+ start_time = time.time()
256
+
257
+ # compute finger prints of the inputs if any
258
+ # input file and the key to use as id
259
+ if args.inputs is not None:
260
+ print("Computing fingerprints", flush=True)
261
+ assert len(args.inputs) % 2 == 0
262
+ for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
263
+ print(' document processing {} with key {}'.format(input_file, key),
264
+ flush=True)
265
+
266
+ # compute fingerprints in parallel
267
+ num_workers = 40
268
+ pool = multiprocessing.Pool(num_workers)
269
+ fin = open(input_file, 'r', encoding='utf-8')
270
+ compute_fingerprint_partial = partial(compute_fingerprint, key=key)
271
+ compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
272
+ fin, 512)
273
+ # traverse all the texts and add fingerprints
274
+ for url, text, fingerprint, flag in compute_fingerprint_iter:
275
+ counter += 1
276
+ if flag:
277
+ url_doc[url] = text
278
+ lshcache.add_fingerprint(fingerprint, url)
279
+ if counter % 10000 == 0:
280
+ print(' [read]> processed {} documents in {:.2f} '
281
+ 'seconds ...'.format(counter, time.time() - \
282
+ start_time), flush=True)
283
+
284
+ fin.close()
285
+ pool.close()
286
+ pool.join()
287
+
288
+ # Save the fingerprints if needed
289
+ if args.save_fingerprints is not None:
290
+ print("Saving fingerprints to pickle file {}".format(
291
+ args.save_fingerprints), flush=True)
292
+ with open(args.save_fingerprints, 'wb') as f_save:
293
+ pickle.dump(lshcache, f_save)
294
+ pickle.dump(url_doc, f_save)
295
+
296
+ # compute jaccard index of the input texts and write to file if needed
297
+ if args.output is not None:
298
+ print("Compute jaccard similarity", flush=True)
299
+ if args.jaccard_parallel:
300
+ find_pair_urls_parallel(args, lshcache, url_doc)
301
+ else:
302
+ find_pair_urls_sequential(args, lshcache, url_doc)
303
+
304
+ print('done :-)')
305
+
tools/openwebtext/group_duplicate_url.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import time
18
+ import sys
19
+
20
+
21
+ if __name__ == '__main__':
22
+
23
+
24
+ print('grouping duplicate urls ...')
25
+
26
+ input = sys.argv[1]
27
+ output = sys.argv[2]
28
+ if len(sys.argv) > 3:
29
+ jaccard_similarity_threshold = float(sys.argv[3])
30
+ else:
31
+ jaccard_similarity_threshold = 0.7
32
+
33
+ url_to_index = {}
34
+ index_to_urls = []
35
+ counter = 0
36
+ start_time = time.time()
37
+ with open(input, 'r') as f:
38
+ for line in f:
39
+ counter += 1
40
+ myjson = json.loads(line)
41
+ urls = []
42
+ for main_url in myjson.keys():
43
+ urls.append(main_url)
44
+ for value in myjson[main_url]:
45
+ for other_url, js in value.items():
46
+ if js >= jaccard_similarity_threshold:
47
+ urls.append(other_url)
48
+ current_index = -1
49
+ other_indices = set()
50
+ for url in urls:
51
+ if url in url_to_index:
52
+ if current_index == -1:
53
+ current_index = url_to_index[url]
54
+ elif current_index != url_to_index[url]:
55
+ other_indices.add(url_to_index[url])
56
+ if current_index == -1:
57
+ current_index = len(index_to_urls)
58
+ index_to_urls.append(set())
59
+ for url in urls:
60
+ url_to_index[url] = current_index
61
+ index_to_urls[current_index].add(url)
62
+ for index in other_indices:
63
+ for url in index_to_urls[index]:
64
+ index_to_urls[current_index].add(url)
65
+ url_to_index[url] = current_index
66
+ index_to_urls[index] = None
67
+
68
+ if counter % 100000 == 0:
69
+ print(' > processed {} lines in {} seconds ...'.format(
70
+ counter, time.time() - start_time))
71
+
72
+
73
+ total_remove = 0
74
+ total_remain = 0
75
+ for urls in index_to_urls:
76
+ if urls is not None:
77
+ if len(urls) > 1:
78
+ total_remove += (len(urls) - 1)
79
+ total_remain += 1
80
+ print('out of {} urls, only {} are unique and {} should be removed'.format(
81
+ total_remove+total_remain, total_remain, total_remove))
82
+
83
+ with open(output, 'wb') as f:
84
+ for i, urls in enumerate(index_to_urls):
85
+ if urls is not None:
86
+ if len(urls) > 1:
87
+ myjson = json.dumps({str(i): list(urls)},
88
+ ensure_ascii=False)
89
+ f.write(myjson.encode('utf-8'))
90
+ f.write('\n'.encode('utf-8'))
tools/openwebtext/merge_jsons.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import glob
18
+ import sys
19
+ import json
20
+ import argparse
21
+
22
+ if __name__ == '__main__':
23
+
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--json_path", type=str, default=".",
26
+ help="path where all the json files are located")
27
+
28
+ parser.add_argument("--output_file", type=str, default="merged_output.json",
29
+ help="filename where the merged json should go")
30
+
31
+ args = parser.parse_args()
32
+
33
+ json_path = args.json_path
34
+ out_file = args.output_file
35
+
36
+ json_files = glob.glob(json_path + '/*.json')
37
+
38
+ counter = 0
39
+
40
+ with open(out_file, 'w') as outfile:
41
+ for fname in json_files:
42
+ counter += 1
43
+
44
+ if counter % 1024 == 0:
45
+ print("Merging at ", counter, flush=True)
46
+
47
+ with open(fname, 'r') as infile:
48
+ for row in infile:
49
+ each_row = json.loads(row)
50
+ outfile.write(row)
51
+
52
+
53
+ print("Merged file", out_file, flush=True)
54
+
55
+
tools/openwebtext/remove_group_duplicates.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import json
18
+ import time
19
+ import sys
20
+
21
+
22
+ if __name__ == '__main__':
23
+
24
+ url_filename = sys.argv[1]
25
+ data_filename = sys.argv[2]
26
+ output_filename = sys.argv[3]
27
+
28
+ urls = set()
29
+ with open(url_filename, 'r') as f:
30
+ for line in f:
31
+ myjson = json.loads(line)
32
+ for key in myjson:
33
+ this_urls = myjson[key]
34
+ for i in range(1, len(this_urls)):
35
+ urls.add(this_urls[i])
36
+ print('will be removing {} urls'.format(len(urls)), flush=True)
37
+
38
+ written_docs = 0
39
+ removed_docs = 0
40
+ removed_chars = 0
41
+ start_time = time.time()
42
+ with open(output_filename, 'wb') as fout:
43
+ with open(data_filename, 'r') as fin:
44
+ for line in fin:
45
+ try:
46
+ myjson = json.loads(line)
47
+ url = myjson['url']
48
+ if url in urls:
49
+ print('removing', myjson)
50
+ removed_docs += 1
51
+ removed_chars += len(myjson['text'])
52
+ continue
53
+ myjson = json.dumps(myjson, ensure_ascii=False)
54
+ fout.write(myjson.encode('utf-8'))
55
+ fout.write('\n'.encode('utf-8'))
56
+ written_docs += 1
57
+ if written_docs % 10000 == 0:
58
+ print(' [PROCESSED] time (s): {:.2f} | written: {} '
59
+ '| removed: {} (char: {})'.format(
60
+ time.time() - start_time,
61
+ written_docs, removed_docs, removed_chars))
62
+ except Exception as e:
63
+ print('[SKIPPING]', line, e)
64
+
65
+ print(' [PROCESSED] time (s): {:.2f} | written: {} '
66
+ '| removed: {} (char: {})'.format(
67
+ time.time() - start_time,
68
+ written_docs, removed_docs, removed_chars))
69
+ print('done :-)')
tools/preprocess_data.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Processing data for pretraining."""
17
+
18
+ import argparse
19
+ import json
20
+ import multiprocessing
21
+ import os
22
+ import sys
23
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
24
+ os.path.pardir)))
25
+ import time
26
+
27
+ import torch
28
+ try:
29
+ import nltk
30
+ nltk_available = True
31
+ except ImportError:
32
+ nltk_available = False
33
+
34
+ from megatron.tokenizer import build_tokenizer
35
+ from megatron.data import indexed_dataset
36
+
37
+
38
+ # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
39
+ class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
40
+
41
+ _period_context_fmt = r"""
42
+ \S* # some word material
43
+ %(SentEndChars)s # a potential sentence ending
44
+ \s* # <-- THIS is what I changed
45
+ (?=(?P<after_tok>
46
+ %(NonWord)s # either other punctuation
47
+ |
48
+ (?P<next_tok>\S+) # <-- Normally you would have \s+ here
49
+ ))"""
50
+
51
+ class IdentitySplitter(object):
52
+ def tokenize(self, *text):
53
+ return text
54
+
55
+ class Encoder(object):
56
+ def __init__(self, args):
57
+ self.args = args
58
+
59
+ def initializer(self):
60
+ # Use Encoder class as a container for global data
61
+ Encoder.tokenizer = build_tokenizer(self.args)
62
+ if self.args.split_sentences:
63
+ if not nltk_available:
64
+ print("NLTK is not available to split sentences.")
65
+ exit()
66
+ splitter = nltk.load("tokenizers/punkt/english.pickle")
67
+ if self.args.keep_newlines:
68
+ # this prevents punkt from eating newlines after sentences
69
+ Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
70
+ train_text = splitter._params,
71
+ lang_vars = CustomLanguageVars())
72
+ else:
73
+ Encoder.splitter = splitter
74
+
75
+ else:
76
+ Encoder.splitter = IdentitySplitter()
77
+
78
+ def encode(self, json_line):
79
+ data = json.loads(json_line)
80
+ ids = {}
81
+ for key in self.args.json_keys:
82
+ text = data[key]
83
+ doc_ids = []
84
+ for sentence in Encoder.splitter.tokenize(text):
85
+ sentence_ids = Encoder.tokenizer.tokenize(sentence)
86
+ if len(sentence_ids) > 0:
87
+ doc_ids.append(sentence_ids)
88
+ if len(doc_ids) > 0 and self.args.append_eod:
89
+ doc_ids[-1].append(Encoder.tokenizer.eod)
90
+ ids[key] = doc_ids
91
+ return ids, len(json_line)
92
+
93
+ def get_args():
94
+ parser = argparse.ArgumentParser()
95
+ group = parser.add_argument_group(title='input data')
96
+ group.add_argument('--input', type=str, required=True,
97
+ help='Path to input JSON')
98
+ group.add_argument('--json-keys', nargs='+', default=['text'],
99
+ help='space separate listed of keys to extract from json')
100
+ group.add_argument('--split-sentences', action='store_true',
101
+ help='Split documents into sentences.')
102
+ group.add_argument('--keep-newlines', action='store_true',
103
+ help='Keep newlines between sentences when splitting.')
104
+
105
+ group = parser.add_argument_group(title='tokenizer')
106
+ group.add_argument('--tokenizer-type', type=str, required=True,
107
+ choices=['BertWordPieceLowerCase','BertWordPieceCase',
108
+ 'GPT2BPETokenizer'],
109
+ help='What type of tokenizer to use.')
110
+ group.add_argument('--vocab-file', type=str, default=None,
111
+ help='Path to the vocab file')
112
+ group.add_argument('--merge-file', type=str, default=None,
113
+ help='Path to the BPE merge file (if necessary).')
114
+ group.add_argument('--append-eod', action='store_true',
115
+ help='Append an <eod> token to the end of a document.')
116
+
117
+
118
+ group = parser.add_argument_group(title='output data')
119
+ group.add_argument('--output-prefix', type=str, required=True,
120
+ help='Path to binary output file without suffix')
121
+ group.add_argument('--dataset-impl', type=str, default='mmap',
122
+ choices=['lazy', 'cached', 'mmap'])
123
+
124
+ group = parser.add_argument_group(title='runtime')
125
+ group.add_argument('--workers', type=int, required=True,
126
+ help='Number of worker processes to launch')
127
+ group.add_argument('--chunk-size', type=int, required=True,
128
+ help='Chunk size assigned to each worker process')
129
+ group.add_argument('--log-interval', type=int, default=100,
130
+ help='Interval between progress updates')
131
+ args = parser.parse_args()
132
+ args.keep_empty = False
133
+
134
+ if args.tokenizer_type.lower().startswith('bert'):
135
+ if not args.split_sentences:
136
+ print("Bert tokenizer detected, are you sure you don't want to split sentences?")
137
+
138
+ # some default/dummy values for the tokenizer
139
+ args.rank = 0
140
+ args.make_vocab_size_divisible_by = 128
141
+ args.tensor_model_parallel_size = 1
142
+ args.vocab_extra_ids = 0
143
+
144
+ return args
145
+
146
+ def main():
147
+ args = get_args()
148
+ startup_start = time.time()
149
+
150
+ print("Opening", args.input)
151
+ fin = open(args.input, 'r', encoding='utf-8')
152
+
153
+ if nltk_available and args.split_sentences:
154
+ nltk.download("punkt", quiet=True)
155
+
156
+ encoder = Encoder(args)
157
+ tokenizer = build_tokenizer(args)
158
+ pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
159
+ encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size)
160
+ #encoded_docs = map(encoder.encode, fin)
161
+
162
+ level = "document"
163
+ if args.split_sentences:
164
+ level = "sentence"
165
+
166
+ print(f"Vocab size: {tokenizer.vocab_size}")
167
+ print(f"Output prefix: {args.output_prefix}")
168
+ output_bin_files = {}
169
+ output_idx_files = {}
170
+ builders = {}
171
+ for key in args.json_keys:
172
+ output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
173
+ key, level)
174
+ output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
175
+ key, level)
176
+ builders[key] = indexed_dataset.make_builder(output_bin_files[key],
177
+ impl=args.dataset_impl,
178
+ vocab_size=tokenizer.vocab_size)
179
+
180
+ startup_end = time.time()
181
+ proc_start = time.time()
182
+ total_bytes_processed = 0
183
+ print("Time to startup:", startup_end - startup_start)
184
+
185
+ for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
186
+ total_bytes_processed += bytes_processed
187
+ for key, sentences in doc.items():
188
+ if len(sentences) == 0:
189
+ continue
190
+ for sentence in sentences:
191
+ builders[key].add_item(torch.IntTensor(sentence))
192
+ builders[key].end_document()
193
+ if i % args.log_interval == 0:
194
+ current = time.time()
195
+ elapsed = current - proc_start
196
+ mbs = total_bytes_processed/elapsed/1024/1024
197
+ print(f"Processed {i} documents",
198
+ f"({i/elapsed} docs/s, {mbs} MB/s).",
199
+ file=sys.stderr)
200
+
201
+ for key in args.json_keys:
202
+ builders[key].finalize(output_idx_files[key])
203
+
204
+ if __name__ == '__main__':
205
+ main()
tools/run_build_data.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #INPUT="roberta_train_data_raw/valid.json"
3
+ INPUT="/mnt/nvme0/ouyangxuan/project_pretrain/make_pretrain_data/roberta_train_data_raw/valid.json"
4
+ python preprocess_data.py \
5
+ --input ${INPUT} \
6
+ --output-prefix my-bert \
7
+ --vocab bert-vocab.txt \
8
+ --dataset-impl mmap \
9
+ --worker 1 \
10
+ --chunk-size 1 \
11
+ --tokenizer-type BertWordPieceLowerCase \
12
+ --split-sentences
13
+
14
+
15
+ #--input /mnt/nvme1/ouyangxuan/project_pretrain/find_framework/tmp_data/data.json \
16
+ #--input roberta_train_data_raw/train_1g.json \
tools/run_text_generation_server.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Sample Generate GPT"""
17
+ import os
18
+ import sys
19
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
20
+ os.path.pardir)))
21
+ import socket
22
+ from megatron import get_args
23
+ from megatron import print_rank_0
24
+ from megatron import mpu
25
+ from megatron.checkpointing import load_checkpoint
26
+ from megatron.initialize import initialize_megatron
27
+ from megatron.model import GPTModel
28
+ from megatron.training import get_model
29
+ from megatron.text_generation_server import MegatronServer
30
+ from megatron.text_generation import generate_and_post_process
31
+ from megatron.text_generation import beam_search_and_post_process
32
+ import torch
33
+
34
+ def model_provider(pre_process=True, post_process=True):
35
+ """Build the model."""
36
+
37
+ print_rank_0('building GPT model ...')
38
+ model = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process)
39
+
40
+ return model
41
+
42
+ def add_text_generate_args(parser):
43
+ group = parser.add_argument_group(title='text generation')
44
+
45
+ group.add_argument("--temperature", type=float, default=1.0,
46
+ help='Sampling temperature.')
47
+ group.add_argument("--top_p", type=float, default=0.0,
48
+ help='Top p sampling.')
49
+ group.add_argument("--top_k", type=int, default=0,
50
+ help='Top k sampling.')
51
+ group.add_argument("--out-seq-length", type=int, default=1024,
52
+ help='Size of the output generated text.')
53
+ return parser
54
+
55
+
56
+ if __name__ == "__main__":
57
+ initialize_megatron(extra_args_provider=add_text_generate_args,
58
+ args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
59
+ 'no_load_rng': True,
60
+ 'no_load_optim': True})
61
+
62
+ args = get_args()
63
+ if args.num_layers_per_virtual_pipeline_stage is not None:
64
+ print("Interleaved pipeline schedule is not yet supported for text generation.")
65
+ exit()
66
+ # Set up model and load checkpoint
67
+ model = get_model(model_provider, wrap_with_ddp=False)
68
+
69
+ if args.load is not None:
70
+ _ = load_checkpoint(model, None, None)
71
+
72
+ assert len(model) == 1, "Above condition should have caught this"
73
+ model = model[0]
74
+ if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
75
+ server = MegatronServer(model)
76
+ server.run("0.0.0.0")
77
+
78
+ while True:
79
+ choice = torch.cuda.LongTensor(1)
80
+ torch.distributed.broadcast(choice, 0)
81
+ if choice[0].item() == 0:
82
+ try:
83
+ generate_and_post_process(model)
84
+ except ValueError as ve:
85
+ pass
86
+ elif choice[0].item() == 1:
87
+ try:
88
+ beam_search_and_post_process(model)
89
+ except ValueError as ve:
90
+ pass
tools/text_generation_cli.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import json
16
+ import sys
17
+ import urllib2
18
+ class PutRequest(urllib2.Request):
19
+ '''class to handling putting with urllib2'''
20
+
21
+ def get_method(self, *args, **kwargs):
22
+ return 'PUT'
23
+
24
+ if __name__ == "__main__":
25
+ url = sys.argv[1]
26
+ while True:
27
+ sentence = raw_input("Enter prompt: ")
28
+ tokens_to_generate = int(input("Enter number of tokens to generate: "))
29
+ data = json.dumps({"prompts": [sentence], "tokens_to_generate":tokens_to_generate})
30
+ req = PutRequest(url, data, {'Content-Type': 'application/json'})
31
+ response = urllib2.urlopen(req)
32
+ resp_sentences = json.load(response)
33
+ print("Megatron Response: ")
34
+ print(resp_sentences["text"][0])