Upload 21 files
Browse files- tools/bert-vocab.txt +0 -0
- tools/checkpoint_loader_megatron.py +273 -0
- tools/checkpoint_saver_megatron.py +322 -0
- tools/checkpoint_util.py +151 -0
- tools/linter.py +36 -0
- tools/merge_datasets.py +66 -0
- tools/merge_mp_partitions.py +352 -0
- tools/openwebtext/README.md +59 -0
- tools/openwebtext/add_id.py +67 -0
- tools/openwebtext/blacklist_urls.py +312 -0
- tools/openwebtext/cleanup_dataset.py +115 -0
- tools/openwebtext/cleanup_fix_dataset.py +191 -0
- tools/openwebtext/filter_ngrams.py +492 -0
- tools/openwebtext/find_duplicates.py +305 -0
- tools/openwebtext/group_duplicate_url.py +90 -0
- tools/openwebtext/merge_jsons.py +55 -0
- tools/openwebtext/remove_group_duplicates.py +69 -0
- tools/preprocess_data.py +205 -0
- tools/run_build_data.sh +16 -0
- tools/run_text_generation_server.py +90 -0
- tools/text_generation_cli.py +34 -0
tools/bert-vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tools/checkpoint_loader_megatron.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import types
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
def add_arguments(parser):
|
9 |
+
group = parser.add_argument_group(title='Megatron loader')
|
10 |
+
|
11 |
+
group.add_argument('--true-vocab-size', type=int, default=None,
|
12 |
+
help='original size of vocab, if specified will trim padding from embedding table.')
|
13 |
+
group.add_argument('--vocab-file', type=str, default=None,
|
14 |
+
help='Path to the vocab file. If specified will use this to get vocab size and '
|
15 |
+
'trim padding from the embedding table.')
|
16 |
+
group.add_argument('--megatron-path', type=str, default=None,
|
17 |
+
help='Base directory of deepspeed repository')
|
18 |
+
|
19 |
+
def _load_checkpoint(queue, args):
|
20 |
+
|
21 |
+
# Search in directory above this
|
22 |
+
sys.path.append(os.path.abspath(
|
23 |
+
os.path.join(os.path.dirname(__file__),
|
24 |
+
os.path.pardir)))
|
25 |
+
if args.megatron_path is not None:
|
26 |
+
sys.path.insert(0, args.megatron_path)
|
27 |
+
|
28 |
+
try:
|
29 |
+
from megatron.arguments import parse_args, validate_args
|
30 |
+
from megatron.global_vars import set_args, set_global_variables
|
31 |
+
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
|
32 |
+
from megatron.model import ModelType, module
|
33 |
+
from megatron import mpu, fused_kernels
|
34 |
+
except ModuleNotFoundError:
|
35 |
+
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
|
36 |
+
queue.put("exit")
|
37 |
+
exit(1)
|
38 |
+
|
39 |
+
# We want all arguments to come from us
|
40 |
+
sys.argv = ['script.py',
|
41 |
+
'--no-masked-softmax-fusion',
|
42 |
+
'--no-bias-gelu-fusion',
|
43 |
+
'--no-bias-dropout-fusion',
|
44 |
+
'--use-cpu-initialization',
|
45 |
+
'--micro-batch-size', '1',
|
46 |
+
'--no-load-optim',
|
47 |
+
'--no-load-rng',
|
48 |
+
'--no-save-optim',
|
49 |
+
'--no-save-rng',
|
50 |
+
'--no-initialization',
|
51 |
+
'--load', args.load_dir
|
52 |
+
]
|
53 |
+
|
54 |
+
margs = parse_args()
|
55 |
+
margs = load_args_from_checkpoint(margs)
|
56 |
+
|
57 |
+
# Arguments do sanity checks on the world size, but we don't care,
|
58 |
+
# so trick it into thinking we are plenty of processes
|
59 |
+
margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size
|
60 |
+
|
61 |
+
margs = validate_args(margs)
|
62 |
+
|
63 |
+
def check_for_arg(arg_name):
|
64 |
+
if getattr(margs, arg_name, None) is None:
|
65 |
+
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
|
66 |
+
print(f"Arguments: {margs}")
|
67 |
+
queue.put("exit")
|
68 |
+
exit(1)
|
69 |
+
|
70 |
+
check_for_arg('tensor_model_parallel_size')
|
71 |
+
check_for_arg('pipeline_model_parallel_size')
|
72 |
+
check_for_arg('num_layers')
|
73 |
+
check_for_arg('hidden_size')
|
74 |
+
check_for_arg('seq_length')
|
75 |
+
check_for_arg('num_attention_heads')
|
76 |
+
check_for_arg('max_position_embeddings')
|
77 |
+
check_for_arg('tokenizer_type')
|
78 |
+
check_for_arg('iteration')
|
79 |
+
check_for_arg('bert_binary_head')
|
80 |
+
check_for_arg('params_dtype')
|
81 |
+
|
82 |
+
# Determine how to make our models
|
83 |
+
if args.model_type == 'GPT':
|
84 |
+
from pretrain_gpt import model_provider
|
85 |
+
margs.model_type = ModelType.encoder_or_decoder
|
86 |
+
elif args.model_type == 'BERT':
|
87 |
+
from pretrain_bert import model_provider
|
88 |
+
margs.model_type = ModelType.encoder_or_decoder
|
89 |
+
else:
|
90 |
+
raise Exception(f'unrecognized model type: {args.model_type}')
|
91 |
+
|
92 |
+
# supress warning about torch.distributed not being initialized
|
93 |
+
module.MegatronModule.embedding_warning_printed = True
|
94 |
+
|
95 |
+
consumed_train_samples = None
|
96 |
+
consumed_valid_samples = None
|
97 |
+
def get_models(count, dtype, pre_process, post_process):
|
98 |
+
nonlocal consumed_train_samples
|
99 |
+
nonlocal consumed_valid_samples
|
100 |
+
models = []
|
101 |
+
for rank in range(count):
|
102 |
+
mpu.initialize.set_tensor_model_parallel_rank(rank)
|
103 |
+
model_ = [model_provider(pre_process, post_process).to(dtype)]
|
104 |
+
margs.consumed_train_samples = 0
|
105 |
+
margs.consumed_valid_samples = 0
|
106 |
+
load_checkpoint(model_, None, None)
|
107 |
+
assert(len(model_) == 1)
|
108 |
+
model_ = model_[0]
|
109 |
+
if consumed_train_samples is not None:
|
110 |
+
assert(margs.consumed_train_samples == consumed_train_samples)
|
111 |
+
else:
|
112 |
+
consumed_train_samples = margs.consumed_train_samples
|
113 |
+
if consumed_valid_samples is not None:
|
114 |
+
assert(margs.consumed_valid_samples == consumed_valid_samples)
|
115 |
+
else:
|
116 |
+
consumed_valid_samples = margs.consumed_valid_samples
|
117 |
+
models.append(model_)
|
118 |
+
return models
|
119 |
+
|
120 |
+
if margs.num_layers_per_virtual_pipeline_stage is not None:
|
121 |
+
print("Model with an interleaved pipeline schedule are not yet supported.")
|
122 |
+
queue.put("exit")
|
123 |
+
exit(1)
|
124 |
+
|
125 |
+
set_global_variables(margs)
|
126 |
+
mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
|
127 |
+
mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
|
128 |
+
fused_kernels.load(margs)
|
129 |
+
|
130 |
+
# Get true (non-padded) vocab size
|
131 |
+
if args.true_vocab_size is not None:
|
132 |
+
true_vocab_size = args.true_vocab_size
|
133 |
+
elif args.vocab_file is not None:
|
134 |
+
vocab = json.load(open(args.vocab_file))
|
135 |
+
true_vocab_size = len(vocab)
|
136 |
+
if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size:
|
137 |
+
print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.")
|
138 |
+
queue.put("exit")
|
139 |
+
exit(1)
|
140 |
+
else:
|
141 |
+
true_vocab_size = None
|
142 |
+
|
143 |
+
# short aliases
|
144 |
+
tp_size = margs.tensor_model_parallel_size
|
145 |
+
pp_size = margs.pipeline_model_parallel_size
|
146 |
+
|
147 |
+
# metadata
|
148 |
+
md = types.SimpleNamespace()
|
149 |
+
md.model_type = args.model_type
|
150 |
+
md.num_layers = margs.num_layers
|
151 |
+
md.hidden_size = margs.hidden_size
|
152 |
+
md.seq_length = margs.seq_length
|
153 |
+
md.num_attention_heads = margs.num_attention_heads
|
154 |
+
md.max_position_embeddings = margs.max_position_embeddings
|
155 |
+
md.tokenizer_type = margs.tokenizer_type
|
156 |
+
md.iteration = margs.iteration
|
157 |
+
md.params_dtype = margs.params_dtype
|
158 |
+
md.bert_binary_head = margs.bert_binary_head
|
159 |
+
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
|
160 |
+
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
|
161 |
+
md.true_vocab_size = true_vocab_size
|
162 |
+
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
|
163 |
+
|
164 |
+
# Get first pipe stage
|
165 |
+
mpu.initialize.set_pipeline_model_parallel_rank(0)
|
166 |
+
post_process = pp_size == 1
|
167 |
+
models = get_models(tp_size, md.params_dtype, True, post_process)
|
168 |
+
|
169 |
+
md.consumed_train_samples = consumed_train_samples
|
170 |
+
md.consumed_valid_samples = consumed_valid_samples
|
171 |
+
queue.put(md)
|
172 |
+
|
173 |
+
def queue_put(name, msg):
|
174 |
+
print(f"sending {name}")
|
175 |
+
msg["name"] = name
|
176 |
+
queue.put(msg)
|
177 |
+
|
178 |
+
# Send embeddings
|
179 |
+
message = {
|
180 |
+
"position embeddings": models[0].language_model.embedding.position_embeddings.weight.data,
|
181 |
+
"word embeddings": torch.cat(
|
182 |
+
[models[tp_rank].language_model.embedding.word_embeddings.weight.data for tp_rank in range(tp_size)],
|
183 |
+
dim = 0)
|
184 |
+
}
|
185 |
+
|
186 |
+
queue_put("embeddings", message)
|
187 |
+
|
188 |
+
total_layer_num = 0
|
189 |
+
for pp_rank in range(pp_size):
|
190 |
+
if pp_rank > 0:
|
191 |
+
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
|
192 |
+
post_process = pp_rank == pp_size - 1
|
193 |
+
models = get_models(tp_size, md.params_dtype, False, post_process)
|
194 |
+
for layer_num in range(len(models[0].language_model.encoder.layers)):
|
195 |
+
message = {}
|
196 |
+
|
197 |
+
# Get non-parallel tensors from tp_rank 0
|
198 |
+
layer = models[0].language_model.encoder.layers[layer_num]
|
199 |
+
message["input layernorm weight"] = layer.input_layernorm.weight.data
|
200 |
+
message["input layernorm bias"] = layer.input_layernorm.bias.data
|
201 |
+
message["dense bias"] = layer.self_attention.dense.bias.data
|
202 |
+
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
|
203 |
+
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
|
204 |
+
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
|
205 |
+
|
206 |
+
# Grab all parallel tensors for this layer
|
207 |
+
qkv_weight = []
|
208 |
+
qkv_bias = []
|
209 |
+
dense_weight = []
|
210 |
+
mlp_l0_weight = []
|
211 |
+
mlp_l0_bias = []
|
212 |
+
mlp_l1_weight = []
|
213 |
+
for tp_rank, model in enumerate(models):
|
214 |
+
layer = model.language_model.encoder.layers[layer_num]
|
215 |
+
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
|
216 |
+
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
|
217 |
+
dense_weight.append(layer.self_attention.dense.weight.data)
|
218 |
+
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
|
219 |
+
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
|
220 |
+
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
|
221 |
+
|
222 |
+
# concat them
|
223 |
+
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
|
224 |
+
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
|
225 |
+
message["dense weight"] = torch.cat(dense_weight, dim=1)
|
226 |
+
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
|
227 |
+
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
|
228 |
+
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
|
229 |
+
|
230 |
+
queue_put(f"transformer layer {total_layer_num}", message)
|
231 |
+
|
232 |
+
total_layer_num = total_layer_num + 1
|
233 |
+
|
234 |
+
# Send final layernorm from tp_rank 0
|
235 |
+
message = {
|
236 |
+
"weight": models[0].language_model.encoder.final_layernorm.weight.data,
|
237 |
+
"bias": models[0].language_model.encoder.final_layernorm.bias.data
|
238 |
+
}
|
239 |
+
queue_put("final layernorm", message)
|
240 |
+
|
241 |
+
# Send BERT lm head and binary head if it exists
|
242 |
+
if md.model_type == 'BERT':
|
243 |
+
print("Sending LM Pooler")
|
244 |
+
message = {
|
245 |
+
"weight": models[0].language_model.pooler.dense.weight.data,
|
246 |
+
"bias": models[0].language_model.pooler.dense.bias.data
|
247 |
+
}
|
248 |
+
queue_put("pooler", message)
|
249 |
+
|
250 |
+
message = {
|
251 |
+
"dense weight": models[0].lm_head.dense.weight.data,
|
252 |
+
"dense bias": models[0].lm_head.dense.bias.data,
|
253 |
+
"layernorm weight": models[0].lm_head.layernorm.weight.data,
|
254 |
+
"layernorm bias": models[0].lm_head.layernorm.bias.data
|
255 |
+
}
|
256 |
+
queue_put("lm head", message)
|
257 |
+
|
258 |
+
if md.bert_binary_head:
|
259 |
+
print("Sending BERT Binary head")
|
260 |
+
queue.put("binary head")
|
261 |
+
message = {
|
262 |
+
"weight": models[0].binary_head.weight.data,
|
263 |
+
"bias": models[0].binary_head.bias.data
|
264 |
+
}
|
265 |
+
queue_put("binary head", message)
|
266 |
+
queue.put("done")
|
267 |
+
|
268 |
+
def load_checkpoint(queue, args):
|
269 |
+
try:
|
270 |
+
_load_checkpoint(queue, args)
|
271 |
+
except:
|
272 |
+
queue.put("exit")
|
273 |
+
raise
|
tools/checkpoint_saver_megatron.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from collections.abc import Mapping
|
3 |
+
import concurrent.futures
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
def add_arguments(parser):
|
10 |
+
group = parser.add_argument_group(title='Megatron saver')
|
11 |
+
|
12 |
+
group.add_argument('--megatron-path', type=str, default=None,
|
13 |
+
help='Base directory of Megatron repository')
|
14 |
+
|
15 |
+
group.add_argument('--target-tensor-parallel-size', type=int,
|
16 |
+
help='Target tensor model parallel size, defaults to the tensor parallel size '
|
17 |
+
'in the input checkpoint if provided by the loader, otherwise to 1')
|
18 |
+
group.add_argument('--target-pipeline-parallel-size', type=int,
|
19 |
+
help='Target tensor model parallel size, default to the pipeline parall size '
|
20 |
+
'in the input checkpoint if provided by the loader, otherwise to 1')
|
21 |
+
|
22 |
+
def save_checkpoint(queue, args):
|
23 |
+
|
24 |
+
# Search in directory above this
|
25 |
+
sys.path.append(os.path.abspath(
|
26 |
+
os.path.join(os.path.dirname(__file__),
|
27 |
+
os.path.pardir)))
|
28 |
+
if args.megatron_path is not None:
|
29 |
+
sys.path.insert(0, args.megatron_path)
|
30 |
+
|
31 |
+
try:
|
32 |
+
from megatron.arguments import (parse_args, validate_args)
|
33 |
+
from megatron.checkpointing import save_checkpoint
|
34 |
+
from megatron.global_vars import set_global_variables, get_args
|
35 |
+
from megatron.model import ModelType
|
36 |
+
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
|
37 |
+
from megatron import mpu, fused_kernels
|
38 |
+
except ModuleNotFoundError:
|
39 |
+
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
|
40 |
+
exit(1)
|
41 |
+
|
42 |
+
def queue_get(name=None):
|
43 |
+
val = queue.get()
|
44 |
+
if val == "exit":
|
45 |
+
print("Loader exited, exiting saver")
|
46 |
+
exit(1)
|
47 |
+
if name is not None and args.checking and val["name"] != name:
|
48 |
+
val_name = val["name"]
|
49 |
+
print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.')
|
50 |
+
exit(1)
|
51 |
+
if name is not None:
|
52 |
+
print(f"received {name}")
|
53 |
+
return val
|
54 |
+
|
55 |
+
def check_message(msg):
|
56 |
+
if not args.checking:
|
57 |
+
return
|
58 |
+
msg_name = msg.pop("name")
|
59 |
+
if len(msg.keys()) > 0:
|
60 |
+
print(f"Unexpected values in {msg_name}:")
|
61 |
+
for key in msg.keys():
|
62 |
+
print(f" {key}")
|
63 |
+
print(f"Exiting. If you want to ignore this, use the argument --no-checking.")
|
64 |
+
exit(1)
|
65 |
+
|
66 |
+
|
67 |
+
md = queue_get()
|
68 |
+
|
69 |
+
if args.target_tensor_parallel_size is None:
|
70 |
+
if hasattr(md, 'previous_tensor_parallel_size'):
|
71 |
+
args.target_tensor_parallel_size = md.previous_tensor_parallel_size
|
72 |
+
else:
|
73 |
+
print("loader did not provide a tensor parallel size and --target-tensor-parallel-size not provided on command line. "
|
74 |
+
"Default to 1.")
|
75 |
+
args.target_tensor_parallel_size = 1
|
76 |
+
|
77 |
+
if args.target_pipeline_parallel_size is None:
|
78 |
+
if hasattr(md, 'previous_pipeline_parallel_size'):
|
79 |
+
args.target_pipeline_parallel_size = md.previous_pipeline_parallel_size
|
80 |
+
else:
|
81 |
+
print("loader did not provide a pipeline parallel size and --target-pipeline-parallel-size not provided on command line. "
|
82 |
+
"Default to 1.")
|
83 |
+
args.target_pipeline_parallel_size = 1
|
84 |
+
|
85 |
+
|
86 |
+
# Arguments do sanity checks on the world size, but we don't care,
|
87 |
+
# so trick it into thinking we are plenty of processes
|
88 |
+
if args.target_tensor_parallel_size is not None and args.target_pipeline_parallel_size is not None:
|
89 |
+
os.environ["WORLD_SIZE"] = f'{args.target_tensor_parallel_size * args.target_pipeline_parallel_size}'
|
90 |
+
|
91 |
+
# We want all arguments to come from us
|
92 |
+
sys.argv = ['script.py',
|
93 |
+
'--num-layers', str(md.num_layers),
|
94 |
+
'--hidden-size', str(md.hidden_size),
|
95 |
+
'--seq-length', str(md.seq_length),
|
96 |
+
'--num-attention-heads', str(md.num_attention_heads),
|
97 |
+
'--max-position-embeddings', str(md.max_position_embeddings),
|
98 |
+
'--tokenizer-type', str(md.tokenizer_type),
|
99 |
+
'--tensor-model-parallel-size', str(args.target_tensor_parallel_size),
|
100 |
+
'--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size),
|
101 |
+
'--no-masked-softmax-fusion',
|
102 |
+
'--no-bias-gelu-fusion',
|
103 |
+
'--no-bias-dropout-fusion',
|
104 |
+
'--use-cpu-initialization',
|
105 |
+
'--micro-batch-size', '1',
|
106 |
+
'--no-load-optim',
|
107 |
+
'--no-load-rng',
|
108 |
+
'--no-save-optim',
|
109 |
+
'--no-save-rng',
|
110 |
+
'--no-initialization',
|
111 |
+
'--save-interval', '1',
|
112 |
+
'--save', args.save_dir
|
113 |
+
]
|
114 |
+
|
115 |
+
if md.make_vocab_size_divisible_by is not None:
|
116 |
+
sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)])
|
117 |
+
if md.params_dtype == torch.float16:
|
118 |
+
sys.argv.append('--fp16')
|
119 |
+
elif md.params_dtype == torch.bfloat16:
|
120 |
+
sys.argv.append('--bf16')
|
121 |
+
|
122 |
+
if md.model_type == 'BERT' and not md.bert_binary_head:
|
123 |
+
sys.argv.append('--bert-no-binary-head')
|
124 |
+
|
125 |
+
margs = parse_args()
|
126 |
+
validate_args(margs)
|
127 |
+
set_global_variables(margs)
|
128 |
+
|
129 |
+
# margs = megatron args
|
130 |
+
margs = get_args()
|
131 |
+
|
132 |
+
if hasattr(md, 'consumed_train_samples'):
|
133 |
+
margs.consumed_train_samples = md.consumed_train_samples
|
134 |
+
margs.consumed_valid_samples = md.consumed_valid_samples
|
135 |
+
print(f"Setting consumed_train_samples to {margs.consumed_train_samples}"
|
136 |
+
f" and consumed_valid_samples to {margs.consumed_valid_samples}")
|
137 |
+
else:
|
138 |
+
print("consumed_train_samples not provided.")
|
139 |
+
|
140 |
+
# Determine how to make our models
|
141 |
+
if md.model_type == 'GPT':
|
142 |
+
from pretrain_gpt import model_provider
|
143 |
+
margs.model_type = ModelType.encoder_or_decoder
|
144 |
+
elif md.model_type == 'BERT':
|
145 |
+
from pretrain_bert import model_provider
|
146 |
+
margs.model_type = ModelType.encoder_or_decoder
|
147 |
+
else:
|
148 |
+
raise Exception(f'unrecognized model type: {args.model_type}')
|
149 |
+
|
150 |
+
def get_models(count, dtype, pre_process, post_process):
|
151 |
+
models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)]
|
152 |
+
return models
|
153 |
+
|
154 |
+
# fake initializing distributed
|
155 |
+
mpu.initialize.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size)
|
156 |
+
mpu.initialize.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
|
157 |
+
mpu.initialize.set_tensor_model_parallel_rank(0)
|
158 |
+
mpu.initialize.set_pipeline_model_parallel_rank(0)
|
159 |
+
fused_kernels.load(margs)
|
160 |
+
|
161 |
+
# Embeddings
|
162 |
+
#-----------
|
163 |
+
embeddings_msg = queue_get("embeddings")
|
164 |
+
|
165 |
+
pos_embed = embeddings_msg.pop("position embeddings")
|
166 |
+
orig_word_embed = embeddings_msg.pop("word embeddings")
|
167 |
+
check_message(embeddings_msg)
|
168 |
+
|
169 |
+
# Deal with padding
|
170 |
+
if md.true_vocab_size is not None:
|
171 |
+
# figure out what our padded vocab size is
|
172 |
+
orig_vocab_size = orig_word_embed.shape[0]
|
173 |
+
margs.padded_vocab_size = _vocab_size_with_padding(md.true_vocab_size, margs)
|
174 |
+
|
175 |
+
# Cut out extra padding we don't need
|
176 |
+
if orig_vocab_size > margs.padded_vocab_size:
|
177 |
+
full_word_embed = orig_word_embed[0:margs.padded_vocab_size,:]
|
178 |
+
|
179 |
+
# Expanding embedding to larger size by replicating final entry
|
180 |
+
elif orig_vocab_size < margs.padded_vocab_size:
|
181 |
+
padding_size = margs.padded_vocab_size - orig_vocab_size
|
182 |
+
|
183 |
+
full_word_embed = torch.cat((
|
184 |
+
orig_word_embed,
|
185 |
+
orig_word_embed[-1].unsqueeze(0).expand(padding_size, -1)))
|
186 |
+
|
187 |
+
# Same size!
|
188 |
+
else:
|
189 |
+
full_word_embed = orig_word_embed
|
190 |
+
else:
|
191 |
+
print("Original vocab size not specified, leaving embedding table as-is. "
|
192 |
+
"If you've changed the tensor parallel size this could cause problems.")
|
193 |
+
margs.padded_vocab_size = orig_word_embed.shape[0]
|
194 |
+
full_word_embed = orig_word_embed
|
195 |
+
|
196 |
+
# Split into new tensor model parallel sizes
|
197 |
+
out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)
|
198 |
+
|
199 |
+
# Make models for first pipeline stage and fill in embeddings
|
200 |
+
mpu.initialize.set_pipeline_model_parallel_rank(0)
|
201 |
+
post_process = args.target_pipeline_parallel_size == 1
|
202 |
+
models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
|
203 |
+
for tp_rank, model in enumerate(models):
|
204 |
+
print(f"word embeddings shape {model.language_model.embedding.word_embeddings.weight.shape}")
|
205 |
+
model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
|
206 |
+
model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed)
|
207 |
+
|
208 |
+
# Transformer layers
|
209 |
+
#-------------------
|
210 |
+
total_layer_num = 0
|
211 |
+
for pp_rank in range(args.target_pipeline_parallel_size):
|
212 |
+
# For later pipeline parallel ranks, make the new models
|
213 |
+
if pp_rank > 0:
|
214 |
+
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
|
215 |
+
post_process = pp_rank == args.target_pipeline_parallel_size - 1
|
216 |
+
models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)
|
217 |
+
|
218 |
+
for layer in range(len(models[0].language_model.encoder.layers)):
|
219 |
+
msg = queue_get(f"transformer layer {total_layer_num}")
|
220 |
+
|
221 |
+
# duplicated tensors
|
222 |
+
input_layernorm_weight = msg.pop("input layernorm weight")
|
223 |
+
input_layernorm_bias = msg.pop("input layernorm bias")
|
224 |
+
dense_bias = msg.pop("dense bias")
|
225 |
+
post_layernorm_weight = msg.pop("post layernorm weight")
|
226 |
+
post_layernorm_bias = msg.pop("post layernorm bias")
|
227 |
+
mlp_l1_bias = msg.pop("mlp l1 bias")
|
228 |
+
|
229 |
+
# Split up the parallel tensors
|
230 |
+
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
|
231 |
+
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
|
232 |
+
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
|
233 |
+
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
|
234 |
+
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
|
235 |
+
mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1)
|
236 |
+
|
237 |
+
# Save them to the model
|
238 |
+
for tp_rank in range(args.target_tensor_parallel_size):
|
239 |
+
l = models[tp_rank].language_model.encoder.layers[layer]
|
240 |
+
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
|
241 |
+
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
|
242 |
+
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
|
243 |
+
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
|
244 |
+
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
|
245 |
+
l.self_attention.dense.bias.data.copy_(dense_bias)
|
246 |
+
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
|
247 |
+
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
|
248 |
+
l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
|
249 |
+
l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank])
|
250 |
+
l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
|
251 |
+
l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias)
|
252 |
+
total_layer_num = total_layer_num + 1
|
253 |
+
check_message(msg)
|
254 |
+
|
255 |
+
|
256 |
+
if post_process:
|
257 |
+
msg = queue_get("final layernorm")
|
258 |
+
final_layernorm_weight = msg.pop("weight")
|
259 |
+
final_layernorm_bias = msg.pop("bias")
|
260 |
+
for tp_rank in range(args.target_tensor_parallel_size):
|
261 |
+
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
|
262 |
+
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
|
263 |
+
if pp_rank != 0:
|
264 |
+
# Copy word embeddings to final pipeline rank
|
265 |
+
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
|
266 |
+
del final_layernorm_weight
|
267 |
+
del final_layernorm_bias
|
268 |
+
check_message(msg)
|
269 |
+
|
270 |
+
msg = queue_get()
|
271 |
+
if msg != "done" and msg["name"] == "pooler":
|
272 |
+
if not hasattr(models[0].language_model, 'pooler'):
|
273 |
+
print("ERROR: got a pooler, but model does not have one")
|
274 |
+
exit(1)
|
275 |
+
print("received pooler")
|
276 |
+
pooler_weight = msg.pop("weight")
|
277 |
+
pooler_bias = msg.pop("bias")
|
278 |
+
for tp_rank in range(args.target_tensor_parallel_size):
|
279 |
+
models[tp_rank].language_model.pooler.dense.weight.data.copy_(pooler_weight)
|
280 |
+
models[tp_rank].language_model.pooler.dense.bias.data.copy_(pooler_bias)
|
281 |
+
del pooler_weight
|
282 |
+
del pooler_bias
|
283 |
+
check_message(msg)
|
284 |
+
msg = queue_get()
|
285 |
+
|
286 |
+
if msg != "done" and msg["name"] == "lm head":
|
287 |
+
if not hasattr(models[0], 'lm_head'):
|
288 |
+
print("ERROR: got an lm head, but model does not have one")
|
289 |
+
exit(1)
|
290 |
+
print("received lm head")
|
291 |
+
lm_head_dense_weight = msg.pop("dense weight")
|
292 |
+
lm_head_dense_bias = msg.pop("dense bias")
|
293 |
+
lm_head_layernorm_weight = msg.pop("layernorm weight")
|
294 |
+
lm_head_layernorm_bias = msg.pop("layernorm bias")
|
295 |
+
for tp_rank in range(args.target_tensor_parallel_size):
|
296 |
+
models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight)
|
297 |
+
models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias)
|
298 |
+
models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight)
|
299 |
+
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
|
300 |
+
check_message(msg)
|
301 |
+
msg = queue_get()
|
302 |
+
|
303 |
+
if msg != "done" and msg["name"] == "binary head":
|
304 |
+
if not hasattr(models[0], 'binary_head'):
|
305 |
+
print("ERROR: got a binary head, but model does not have one")
|
306 |
+
exit(1)
|
307 |
+
print("received binary head")
|
308 |
+
binary_head_weight = msg.pop("weight")
|
309 |
+
binary_head_bias = msg.pop("bias")
|
310 |
+
for tp_rank in range(args.target_tensor_parallel_size):
|
311 |
+
models[tp_rank].binary_head.weight.data.copy_(binary_head_weight)
|
312 |
+
models[tp_rank].binary_head.bias.data.copy_(binary_head_bias)
|
313 |
+
check_message(msg)
|
314 |
+
msg = queue_get()
|
315 |
+
|
316 |
+
if msg != "done":
|
317 |
+
print("ERROR: got some more data but was expecting to be done")
|
318 |
+
|
319 |
+
for tp_rank in range(args.target_tensor_parallel_size):
|
320 |
+
mpu.initialize.set_tensor_model_parallel_rank(tp_rank)
|
321 |
+
save_checkpoint(md.iteration, [models[tp_rank]], None, None)
|
322 |
+
print("Done!")
|
tools/checkpoint_util.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import importlib
|
3 |
+
import torch.multiprocessing as mp
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
# A loader is a python file with at least two functions
|
8 |
+
# - add_arguments - takes in a parser and adds any arguments needed
|
9 |
+
# - load_checkpoint - takes in the queue and parsed arguments
|
10 |
+
|
11 |
+
# A saver is similar but has save_checkpoint instead of
|
12 |
+
# load_checkpoint
|
13 |
+
|
14 |
+
# The loader and saver process are each given a queue, the loader
|
15 |
+
# should load the checkpoint and send the weights in messages in the
|
16 |
+
# following order, the saver should receive them in this order and
|
17 |
+
# save the checkpoints. A message consists of a python dictionary with
|
18 |
+
# a "name" for error checking and an entry for each tensor as
|
19 |
+
# indicated below. Note that the weight sent over the queue are the
|
20 |
+
# full model weights, nothing split.
|
21 |
+
|
22 |
+
# If the loader ever sends "exit" to the queue, that means something
|
23 |
+
# went wrong and it is exiting.
|
24 |
+
|
25 |
+
# - Metadata Namespace with the following attributes:
|
26 |
+
# model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line)
|
27 |
+
# num_layers - Number of transformer layers
|
28 |
+
# hidden_size
|
29 |
+
# seq_length
|
30 |
+
# num_attention_heads
|
31 |
+
# max_position_embeddings
|
32 |
+
# tokenizer_type
|
33 |
+
# iteration
|
34 |
+
# params_dtype
|
35 |
+
# bert_binary_head - Used only if model_type is BERT
|
36 |
+
# previous_tensor_parallel_size - Optional
|
37 |
+
# previous_pipeline_parallel_size - Optional
|
38 |
+
# true_vocab_size
|
39 |
+
# make_vocab_size_divisble_by
|
40 |
+
# consumed_train_samples
|
41 |
+
# consumed_valid_samples
|
42 |
+
# messages
|
43 |
+
# {
|
44 |
+
# "name": "embeddings"
|
45 |
+
# "position embeddings"
|
46 |
+
# "word embeddings"
|
47 |
+
# }
|
48 |
+
# (for each transformer layer):
|
49 |
+
# {
|
50 |
+
# "name": "transformer layer N"
|
51 |
+
# "input layernorm weight"
|
52 |
+
# "input layernorm bias"
|
53 |
+
# "qkv weight"
|
54 |
+
# "qkv bias"
|
55 |
+
# "dense weight"
|
56 |
+
# "dense bias"
|
57 |
+
# "post layernorm weight"
|
58 |
+
# "post layernorm bias"
|
59 |
+
# "mlp l0 weight"
|
60 |
+
# "mlp l0 bias"
|
61 |
+
# "mlp l1 weight"
|
62 |
+
# "mlp l1 bias"
|
63 |
+
# }
|
64 |
+
# {
|
65 |
+
# "name": "final layer norm"
|
66 |
+
# "weight"
|
67 |
+
# "bias"
|
68 |
+
# }
|
69 |
+
# if present (i.e. for BERT):
|
70 |
+
# {
|
71 |
+
# "name": "pooler"
|
72 |
+
# "weight"
|
73 |
+
# "bias"
|
74 |
+
# }
|
75 |
+
# {
|
76 |
+
# "name": "lm head"
|
77 |
+
# "dense weight"
|
78 |
+
# "dense bias"
|
79 |
+
# "layernorm weight"
|
80 |
+
# "layernorm bias"
|
81 |
+
# }
|
82 |
+
# {
|
83 |
+
# "name": "binary head"
|
84 |
+
# "weight"
|
85 |
+
# "bias"
|
86 |
+
# }
|
87 |
+
# - "done"
|
88 |
+
|
89 |
+
def load_plugin(plugin_type, name):
|
90 |
+
module_name = f"checkpoint_{plugin_type}_{name}"
|
91 |
+
try:
|
92 |
+
plugin = importlib.import_module(module_name)
|
93 |
+
except ModuleNotFoundError:
|
94 |
+
module_name = name
|
95 |
+
try:
|
96 |
+
plugin = importlib.import_module(module_name)
|
97 |
+
except ModuleNotFoundError:
|
98 |
+
sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.")
|
99 |
+
|
100 |
+
if not hasattr(plugin, 'add_arguments'):
|
101 |
+
sys.exit(f"{module_name} module is not a plugin. Exiting.")
|
102 |
+
|
103 |
+
print(f"Loaded {module_name} as the {plugin_type}.")
|
104 |
+
return plugin
|
105 |
+
|
106 |
+
def main():
|
107 |
+
import argparse
|
108 |
+
parser = argparse.ArgumentParser(description="Megatron Checkpoint Utility Arguments",
|
109 |
+
allow_abbrev=False, conflict_handler='resolve')
|
110 |
+
|
111 |
+
parser.add_argument('--model-type', type=str, required=True,
|
112 |
+
choices=['GPT', 'BERT'],
|
113 |
+
help='Type of the model')
|
114 |
+
parser.add_argument('--loader', type=str, default='megatron',
|
115 |
+
help='Module name to load checkpoint, should be on python path')
|
116 |
+
parser.add_argument('--saver', type=str, default='megatron',
|
117 |
+
help='Module name to save checkpoint, shdoul be on python path')
|
118 |
+
parser.add_argument('--load-dir', type=str, required=True,
|
119 |
+
help='Directory to load model checkpoint from')
|
120 |
+
parser.add_argument('--save-dir', type=str, required=True,
|
121 |
+
help='Directory to save model checkpoint to')
|
122 |
+
parser.add_argument('--max-queue-size', type=int, default=50,
|
123 |
+
help='Maximum number of tensors in the queue')
|
124 |
+
parser.add_argument('--no-checking', action='store_false',
|
125 |
+
help='Do not perform checking on the name and ordering of weights',
|
126 |
+
dest='checking')
|
127 |
+
|
128 |
+
known_args, _ = parser.parse_known_args()
|
129 |
+
loader = load_plugin('loader', known_args.loader)
|
130 |
+
saver = load_plugin('saver', known_args.saver)
|
131 |
+
|
132 |
+
loader.add_arguments(parser)
|
133 |
+
saver.add_arguments(parser)
|
134 |
+
|
135 |
+
args = parser.parse_args()
|
136 |
+
|
137 |
+
queue = mp.Queue(maxsize=args.max_queue_size)
|
138 |
+
|
139 |
+
print("Starting saver...")
|
140 |
+
saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args))
|
141 |
+
saver_proc.start()
|
142 |
+
|
143 |
+
print("Starting loader...")
|
144 |
+
loader.load_checkpoint(queue, args)
|
145 |
+
|
146 |
+
print("Waiting for saver to complete...")
|
147 |
+
saver_proc.join()
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == '__main__':
|
151 |
+
main()
|
tools/linter.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import pathlib
|
4 |
+
import subprocess
|
5 |
+
|
6 |
+
|
7 |
+
def recursively_lint_files():
|
8 |
+
"""Recursively lint all python files in chosen subdirectories of megatron-lm"""
|
9 |
+
|
10 |
+
try:
|
11 |
+
import autopep8
|
12 |
+
except ModuleNotFoundError:
|
13 |
+
print("Please first install autopep8 via `pip install autopep8`")
|
14 |
+
return
|
15 |
+
|
16 |
+
# get all python file paths from top level directory
|
17 |
+
file_dir = str(pathlib.Path(__file__).parent.absolute())
|
18 |
+
working_dir = osp.join(file_dir, os.pardir)
|
19 |
+
all_py_paths = set(os.path.join(working_dir, fname)
|
20 |
+
for fname in os.listdir(working_dir) if ".py" in fname)
|
21 |
+
|
22 |
+
# get all python file paths from chosen subdirectories
|
23 |
+
check_dirs = ['docker', 'megatron', 'openwebtext', 'scripts', 'tasks']
|
24 |
+
for sub_dir in check_dirs:
|
25 |
+
for path, _, fnames in os.walk(osp.join(working_dir, sub_dir)):
|
26 |
+
all_py_paths.update(set(osp.join(path, fname) for fname in fnames if ".py" in fname))
|
27 |
+
|
28 |
+
print("Linting the following: ")
|
29 |
+
for py_path in all_py_paths:
|
30 |
+
print(py_path)
|
31 |
+
command = 'autopep8 --max-line-length 100 --aggressive --in-place {}'.format(py_path)
|
32 |
+
subprocess.check_call(command)
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
recursively_lint_files()
|
tools/merge_datasets.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import argparse
|
5 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
|
6 |
+
os.path.pardir)))
|
7 |
+
|
8 |
+
from megatron.data import indexed_dataset
|
9 |
+
|
10 |
+
|
11 |
+
def main(args):
|
12 |
+
|
13 |
+
prefixes = set()
|
14 |
+
for basename in os.listdir(args.input):
|
15 |
+
prefix, ext = os.path.splitext(basename)
|
16 |
+
|
17 |
+
if prefix in prefixes:
|
18 |
+
continue
|
19 |
+
|
20 |
+
if not os.path.isfile(os.path.join(args.input, basename)):
|
21 |
+
continue
|
22 |
+
|
23 |
+
ext_pair = '.bin' if ext == '.idx' else '.idx'
|
24 |
+
assert os.path.isfile(os.path.join(args.input, prefix) + ext_pair), \
|
25 |
+
f'ERROR: {ext_pair} file not provided for {os.path.join(args.input, prefix)}'
|
26 |
+
|
27 |
+
prefixes.add(prefix)
|
28 |
+
|
29 |
+
builder = None
|
30 |
+
for prefix in sorted(prefixes):
|
31 |
+
if builder is None:
|
32 |
+
dataset = indexed_dataset.make_dataset(os.path.join(args.input, prefix), 'infer')
|
33 |
+
|
34 |
+
if isinstance(dataset, indexed_dataset.MMapIndexedDataset):
|
35 |
+
builder = indexed_dataset.MMapIndexedDatasetBuilder(args.output_prefix + '.bin', dtype=dataset._index.dtype)
|
36 |
+
else:
|
37 |
+
builder = indexed_dataset.IndexedDatasetBuilder(args.output_prefix + '.bin')
|
38 |
+
|
39 |
+
del dataset
|
40 |
+
|
41 |
+
builder.merge_file_(os.path.join(args.input, prefix))
|
42 |
+
|
43 |
+
builder.finalize(args.output_prefix + '.idx')
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
parser = argparse.ArgumentParser()
|
48 |
+
|
49 |
+
group = parser.add_argument_group(title='input data')
|
50 |
+
group.add_argument('--input', type=str, required=True,
|
51 |
+
help='Path to directory containing all document files to merge')
|
52 |
+
|
53 |
+
group = parser.add_argument_group(title='output data')
|
54 |
+
group.add_argument('--output-prefix', type=str, required=True,
|
55 |
+
help='Path to binary output file without suffix')
|
56 |
+
|
57 |
+
args = parser.parse_args()
|
58 |
+
|
59 |
+
assert os.path.isdir(args.input), \
|
60 |
+
f'ERROR: {args.input} is not a directory or does not exist'
|
61 |
+
|
62 |
+
assert os.path.isdir(os.path.dirname(args.output_prefix)), \
|
63 |
+
f'ERROR: {os.path.dirname(args.output_prefix)} is not a directory or does not exist'
|
64 |
+
|
65 |
+
main(args)
|
66 |
+
|
tools/merge_mp_partitions.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Merge model parallel partitions."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
import re
|
20 |
+
import sys
|
21 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
|
22 |
+
os.path.pardir)))
|
23 |
+
|
24 |
+
import torch
|
25 |
+
|
26 |
+
from megatron import mpu
|
27 |
+
from megatron.checkpointing import load_checkpoint, save_checkpoint
|
28 |
+
from megatron.checkpointing import ensure_directory_exists
|
29 |
+
from megatron.checkpointing import get_checkpoint_name
|
30 |
+
from megatron.checkpointing import get_checkpoint_version
|
31 |
+
from megatron.checkpointing import get_checkpoint_tracker_filename
|
32 |
+
from megatron.global_vars import set_global_variables, get_args
|
33 |
+
from megatron.global_vars import rebuild_tokenizer
|
34 |
+
|
35 |
+
|
36 |
+
def split_into_partitions(tensor, num_partitions, partition_dim, stride):
|
37 |
+
|
38 |
+
per_partition_size = mpu.utils.divide(tensor.size(partition_dim),
|
39 |
+
num_partitions)
|
40 |
+
per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
|
41 |
+
|
42 |
+
partitions_list = torch.split(tensor,
|
43 |
+
per_partition_per_stride_size,
|
44 |
+
dim=partition_dim)
|
45 |
+
|
46 |
+
partitions = []
|
47 |
+
for i in range(num_partitions):
|
48 |
+
partition = torch.cat(partitions_list[i::num_partitions],
|
49 |
+
dim=partition_dim)
|
50 |
+
partitions.append(partition)
|
51 |
+
|
52 |
+
return partitions
|
53 |
+
|
54 |
+
|
55 |
+
def merge_partitions(merged, partitions, partition_dim, stride):
|
56 |
+
|
57 |
+
# Number and size of each partition.
|
58 |
+
num_partitions = len(partitions)
|
59 |
+
per_partition_size = None
|
60 |
+
for partition in partitions:
|
61 |
+
if per_partition_size is None:
|
62 |
+
per_partition_size = partition.size(partition_dim)
|
63 |
+
else:
|
64 |
+
assert per_partition_size == partition.size(partition_dim)
|
65 |
+
|
66 |
+
def concat_partitions(partitions_):
|
67 |
+
with torch.no_grad():
|
68 |
+
if (per_partition_size * num_partitions) == merged.size(
|
69 |
+
partition_dim):
|
70 |
+
torch.cat(partitions_, dim=partition_dim, out=merged)
|
71 |
+
else:
|
72 |
+
print(' ***WARNING*** sizes do not match. Will cut '
|
73 |
+
'the merged partitions by {} along dimension {} '
|
74 |
+
'to reduce the size from {} to {} ...'.format(
|
75 |
+
(per_partition_size * num_partitions) - \
|
76 |
+
merged.size(partition_dim), partition_dim,
|
77 |
+
per_partition_size * num_partitions,
|
78 |
+
merged.size(partition_dim)))
|
79 |
+
merged_ = torch.cat(partitions_, dim=partition_dim)
|
80 |
+
merged_split = torch.split(merged_, merged.size(partition_dim),
|
81 |
+
dim=partition_dim)
|
82 |
+
merged_ = merged_split[0]
|
83 |
+
assert merged_.size(partition_dim) == merged.size(partition_dim)
|
84 |
+
merged.data.copy_(merged_.data)
|
85 |
+
|
86 |
+
# If stride is 1, then do simple concatination.
|
87 |
+
if stride == 1:
|
88 |
+
concat_partitions(partitions)
|
89 |
+
return
|
90 |
+
|
91 |
+
# For none unity strides, first split based on stride and then group.
|
92 |
+
per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
|
93 |
+
# Chunk and build a list.
|
94 |
+
chunks = None
|
95 |
+
for i, partition in enumerate(partitions):
|
96 |
+
chunk = torch.split(partition,
|
97 |
+
per_partition_per_stride_size,
|
98 |
+
dim=partition_dim)
|
99 |
+
|
100 |
+
if chunks is None:
|
101 |
+
chunks = [0]*(num_partitions*len(chunk))
|
102 |
+
chunks[i::num_partitions] = chunk
|
103 |
+
|
104 |
+
# Concatinate.
|
105 |
+
concat_partitions(chunks)
|
106 |
+
|
107 |
+
return
|
108 |
+
|
109 |
+
|
110 |
+
def get_model(model_type):
|
111 |
+
|
112 |
+
if model_type == 'BERT':
|
113 |
+
from pretrain_bert import model_provider
|
114 |
+
elif model_type == 'GPT':
|
115 |
+
from pretrain_gpt import model_provider
|
116 |
+
elif model_type == 'RACE':
|
117 |
+
from tasks.race.finetune import model_provider
|
118 |
+
elif model_type == ['MNLI', 'QQP']:
|
119 |
+
num_classes = 2
|
120 |
+
if model_type == 'MNLI':
|
121 |
+
num_classes = 3
|
122 |
+
from megatron.model.classification import Classification
|
123 |
+
def model_provider():
|
124 |
+
return Classification(num_classes=num_classes, num_tokentypes=2)
|
125 |
+
else:
|
126 |
+
raise Exception('unrecognized model type: {}'.format(model_type))
|
127 |
+
|
128 |
+
model = model_provider()
|
129 |
+
model = model.half()
|
130 |
+
|
131 |
+
return model
|
132 |
+
|
133 |
+
|
134 |
+
def get_parallel_checkpoint_name(path):
|
135 |
+
|
136 |
+
tracker_filename = get_checkpoint_tracker_filename(path)
|
137 |
+
iteration = 0
|
138 |
+
with open(tracker_filename, 'r') as f:
|
139 |
+
metastring = f.read().strip()
|
140 |
+
iteration = int(metastring)
|
141 |
+
assert iteration > 0
|
142 |
+
checkpoint_name = get_checkpoint_name(path, iteration)
|
143 |
+
|
144 |
+
return checkpoint_name, iteration
|
145 |
+
|
146 |
+
|
147 |
+
def test_split_merge():
|
148 |
+
|
149 |
+
print('testing split and merge ...')
|
150 |
+
|
151 |
+
#[QKV.ROW-COL]
|
152 |
+
tensor = torch.FloatTensor([[1.11, 1.12, 1.13, 1.14, 1.15],
|
153 |
+
[1.21, 1.22, 1.23, 1.24, 1.25],
|
154 |
+
[1.31, 1.32, 1.33, 1.34, 1.35],
|
155 |
+
[1.41, 1.42, 1.43, 1.44, 1.45],
|
156 |
+
[2.11, 2.12, 2.13, 2.14, 2.15],
|
157 |
+
[2.21, 2.22, 2.23, 2.24, 2.25],
|
158 |
+
[2.31, 2.32, 2.33, 2.34, 2.35],
|
159 |
+
[2.41, 2.42, 2.43, 2.44, 2.45],
|
160 |
+
[3.11, 3.12, 3.13, 3.14, 3.15],
|
161 |
+
[3.21, 3.22, 3.23, 3.24, 3.25],
|
162 |
+
[3.31, 3.32, 3.33, 3.34, 3.35],
|
163 |
+
[3.41, 3.42, 3.43, 3.44, 3.45]])
|
164 |
+
|
165 |
+
num_partitions = 2
|
166 |
+
partition_dim = 0
|
167 |
+
stride = 3
|
168 |
+
partitions = split_into_partitions(tensor, num_partitions,
|
169 |
+
partition_dim, stride)
|
170 |
+
|
171 |
+
merged = torch.zeros_like(tensor)
|
172 |
+
merge_partitions(merged, partitions, partition_dim, stride)
|
173 |
+
|
174 |
+
max_error = (merged - tensor).abs().max()
|
175 |
+
print(' > max error (should be zero): {}'.format(max_error))
|
176 |
+
|
177 |
+
|
178 |
+
def get_mp_merge_args(parser):
|
179 |
+
"""Provide extra arguments required for merging."""
|
180 |
+
group = parser.add_argument_group(title='mp merge')
|
181 |
+
|
182 |
+
group.add_argument('--model-type', type=str, required=True,
|
183 |
+
choices=['BERT', 'GPT', 'RACE', 'MNLI', 'QQP'],
|
184 |
+
help='Type of the mdoel.')
|
185 |
+
group.add_argument('--target-pipeline-model-parallel-size', type=int, default=1,
|
186 |
+
help='Degree of pipeline model parallelism in output model.')
|
187 |
+
|
188 |
+
return parser
|
189 |
+
|
190 |
+
|
191 |
+
def main():
|
192 |
+
|
193 |
+
# Arguments do sanity checks on the world size, but we don't care,
|
194 |
+
# so trick it into thinking we are plenty of processes
|
195 |
+
os.environ["WORLD_SIZE"] = f'{2**31}'
|
196 |
+
|
197 |
+
# Args
|
198 |
+
set_global_variables(extra_args_provider=get_mp_merge_args,
|
199 |
+
args_defaults = {'use_cpu_initialization': True,
|
200 |
+
'micro_batch_size': 1,
|
201 |
+
'no_load_optim': True,
|
202 |
+
'no_load_rng': True,
|
203 |
+
'no_save_optim': True,
|
204 |
+
'no_save_rng': True,
|
205 |
+
'save_interval': 1})
|
206 |
+
args = get_args()
|
207 |
+
|
208 |
+
if args.pipeline_model_parallel_size > 1:
|
209 |
+
print("Checkpoints with pipeline model parallelism are not currently supported.")
|
210 |
+
exit()
|
211 |
+
|
212 |
+
model_type = args.model_type
|
213 |
+
orig_tensor_model_parallel_size = args.tensor_model_parallel_size
|
214 |
+
args.tensor_model_parallel_size = 1
|
215 |
+
tokenizer = rebuild_tokenizer(args)
|
216 |
+
|
217 |
+
print('\n merging model parallel partitions ...')
|
218 |
+
print(' > number of partitions: {}'.format(orig_tensor_model_parallel_size))
|
219 |
+
print(' > checkpoint path: {}'.format(args.load))
|
220 |
+
print(' > model parameters:')
|
221 |
+
print(' number of tokens ................ {} '.format(
|
222 |
+
tokenizer.vocab_size))
|
223 |
+
print(' number of layers ................ {}'.format(args.num_layers))
|
224 |
+
print(' hidden size ..................... {}'.format(args.hidden_size))
|
225 |
+
print(' number of attention heads ....... {}'.format(
|
226 |
+
args.num_attention_heads))
|
227 |
+
print(' maximum position embeddings ..... {}'.format(
|
228 |
+
args.max_position_embeddings))
|
229 |
+
|
230 |
+
# Full model.
|
231 |
+
print('> building the full model ...')
|
232 |
+
mpu.initialize.set_tensor_model_parallel_world_size(1)
|
233 |
+
mpu.initialize.set_tensor_model_parallel_rank(0)
|
234 |
+
mpu.initialize.set_pipeline_model_parallel_world_size(1)
|
235 |
+
mpu.initialize.set_pipeline_model_parallel_rank(0)
|
236 |
+
merged_model = get_model(model_type)
|
237 |
+
|
238 |
+
# Build and load partitions.
|
239 |
+
partitions = []
|
240 |
+
iteration = 0
|
241 |
+
args.tensor_model_parallel_size = orig_tensor_model_parallel_size
|
242 |
+
tokenizer = rebuild_tokenizer(args)
|
243 |
+
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
|
244 |
+
for rank in range(args.tensor_model_parallel_size):
|
245 |
+
# Reset these since load_checkpoint asserts they are 0, but we are loading
|
246 |
+
# multiple checkpoints in the same process and they get set each time
|
247 |
+
args.consumed_train_samples = 0
|
248 |
+
args.consumed_valid_samples = 0
|
249 |
+
|
250 |
+
mpu.initialize.set_tensor_model_parallel_rank(rank)
|
251 |
+
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
|
252 |
+
model_ = get_model(model_type)
|
253 |
+
print(f'> loading {checkpoint_name} ...')
|
254 |
+
load_checkpoint(model_, None, None)
|
255 |
+
print(f'> checkpoint version {get_checkpoint_version()}')
|
256 |
+
partitions.append(model_)
|
257 |
+
|
258 |
+
# Parameter generators so we can loop through them semiltaneouly.
|
259 |
+
merged_params_gen = merged_model.named_parameters()
|
260 |
+
partitions_params_gen = [partition.named_parameters()
|
261 |
+
for partition in partitions]
|
262 |
+
while True:
|
263 |
+
try:
|
264 |
+
|
265 |
+
# Get the params and check names.
|
266 |
+
name, merged_param = next(merged_params_gen)
|
267 |
+
print(' > working on {} ...'.format(name))
|
268 |
+
print(' merged type: {}, size: {}'.format(
|
269 |
+
merged_param.dtype, list(merged_param.size())))
|
270 |
+
partitions_param = []
|
271 |
+
for rank, partition_params_gen in enumerate(partitions_params_gen):
|
272 |
+
partition_name, partition_param = next(partition_params_gen)
|
273 |
+
assert partition_name == name
|
274 |
+
partitions_param.append(partition_param)
|
275 |
+
print(' partition {} type: {}, size: {}'.format(
|
276 |
+
rank, partition_param.dtype, list(partition_param.size())))
|
277 |
+
|
278 |
+
# For the non-parallel parameters, simply copy the rank 0 values.
|
279 |
+
if not hasattr(merged_param, 'tensor_model_parallel'):
|
280 |
+
print(' none-parallel parameter, simple copy from rank 0')
|
281 |
+
with torch.no_grad():
|
282 |
+
merged_param.data.copy_(partitions_param[0].data)
|
283 |
+
# For parallel parameters, merge the values
|
284 |
+
else:
|
285 |
+
dim = merged_param.partition_dim
|
286 |
+
stride = merged_param.partition_stride
|
287 |
+
print(f' parallel parameter merge with stride {stride} along '
|
288 |
+
f'dimention {dim}')
|
289 |
+
merge_partitions(merged_param,
|
290 |
+
partitions_param,
|
291 |
+
dim,
|
292 |
+
stride)
|
293 |
+
|
294 |
+
except StopIteration:
|
295 |
+
break
|
296 |
+
|
297 |
+
partitions = []
|
298 |
+
args.tensor_model_parallel_size = 1
|
299 |
+
args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
|
300 |
+
|
301 |
+
assert args.num_layers % args.pipeline_model_parallel_size == 0, \
|
302 |
+
'num_layers must be divisible by target pipeline model parallel size'
|
303 |
+
layers_per_part = args.num_layers // args.pipeline_model_parallel_size
|
304 |
+
|
305 |
+
tokenizer = rebuild_tokenizer(args)
|
306 |
+
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
|
307 |
+
mpu.initialize.set_tensor_model_parallel_rank(0)
|
308 |
+
mpu.initialize.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size)
|
309 |
+
|
310 |
+
# regex to parse out layer number from param name
|
311 |
+
layer_re = re.compile('layers\.([0-9]+)')
|
312 |
+
|
313 |
+
if args.pipeline_model_parallel_size > 1:
|
314 |
+
merged_params = {}
|
315 |
+
for name, merged_param in merged_model.named_parameters():
|
316 |
+
merged_params[name] = merged_param
|
317 |
+
|
318 |
+
for rank in range(args.pipeline_model_parallel_size):
|
319 |
+
mpu.initialize.set_pipeline_model_parallel_rank(rank)
|
320 |
+
model = get_model(model_type)
|
321 |
+
def update_layer_num(m):
|
322 |
+
# TODO! This assumes no interleaved pipeline execution
|
323 |
+
layer = int(m.group(1))
|
324 |
+
layer += rank * layers_per_part
|
325 |
+
return f'layers.{layer}'
|
326 |
+
|
327 |
+
for dst_name, partition_param in model.named_parameters():
|
328 |
+
if dst_name == "word_embeddings.weight":
|
329 |
+
# See comment in MegatronModule.initialize_word_embeddings()
|
330 |
+
src_name = "language_model.embedding.word_embeddings.weight"
|
331 |
+
else:
|
332 |
+
# Translate destination layer number (0-N for each partition)
|
333 |
+
# to source layer number (single-model layer number)
|
334 |
+
src_name = re.sub(layer_re, update_layer_num, dst_name)
|
335 |
+
print(f" > copying {src_name} to {dst_name} in rank {rank}'s model")
|
336 |
+
partition_param.data.copy_(merged_params[src_name].data)
|
337 |
+
|
338 |
+
partitions.append(model)
|
339 |
+
else:
|
340 |
+
partitions = [merged_model]
|
341 |
+
|
342 |
+
for rank, model in enumerate(partitions):
|
343 |
+
mpu.initialize.set_pipeline_model_parallel_rank(rank)
|
344 |
+
print(f"> saving rank {rank}'s model")
|
345 |
+
save_checkpoint(iteration, model, None, None)
|
346 |
+
|
347 |
+
print('done :-)')
|
348 |
+
|
349 |
+
|
350 |
+
if __name__ == '__main__':
|
351 |
+
|
352 |
+
main()
|
tools/openwebtext/README.md
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The following steps show how to prepare training dataset to train the mode.
|
2 |
+
|
3 |
+
# Libraries to install
|
4 |
+
|
5 |
+
```
|
6 |
+
pip install ftfy langdetect numpy torch pandas nltk sentencepiece boto3 tqdm regex bs4 newspaper3k htmlmin tldextract
|
7 |
+
git clone https://github.com/mattilyra/LSH
|
8 |
+
cd LSH
|
9 |
+
python setup.py install
|
10 |
+
```
|
11 |
+
|
12 |
+
# Download the dataset
|
13 |
+
|
14 |
+
1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ)
|
15 |
+
2. Remove blacklisted URLs.
|
16 |
+
```
|
17 |
+
python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for clean urls. e.g. clean_urls.txt>
|
18 |
+
```
|
19 |
+
3. Download the content from the clean urls with [openwebtext's utilities](https://github.com/eukaryote31/openwebtext/blob/master/download.py).
|
20 |
+
|
21 |
+
4. Merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`. It is important for the url to be unique.
|
22 |
+
|
23 |
+
# Prepare the data for GPT training:
|
24 |
+
|
25 |
+
1. Perform ftfy, english detection and remove documents with less than 128 tokens. This step can be sharded and run on shards.
|
26 |
+
```
|
27 |
+
python cleanup_dataset.py <input data file> <output cleaned data filename>
|
28 |
+
```
|
29 |
+
Additional cleanup (e.g. remove documents less than 512 characters or dataset specific cleaning like stories, realnews datasets) can be done using `cleanup_fix_dataset.py`. More details can be found by running `python cleanup_fix_dataset.py --help`.
|
30 |
+
2. Using LSH, find possible duplicates and store then in a file for later processing. The code supports saving and loading fingerprints for recurrent deduplications, and is also multithreaded for faster processing. More details are can be found by `python find_duplicate.py --help`.
|
31 |
+
```
|
32 |
+
python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
|
33 |
+
```
|
34 |
+
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
|
35 |
+
```
|
36 |
+
python group_duplicate_urls.py <possible duplicate urls file> <output file containing similar urls>
|
37 |
+
```
|
38 |
+
4. Remove similar documents that were detected in the last step.
|
39 |
+
```
|
40 |
+
python remove_group_duplicates.py <file containing simialr documents> <cleaned data file> <outputfile containing deduplicate data>
|
41 |
+
```
|
42 |
+
|
43 |
+
5. Shuffle the dataset.
|
44 |
+
```
|
45 |
+
shuf <cleaned deduped data file> -o train_data.json
|
46 |
+
```
|
47 |
+
|
48 |
+
# Deduplicating ngrams
|
49 |
+
|
50 |
+
To deduplicate the downstream tasks (e.g. lambada, squad) from the training dataset, we run the following command.
|
51 |
+
|
52 |
+
```
|
53 |
+
python filter_ngrams.py --tasks <name of the task, e.g. lambada, squad> --dedup-dataset <training dataset to deduplicate> <json key> --output <output training dataset>
|
54 |
+
```
|
55 |
+
We use 13-grams by default for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times. These parameters can be changed using corresponding arguments.
|
56 |
+
|
57 |
+
Only for the lambada task, we need to provide the path, `--lambada-path <path of the lambada test data>`.
|
58 |
+
|
59 |
+
Several other features (e.g. save and load dictionary) have been added, look at `python filter_ngrams.py --help` for details.
|
tools/openwebtext/add_id.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import json
|
18 |
+
import os
|
19 |
+
import time
|
20 |
+
|
21 |
+
"""
|
22 |
+
This code adds id to each json object in a json file. User can add prefix
|
23 |
+
to the ids.
|
24 |
+
"""
|
25 |
+
|
26 |
+
if __name__ == '__main__':
|
27 |
+
|
28 |
+
print('parsing the arguments ...')
|
29 |
+
|
30 |
+
parser = argparse.ArgumentParser()
|
31 |
+
parser.add_argument('--input-file', type=str, default=None, help='Input'\
|
32 |
+
' json file where id needs to be added')
|
33 |
+
parser.add_argument('--output-file', type=str, default=None, help=\
|
34 |
+
'Output file name with id')
|
35 |
+
parser.add_argument('--id-prefix', type=str, default=None, help=\
|
36 |
+
'Id prefix')
|
37 |
+
parser.add_argument('--log-interval', type=int, default=100,
|
38 |
+
help='Log interval')
|
39 |
+
args = parser.parse_args()
|
40 |
+
|
41 |
+
print('Adding ids to dataset ...')
|
42 |
+
|
43 |
+
f_input = open(args.input_file, 'r', encoding='utf-8')
|
44 |
+
f_output = open(args.output_file, 'wb')
|
45 |
+
|
46 |
+
unique_ids = 1
|
47 |
+
start_time = time.time()
|
48 |
+
for row in f_input:
|
49 |
+
each_row = json.loads(row)
|
50 |
+
adlr_id_string = args.id_prefix + '-{:010d}'.format(int(unique_ids))
|
51 |
+
each_row['adlr_id'] = adlr_id_string
|
52 |
+
myjson = json.dumps(each_row, ensure_ascii=False)
|
53 |
+
|
54 |
+
f_output.write(myjson.encode('utf-8'))
|
55 |
+
f_output.write('\n'.encode('utf-8'))
|
56 |
+
|
57 |
+
if unique_ids % args.log_interval == 0:
|
58 |
+
print(' processed {:9d} documents in {:.2f} seconds ...'.format( \
|
59 |
+
unique_ids, time.time() - start_time), flush=True)
|
60 |
+
|
61 |
+
unique_ids += 1
|
62 |
+
|
63 |
+
# Close the file.
|
64 |
+
f_input.close()
|
65 |
+
f_output.close()
|
66 |
+
|
67 |
+
print('done :-)', flush=True)
|
tools/openwebtext/blacklist_urls.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import glob
|
18 |
+
import re
|
19 |
+
import time
|
20 |
+
import tldextract
|
21 |
+
import sys
|
22 |
+
|
23 |
+
|
24 |
+
# List of the domains to blacklist.
|
25 |
+
domain_blacklist = set([
|
26 |
+
'500px',
|
27 |
+
'aapks',
|
28 |
+
'akamaihd',
|
29 |
+
'amazon',
|
30 |
+
'apple',
|
31 |
+
'artifactfire',
|
32 |
+
'artstation',
|
33 |
+
'awwni',
|
34 |
+
'bandcamp',
|
35 |
+
'battleforthenet',
|
36 |
+
'coinscalendar',
|
37 |
+
'dailymotion',
|
38 |
+
'deviantart',
|
39 |
+
'discord',
|
40 |
+
'discordapp',
|
41 |
+
'dlapkandroid',
|
42 |
+
'dropbox',
|
43 |
+
'e621',
|
44 |
+
'ebay',
|
45 |
+
'edealinfo',
|
46 |
+
'erome',
|
47 |
+
'eroshare',
|
48 |
+
'explosm',
|
49 |
+
'facebook',
|
50 |
+
'fbcdn',
|
51 |
+
'flickr',
|
52 |
+
'furaffinity',
|
53 |
+
'futhead',
|
54 |
+
'gatopardo',
|
55 |
+
'gfycat',
|
56 |
+
'gifsound',
|
57 |
+
'gifsoup',
|
58 |
+
'giphy',
|
59 |
+
'github',
|
60 |
+
'google',
|
61 |
+
'gunprime',
|
62 |
+
'gyazo',
|
63 |
+
'hotdealstar',
|
64 |
+
'imagefap',
|
65 |
+
'imageshack',
|
66 |
+
'imgflip',
|
67 |
+
'imgur',
|
68 |
+
'instagram',
|
69 |
+
'karmadecay',
|
70 |
+
'kryptocal',
|
71 |
+
'kym-cdn',
|
72 |
+
'liveleak',
|
73 |
+
'livememe',
|
74 |
+
'lmgtfy',
|
75 |
+
'magaimg',
|
76 |
+
'memegenerator',
|
77 |
+
'minorplanetcenter',
|
78 |
+
'minus',
|
79 |
+
'mobafire',
|
80 |
+
'morejpeg',
|
81 |
+
'nocookie',
|
82 |
+
'pcpartpicker',
|
83 |
+
'photobucket',
|
84 |
+
'pinimg',
|
85 |
+
'pinterest',
|
86 |
+
'pixiv',
|
87 |
+
'pornhub',
|
88 |
+
'prntscr',
|
89 |
+
'puu',
|
90 |
+
'qkme',
|
91 |
+
'quickmeme',
|
92 |
+
'radd',
|
93 |
+
'redd',
|
94 |
+
'reddit',
|
95 |
+
'reddit-stream',
|
96 |
+
'redditlog',
|
97 |
+
'redditmedia',
|
98 |
+
'reddituploads',
|
99 |
+
'redtube',
|
100 |
+
'reupp',
|
101 |
+
'reverb',
|
102 |
+
'roanoke',
|
103 |
+
'rollingstone',
|
104 |
+
'sli',
|
105 |
+
'soundcloud',
|
106 |
+
'soundgasm',
|
107 |
+
'spankbang',
|
108 |
+
'spotify',
|
109 |
+
'strawpoll',
|
110 |
+
'streamable',
|
111 |
+
'timeanddate',
|
112 |
+
'tinypic',
|
113 |
+
'touhouradio',
|
114 |
+
'tumblr',
|
115 |
+
'twimg',
|
116 |
+
'twitch',
|
117 |
+
'twitter',
|
118 |
+
'vid',
|
119 |
+
'vimeo',
|
120 |
+
'vine',
|
121 |
+
'vkaao',
|
122 |
+
'vocaroo',
|
123 |
+
'voyagefusion',
|
124 |
+
'walmart',
|
125 |
+
'wciu',
|
126 |
+
'wikimedia',
|
127 |
+
'wikipedia',
|
128 |
+
'xhamster',
|
129 |
+
'xkcd',
|
130 |
+
'xvideos',
|
131 |
+
'youtu',
|
132 |
+
'youtube',
|
133 |
+
'youtubedoubler',
|
134 |
+
'ytimg',
|
135 |
+
'zillexplorer',
|
136 |
+
])
|
137 |
+
|
138 |
+
def domain_is_in_blacklist(url):
|
139 |
+
domain = tldextract.extract(url).domain
|
140 |
+
return domain in domain_blacklist
|
141 |
+
|
142 |
+
|
143 |
+
# List of extentions to blacklist.
|
144 |
+
extentions_blacklist = (
|
145 |
+
'.3gp',
|
146 |
+
'.7z'
|
147 |
+
'.ai',
|
148 |
+
'.aif',
|
149 |
+
'.apk',
|
150 |
+
'.app',
|
151 |
+
'.avi',
|
152 |
+
'.bin',
|
153 |
+
'.bmp',
|
154 |
+
'.bz2',
|
155 |
+
'.css',
|
156 |
+
'.csv',
|
157 |
+
'.dat',
|
158 |
+
'.deb',
|
159 |
+
'.dmg',
|
160 |
+
'.doc',
|
161 |
+
'.docx',
|
162 |
+
'.exe',
|
163 |
+
'.gif',
|
164 |
+
'.gifv',
|
165 |
+
'.gz',
|
166 |
+
'.iso',
|
167 |
+
'.jar',
|
168 |
+
'.jpeg',
|
169 |
+
'.jpg',
|
170 |
+
'.js',
|
171 |
+
'.log',
|
172 |
+
'.mid',
|
173 |
+
'.midi',
|
174 |
+
'.mkv',
|
175 |
+
'.mov',
|
176 |
+
'.mp3',
|
177 |
+
'.mp4',
|
178 |
+
'.mpeg',
|
179 |
+
'.mpg',
|
180 |
+
'.ogg',
|
181 |
+
'.ogv',
|
182 |
+
'.otf',
|
183 |
+
'.pdf',
|
184 |
+
'.pkg',
|
185 |
+
'.png',
|
186 |
+
'.pps',
|
187 |
+
'.ppt',
|
188 |
+
'.pptx',
|
189 |
+
'.psd',
|
190 |
+
'.py',
|
191 |
+
'.qt',
|
192 |
+
'.ram',
|
193 |
+
'.rar',
|
194 |
+
'.sql',
|
195 |
+
'.svg',
|
196 |
+
'.swf',
|
197 |
+
'.tar.gz',
|
198 |
+
'.tar',
|
199 |
+
'.tgz',
|
200 |
+
'.tiff',
|
201 |
+
'.ttf',
|
202 |
+
'.txt',
|
203 |
+
'.wav',
|
204 |
+
'.webm',
|
205 |
+
'.wma',
|
206 |
+
'.wmv',
|
207 |
+
'.xls',
|
208 |
+
'.xlsx',
|
209 |
+
'.xml',
|
210 |
+
'.xz',
|
211 |
+
'.zip',
|
212 |
+
)
|
213 |
+
|
214 |
+
def extention_is_in_blacklist(url):
|
215 |
+
if url.split('?')[0].lower().endswith(extentions_blacklist):
|
216 |
+
return True
|
217 |
+
return False
|
218 |
+
|
219 |
+
|
220 |
+
# Malformed urls.
|
221 |
+
# This function is adapted from:
|
222 |
+
# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
|
223 |
+
url_regex = re.compile(
|
224 |
+
r'^(?:http)s?://' # http:// or https://
|
225 |
+
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
|
226 |
+
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
|
227 |
+
r'(?::\d+)?' # optional port
|
228 |
+
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
|
229 |
+
def url_is_malformed(url):
|
230 |
+
return re.match(url_regex, url) is None
|
231 |
+
|
232 |
+
|
233 |
+
def print_progress(prefix, start_time, urls_counter,
|
234 |
+
domain_blacklist_counter,
|
235 |
+
extention_blacklist_counter,
|
236 |
+
short_url_counter, malformed_url_counter,
|
237 |
+
duplicate_url_counter):
|
238 |
+
string = prefix + ' | '
|
239 |
+
string += 'time elapsed (s): {:.2f} | '.format(time.time() - start_time)
|
240 |
+
string += 'number of urls: {} | '.format(urls_counter)
|
241 |
+
string += 'domain blacklisted: {} | '.format(domain_blacklist_counter)
|
242 |
+
string += 'extention blacklisted: {} | '.format(extention_blacklist_counter)
|
243 |
+
string += 'short urls (<=8): {} | '.format(short_url_counter)
|
244 |
+
string += 'malformed urls: {} | '.format(malformed_url_counter)
|
245 |
+
string += 'duplicate urls: {}'.format(duplicate_url_counter)
|
246 |
+
print(string, flush=True)
|
247 |
+
|
248 |
+
|
249 |
+
if __name__ == '__main__':
|
250 |
+
|
251 |
+
|
252 |
+
print('remove blacklisted urls ..')
|
253 |
+
|
254 |
+
# Path to the url files.
|
255 |
+
path = sys.argv[1]
|
256 |
+
# Output url file.
|
257 |
+
output = sys.argv[2]
|
258 |
+
|
259 |
+
# Get the list of url files.
|
260 |
+
files = glob.glob(path + '/*.txt')
|
261 |
+
print('> found {} files'.format(len(files)))
|
262 |
+
|
263 |
+
urls = set()
|
264 |
+
urls_counter = 0
|
265 |
+
domain_blacklist_counter = 0
|
266 |
+
extention_blacklist_counter = 0
|
267 |
+
short_url_counter = 0
|
268 |
+
malformed_url_counter = 0
|
269 |
+
duplicate_url_counter = 0
|
270 |
+
start_time = time.time()
|
271 |
+
for filename in files:
|
272 |
+
with open(filename, 'r') as f:
|
273 |
+
for line in f:
|
274 |
+
url = line.strip()
|
275 |
+
urls_counter += 1
|
276 |
+
if domain_is_in_blacklist(url):
|
277 |
+
print('[DOMAIN BLACKLIST]: {}'.format(url), flush=True)
|
278 |
+
domain_blacklist_counter += 1
|
279 |
+
elif extention_is_in_blacklist(url):
|
280 |
+
print('[EXTENTION BLACKLIST]: {}'.format(url), flush=True)
|
281 |
+
extention_blacklist_counter += 1
|
282 |
+
elif len(url) <= 8:
|
283 |
+
print('[SHORT URL]: {}'.format(url), flush=True)
|
284 |
+
short_url_counter += 1
|
285 |
+
elif url_is_malformed(url):
|
286 |
+
print('[MALFORMED URL]: {}'.format(url), flush=True)
|
287 |
+
malformed_url_counter += 1
|
288 |
+
elif url in urls:
|
289 |
+
print('[DUPLICATE URL]: {}'.format(url), flush=True)
|
290 |
+
duplicate_url_counter += 1
|
291 |
+
else:
|
292 |
+
urls.add(url)
|
293 |
+
if urls_counter % 100000 == 0:
|
294 |
+
print_progress('PROGRESS', start_time, urls_counter,
|
295 |
+
domain_blacklist_counter,
|
296 |
+
extention_blacklist_counter,
|
297 |
+
short_url_counter, malformed_url_counter,
|
298 |
+
duplicate_url_counter)
|
299 |
+
|
300 |
+
print_progress('FINAL', start_time, urls_counter,
|
301 |
+
domain_blacklist_counter,
|
302 |
+
extention_blacklist_counter,
|
303 |
+
short_url_counter, malformed_url_counter,
|
304 |
+
duplicate_url_counter)
|
305 |
+
|
306 |
+
# Write the final set of urls.
|
307 |
+
print('> writing cleaned up url list to {}'.format(output))
|
308 |
+
with open(output, 'w') as f:
|
309 |
+
for url in urls:
|
310 |
+
f.write(url + '\n')
|
311 |
+
|
312 |
+
print('done :-)')
|
tools/openwebtext/cleanup_dataset.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import ftfy
|
18 |
+
import json
|
19 |
+
from langdetect import detect
|
20 |
+
import numpy as np
|
21 |
+
import time
|
22 |
+
import os
|
23 |
+
import sys
|
24 |
+
|
25 |
+
from tokenizer import Tokenizer
|
26 |
+
|
27 |
+
MIN_DOCUMENT_LENGHT = 128
|
28 |
+
|
29 |
+
|
30 |
+
def print_progress(prefix, start_time, num_docs, num_fixed_text,
|
31 |
+
num_non_english_docs, chars_non_english_docs,
|
32 |
+
num_small_docs, chars_small_docs):
|
33 |
+
|
34 |
+
string = prefix + ' | '
|
35 |
+
string += 'elapsed time: {:.2f} | '.format(time.time() - start_time)
|
36 |
+
string += 'documents: {} | '.format(num_docs)
|
37 |
+
string += 'fixed text: {} | '.format(num_fixed_text)
|
38 |
+
string += 'non-english: {} | '.format(num_non_english_docs)
|
39 |
+
string += 'non-english chars: {} | '.format(chars_non_english_docs)
|
40 |
+
string += 'small docs: {} | '.format(num_small_docs)
|
41 |
+
string += 'small docs chars: {}'.format(chars_small_docs)
|
42 |
+
print(string, flush=True)
|
43 |
+
|
44 |
+
|
45 |
+
def filter_corpus(filename, out_filename, print_interval=10000):
|
46 |
+
|
47 |
+
print(' > filtering {}'.format(filename))
|
48 |
+
|
49 |
+
tokenizer = Tokenizer(cache_dir='./cache')
|
50 |
+
|
51 |
+
num_docs = 0
|
52 |
+
num_written_docs = 0
|
53 |
+
num_small_docs = 0
|
54 |
+
num_fixed_text = 0
|
55 |
+
num_non_english_docs = 0
|
56 |
+
chars_non_english_docs = 0
|
57 |
+
chars_small_docs = 0
|
58 |
+
start_time = time.time()
|
59 |
+
with open(out_filename, 'wb') as f:
|
60 |
+
with open(filename, 'r') as fin:
|
61 |
+
for line in fin:
|
62 |
+
try:
|
63 |
+
num_docs += 1
|
64 |
+
myjson = json.loads(line)
|
65 |
+
# Fix text
|
66 |
+
text = ftfy.fix_text(myjson['text'])
|
67 |
+
if text != myjson['text']:
|
68 |
+
num_fixed_text += 1
|
69 |
+
myjson['text'] = text
|
70 |
+
# Detect language.
|
71 |
+
if detect(text) != 'en':
|
72 |
+
print('[non-english text]', myjson)
|
73 |
+
num_non_english_docs += 1
|
74 |
+
chars_non_english_docs += len(text)
|
75 |
+
continue
|
76 |
+
# On average each token is 5 characters so 8 is an
|
77 |
+
# upper bound.
|
78 |
+
if len(text) < (8 * MIN_DOCUMENT_LENGHT):
|
79 |
+
tokens = tokenizer.tokenize_document(text)
|
80 |
+
if len(tokens) < MIN_DOCUMENT_LENGHT:
|
81 |
+
print('[small document, skipping]:', myjson)
|
82 |
+
num_small_docs += 1
|
83 |
+
chars_small_docs += len(text)
|
84 |
+
continue
|
85 |
+
myjson = json.dumps(myjson, ensure_ascii=False)
|
86 |
+
f.write(myjson.encode('utf-8'))
|
87 |
+
f.write('\n'.encode('utf-8'))
|
88 |
+
num_written_docs += 1
|
89 |
+
if num_docs % print_interval == 0:
|
90 |
+
print_progress('[PROGRESS]', start_time, num_docs,
|
91 |
+
num_fixed_text, num_non_english_docs,
|
92 |
+
chars_non_english_docs,
|
93 |
+
num_small_docs, chars_small_docs)
|
94 |
+
except Exception as e:
|
95 |
+
print(' skipping ', line, e)
|
96 |
+
|
97 |
+
print_progress('[FINAL]', start_time, num_docs,
|
98 |
+
num_fixed_text, num_non_english_docs,
|
99 |
+
chars_non_english_docs,
|
100 |
+
num_small_docs, chars_small_docs)
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == '__main__':
|
104 |
+
|
105 |
+
print('building gpt2 dataset ...')
|
106 |
+
|
107 |
+
input_filename = sys.argv[1]
|
108 |
+
output_filename = sys.argv[2]
|
109 |
+
|
110 |
+
print('will be reading {}'.format(input_filename))
|
111 |
+
print('and will write the results to {}'.format(output_filename))
|
112 |
+
|
113 |
+
filter_corpus(input_filename, output_filename)
|
114 |
+
|
115 |
+
|
tools/openwebtext/cleanup_fix_dataset.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""
|
17 |
+
Filter and clean documents:
|
18 |
+
Capable to clean docs with less than 512 characters, less than
|
19 |
+
256 characters and contains javascript, fix text and dataset specific
|
20 |
+
cleaning like stories and realnews datasets.
|
21 |
+
Program arguments have the details.
|
22 |
+
"""
|
23 |
+
|
24 |
+
import argparse
|
25 |
+
from functools import partial
|
26 |
+
import glob
|
27 |
+
import ftfy
|
28 |
+
import json
|
29 |
+
from langdetect import detect
|
30 |
+
import multiprocessing
|
31 |
+
import os
|
32 |
+
from pathlib import Path
|
33 |
+
import re
|
34 |
+
import time
|
35 |
+
|
36 |
+
def process_doc(json_line, args):
|
37 |
+
|
38 |
+
# Read the line.
|
39 |
+
document = json.loads(json_line)
|
40 |
+
text = document['text']
|
41 |
+
|
42 |
+
output = {'remove_512': False, 'remove_256_javascript': False, \
|
43 |
+
'remove_512_non_english': False, 'ftfy_fix_text': False, \
|
44 |
+
'general_cleaning': False}
|
45 |
+
|
46 |
+
try:
|
47 |
+
# Reomove all docs with less than 512 characters
|
48 |
+
if "remove_512" in args.tasks:
|
49 |
+
if len(text) < 512:
|
50 |
+
output['remove_512'] = True
|
51 |
+
return output, text, document, True
|
52 |
+
|
53 |
+
# Remove docs if less than 256 character length and contains Javascript
|
54 |
+
if "remove_256_javascript" in args.tasks:
|
55 |
+
if len(text) < 256 and 'javascript' in text.lower():
|
56 |
+
output['remove_256_javascript'] = True
|
57 |
+
return output, text, document, True
|
58 |
+
|
59 |
+
# Remove docs < 512 and nonenglish
|
60 |
+
if "remove_512_non_english" in args.tasks:
|
61 |
+
if len(text) < 512 and detect(text) != 'en':
|
62 |
+
output['remove_512_non_english'] = True
|
63 |
+
return output, text, document, True
|
64 |
+
|
65 |
+
# Fix the text using ftfy, don't remove the text, hence return False
|
66 |
+
if "ftfy_fix_text" in args.tasks:
|
67 |
+
fixed_text = ftfy.fix_text(text)
|
68 |
+
output['ftfy_fix_text'] = True
|
69 |
+
return output, fixed_text, document, False
|
70 |
+
|
71 |
+
# Cleaning extra spaces and newlines
|
72 |
+
if "general_cleaning" in args.tasks:
|
73 |
+
cleaned_text = re.sub(r" +|\b\n+ |\b\n+", " ", text)
|
74 |
+
#cleaned_text = re.sub(r"\n\n+", "\n\n", text) # used this for Gutenberg dataset
|
75 |
+
#cleaned_text = re.sub(r"\n", "\n\n", text) # Used this for realnews
|
76 |
+
|
77 |
+
# stories datasets
|
78 |
+
#cleaned_text = re.sub(r" \'", "'", text)
|
79 |
+
#cleaned_text = re.sub(r" \!", "!", cleaned_text)
|
80 |
+
#cleaned_text = re.sub(r" \.", ".", cleaned_text)
|
81 |
+
#cleaned_text = re.sub(r" \?", "?", cleaned_text)
|
82 |
+
#cleaned_text = re.sub(r" - ", "-", cleaned_text)
|
83 |
+
##cleaned_text = re.sub(r"\" ", "\"", cleaned_text)
|
84 |
+
#cleaned_text = re.sub(r" @ ", "@", cleaned_text)
|
85 |
+
|
86 |
+
output['general_cleaning'] = True
|
87 |
+
return output, cleaned_text, document, False
|
88 |
+
|
89 |
+
except Exception as e:
|
90 |
+
print('Error: *************************\n{}\ntext: {}'.format(e, \
|
91 |
+
text), flush=True)
|
92 |
+
return output, text, document, True
|
93 |
+
|
94 |
+
# don't remove
|
95 |
+
return output, text, document, False
|
96 |
+
|
97 |
+
|
98 |
+
def process_set(args, input_file, output_f_cleaned, output_f_filtered):
|
99 |
+
|
100 |
+
print(' > working on {} ...'.format(input_file), flush=True)
|
101 |
+
|
102 |
+
num_docs = num_remove_512 = num_remove_java = num_remove_512_non_english \
|
103 |
+
= num_ftfy_fix_text = num_general_cleaning = 0
|
104 |
+
|
105 |
+
# Output file and counters.
|
106 |
+
output_cleaned = open(output_f_cleaned, 'wb')
|
107 |
+
output_filtered = open(output_f_filtered, 'wb')
|
108 |
+
|
109 |
+
start_time = time.time()
|
110 |
+
|
111 |
+
# Setup multi-processing.
|
112 |
+
num_workers = 40
|
113 |
+
fin = open(input_file, 'r', encoding='utf-8')
|
114 |
+
pool = multiprocessing.Pool(num_workers)
|
115 |
+
process_doc_partial = partial(process_doc, args=args)
|
116 |
+
processed_docs = pool.imap(process_doc_partial, fin, 500)
|
117 |
+
|
118 |
+
# Process documents.
|
119 |
+
for output, text, document, to_filter in processed_docs:
|
120 |
+
num_docs += 1
|
121 |
+
|
122 |
+
num_remove_512 += 1 if output['remove_512'] else 0
|
123 |
+
num_remove_java += 1 if output['remove_256_javascript'] else 0
|
124 |
+
num_remove_512_non_english += 1 if output['remove_512_non_english'] \
|
125 |
+
else 0
|
126 |
+
num_ftfy_fix_text += 1 if output['ftfy_fix_text'] else 0
|
127 |
+
num_general_cleaning += 1 if output['general_cleaning'] else 0
|
128 |
+
|
129 |
+
document['text'] = text
|
130 |
+
myjson = json.dumps(document, ensure_ascii=False)
|
131 |
+
|
132 |
+
if to_filter:
|
133 |
+
output_filtered.write(myjson.encode('utf-8'))
|
134 |
+
output_filtered.write('\n'.encode('utf-8'))
|
135 |
+
else:
|
136 |
+
output_cleaned.write(myjson.encode('utf-8'))
|
137 |
+
output_cleaned.write('\n'.encode('utf-8'))
|
138 |
+
|
139 |
+
if num_docs % args.log_interval == 0:
|
140 |
+
print(' processed {:9d} documents in {:.2f} seconds ...'.format(
|
141 |
+
num_docs, time.time() - start_time), flush=True)
|
142 |
+
|
143 |
+
# Close the file.
|
144 |
+
output_cleaned.close()
|
145 |
+
output_filtered.close()
|
146 |
+
fin.close()
|
147 |
+
|
148 |
+
# Print stats.
|
149 |
+
print(' >> total docs: {} remove_512 {} remove_256_javascript {} '\
|
150 |
+
'remove_512_non_english {} ftfy_fix_text {} general_cleaning {}'.\
|
151 |
+
format(num_docs, num_remove_512, num_remove_java,\
|
152 |
+
num_remove_512_non_english, num_ftfy_fix_text, \
|
153 |
+
num_general_cleaning), flush=True)
|
154 |
+
|
155 |
+
if __name__ == '__main__':
|
156 |
+
|
157 |
+
|
158 |
+
print('parsing the arguments ...')
|
159 |
+
|
160 |
+
parser = argparse.ArgumentParser()
|
161 |
+
parser.add_argument('--input-files', nargs = '*', required=True, default=\
|
162 |
+
None, help = 'Input json files that needs to be'\
|
163 |
+
' cleaned')
|
164 |
+
parser.add_argument('--tasks', nargs = '*', required=True, default=None,\
|
165 |
+
help = 'Tasks to perform on the input files, ' \
|
166 |
+
'such as remove_512, remove_256_javascript, ' \
|
167 |
+
'remove_512_non_english, ftfy_fix_text, and ' \
|
168 |
+
'general_cleaning. 256 or 512 means the number' \
|
169 |
+
' of characters.')
|
170 |
+
|
171 |
+
parser.add_argument('--output-path', type=str, default=None,
|
172 |
+
help='Directory where the output should go')
|
173 |
+
parser.add_argument('--log-interval', type=int, default=100,
|
174 |
+
help='Log interval')
|
175 |
+
|
176 |
+
args = parser.parse_args()
|
177 |
+
|
178 |
+
print('cleanup dataset ...')
|
179 |
+
|
180 |
+
for input_file in args.input_files:
|
181 |
+
input_filename, input_filename_ext = os.path.splitext(Path(input_file)\
|
182 |
+
.name)
|
183 |
+
|
184 |
+
output_f_cleaned = os.path.join(args.output_path, input_filename + \
|
185 |
+
"_cleaned" + input_filename_ext)
|
186 |
+
output_f_filtered = os.path.join(args.output_path, input_filename + \
|
187 |
+
"_filtered" + input_filename_ext)
|
188 |
+
|
189 |
+
process_set(args, input_file, output_f_cleaned, output_f_filtered)
|
190 |
+
|
191 |
+
print('done :-)', flush=True)
|
tools/openwebtext/filter_ngrams.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""
|
17 |
+
Deduplicate downstream tasks from training dataset. 13-grams have been used.
|
18 |
+
All split documents with less than 200 characters got filtered. Any document
|
19 |
+
with more than 10 splits got filtered as well.
|
20 |
+
"""
|
21 |
+
|
22 |
+
import argparse
|
23 |
+
from functools import partial
|
24 |
+
import json
|
25 |
+
import multiprocessing
|
26 |
+
import nltk
|
27 |
+
import pickle
|
28 |
+
import re
|
29 |
+
import string
|
30 |
+
import sys
|
31 |
+
import time
|
32 |
+
|
33 |
+
def get_words(text):
|
34 |
+
# get all the lowercase words from text
|
35 |
+
words, positions = [], []
|
36 |
+
for match in re.finditer(r'\w+', text.lower()):
|
37 |
+
words.append(match.group(0))
|
38 |
+
positions.append(match.start())
|
39 |
+
return words, positions
|
40 |
+
|
41 |
+
# splits the text
|
42 |
+
def split_text(text, start_position, remove_char_each_side, seq):
|
43 |
+
# first part of the text
|
44 |
+
punctuations = ".!?"
|
45 |
+
pos = start_position - remove_char_each_side
|
46 |
+
text_first = ""
|
47 |
+
while pos > 0 and not text[pos] in punctuations:
|
48 |
+
pos -= 1
|
49 |
+
if pos > 0:
|
50 |
+
text_first = text[0:pos+1]
|
51 |
+
|
52 |
+
# add length of seq and remove_char_each_side
|
53 |
+
pos = start_position + len(seq) + remove_char_each_side
|
54 |
+
|
55 |
+
# last part of the text
|
56 |
+
text_second = ""
|
57 |
+
while pos < len(text) and not text[pos] in punctuations:
|
58 |
+
pos += 1
|
59 |
+
if pos + 1 < len(text):
|
60 |
+
text_second = text[pos+1:len(text)]
|
61 |
+
|
62 |
+
return text_first, text_second
|
63 |
+
|
64 |
+
def check_and_clean_text(args, words, ngrams, text, start_position, \
|
65 |
+
text_buf_ngram_free, text_buf, local_ngram):
|
66 |
+
|
67 |
+
seq = " ".join(words)
|
68 |
+
if seq in ngrams:
|
69 |
+
print(" [matched]: {}".format(seq), flush=True)
|
70 |
+
|
71 |
+
if args.get_ngram_freq_only:
|
72 |
+
# increase freq of this seq and then only consider the later part
|
73 |
+
# of the text for further processing
|
74 |
+
if seq in local_ngram:
|
75 |
+
local_ngram[seq] += 1
|
76 |
+
else:
|
77 |
+
local_ngram[seq] = 1
|
78 |
+
#print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True)
|
79 |
+
if (start_position + len(seq) + 1) < len(text):
|
80 |
+
text_buf.append(text[start_position + len(seq) + 1:len(text)])
|
81 |
+
return False
|
82 |
+
|
83 |
+
# split the text
|
84 |
+
text_first, text_second = split_text(text, start_position, \
|
85 |
+
args.remove_char_each_side, seq)
|
86 |
+
|
87 |
+
# first part of ngrams free
|
88 |
+
if len(text_first) > args.filter_text_char_len:
|
89 |
+
text_buf_ngram_free.append(text_first)
|
90 |
+
|
91 |
+
# add second part for further processing
|
92 |
+
if len(text_second) > args.filter_text_char_len:
|
93 |
+
text_buf.append(text_second)
|
94 |
+
|
95 |
+
return False # not ngram free
|
96 |
+
|
97 |
+
# ngram free
|
98 |
+
return True
|
99 |
+
|
100 |
+
|
101 |
+
def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
|
102 |
+
# remove all the ngrams
|
103 |
+
|
104 |
+
try:
|
105 |
+
myjson = json.loads(line)
|
106 |
+
text_buf = [myjson[key]]
|
107 |
+
except Exception as e:
|
108 |
+
print("Error: {}".format(e), flush=True)
|
109 |
+
text_buf = []
|
110 |
+
|
111 |
+
text_buf_ngram_free = []
|
112 |
+
local_ngram = {}
|
113 |
+
while len(text_buf) > 0:
|
114 |
+
|
115 |
+
# get the first one from the buffer
|
116 |
+
text = text_buf.pop(0)
|
117 |
+
words, positions = get_words(text)
|
118 |
+
|
119 |
+
ngram_free = True
|
120 |
+
# find each max n-grams and check dictionary
|
121 |
+
for i in range(len(words) - args.max_ngram_size + 1):
|
122 |
+
check_ngram_free = check_and_clean_text(args, words[i:\
|
123 |
+
i+args.max_ngram_size], ngrams, text, positions[i], \
|
124 |
+
text_buf_ngram_free, text_buf, local_ngram)
|
125 |
+
|
126 |
+
# the seq is ngram free? if yes, break
|
127 |
+
if not check_ngram_free:
|
128 |
+
ngram_free = False
|
129 |
+
break
|
130 |
+
|
131 |
+
# if max ngrams doesn't match, check if any other lower n-grams
|
132 |
+
# within max ngram macthes
|
133 |
+
for ngram_len, _ in ngrams_freq_sorted:
|
134 |
+
check_ngram_free = check_and_clean_text(args, words[i:\
|
135 |
+
i+ngram_len], ngrams, text, positions[i], \
|
136 |
+
text_buf_ngram_free, text_buf, local_ngram)
|
137 |
+
|
138 |
+
# same check as above
|
139 |
+
if not check_ngram_free:
|
140 |
+
ngram_free = False
|
141 |
+
break
|
142 |
+
|
143 |
+
# check break from lower than max ngram loop above
|
144 |
+
if not ngram_free:
|
145 |
+
break
|
146 |
+
|
147 |
+
# for the last max n-gram, check all the lower ngrams in it
|
148 |
+
if ngram_free and len(words) - args.max_ngram_size > 0:
|
149 |
+
# get the last words of the lax max ngram
|
150 |
+
last_seq_words = words[(len(words)-args.max_ngram_size):len(words)]
|
151 |
+
last_seq_start_position = len(words) - args.max_ngram_size
|
152 |
+
|
153 |
+
# check all n-grams lower than the max
|
154 |
+
for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):
|
155 |
+
|
156 |
+
# ignore the max ngram as has been considered already
|
157 |
+
if ngram_len == args.max_ngram_size:
|
158 |
+
continue
|
159 |
+
|
160 |
+
# find each ngram of ngram_len in max n-grams and check
|
161 |
+
for i in range(len(last_seq_words) - ngram_len + 1):
|
162 |
+
check_ngram_free = check_and_clean_text(args, \
|
163 |
+
last_seq_words[i:i+ngram_len], ngrams, text,\
|
164 |
+
positions[last_seq_start_position+i], \
|
165 |
+
text_buf_ngram_free, text_buf, local_ngram)
|
166 |
+
|
167 |
+
if not check_ngram_free:
|
168 |
+
ngram_free = False
|
169 |
+
break
|
170 |
+
|
171 |
+
if not ngram_free:
|
172 |
+
break
|
173 |
+
|
174 |
+
# texts are ngram free
|
175 |
+
if ngram_free and not args.get_ngram_freq_only:
|
176 |
+
text_buf_ngram_free.append(text)
|
177 |
+
|
178 |
+
# check if the text has only been trimmed
|
179 |
+
trimmed = 0
|
180 |
+
if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \
|
181 |
+
len(text_buf_ngram_free[0]) < len(myjson[key]):
|
182 |
+
trimmed = 1
|
183 |
+
|
184 |
+
return text_buf_ngram_free, trimmed, myjson, local_ngram
|
185 |
+
|
186 |
+
# insert word sequence into dictionary
|
187 |
+
def insert_dict(words, ngrams, pos):
|
188 |
+
seq = " ".join(words)
|
189 |
+
if seq not in ngrams:
|
190 |
+
ngrams[seq] = 0
|
191 |
+
#ngrams[seq] = pos
|
192 |
+
|
193 |
+
# insert each ngram from text into the ngrams dictionary
|
194 |
+
def compute_ngrams_insert_dict(args, text, ngrams):
|
195 |
+
words, positions = get_words(text)
|
196 |
+
if len(words) < args.min_ngram_size:
|
197 |
+
return
|
198 |
+
|
199 |
+
if len(words) < args.max_ngram_size:
|
200 |
+
insert_dict(words, ngrams, positions[0])
|
201 |
+
|
202 |
+
for i in range(len(words) - args.max_ngram_size+1):
|
203 |
+
insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i])
|
204 |
+
|
205 |
+
|
206 |
+
# Build ngrams for the lambada dataset
|
207 |
+
def process_task_lambda(args, task_file, ngrams):
|
208 |
+
print(' reading from {} and computing ngrams'.format(task_file))
|
209 |
+
with open(task_file, 'r') as f:
|
210 |
+
for line in f:
|
211 |
+
try:
|
212 |
+
myjson = json.loads(line)
|
213 |
+
text = myjson['text']
|
214 |
+
compute_ngrams_insert_dict(args, text, ngrams)
|
215 |
+
except Exception as e:
|
216 |
+
print('Error:', e)
|
217 |
+
print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
|
218 |
+
|
219 |
+
|
220 |
+
# Build ngrams for the dataset of the given task
|
221 |
+
def process_task(args, task_name, ngrams):
|
222 |
+
|
223 |
+
print(' reading from {} and computing ngrams'.format('import datasets'))
|
224 |
+
print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
|
225 |
+
# using validation/test data from datasets
|
226 |
+
from datasets import load_dataset
|
227 |
+
|
228 |
+
entities_in_ngrams = len(ngrams)
|
229 |
+
|
230 |
+
# load the dataset
|
231 |
+
if task_name == 'squad':
|
232 |
+
dataset = load_dataset('squad_v2', split='validation')
|
233 |
+
elif task_name == 'natural_questions':
|
234 |
+
dataset = load_dataset('natural_questions', split='validation')
|
235 |
+
elif task_name == 'triviaqa':
|
236 |
+
dataset = load_dataset('trivia_qa', 'unfiltered', split='test')
|
237 |
+
elif task_name == 'webqa':
|
238 |
+
dataset = load_dataset('web_questions', split='test')
|
239 |
+
elif task_name == 'race':
|
240 |
+
dataset = load_dataset('race', 'all', split='test')
|
241 |
+
elif task_name == 'drop':
|
242 |
+
dataset = load_dataset('drop', split='validation')
|
243 |
+
elif task_name == 'coqa':
|
244 |
+
dataset = load_dataset('coqa', split='validation')
|
245 |
+
elif task_name == 'piqa':
|
246 |
+
dataset = load_dataset('piqa', split='test')
|
247 |
+
else:
|
248 |
+
print("Invalid task name: {}".format(task_name), flush=True)
|
249 |
+
return
|
250 |
+
|
251 |
+
# read the dataset and add to ngrams
|
252 |
+
for line in dataset:
|
253 |
+
try:
|
254 |
+
if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']:
|
255 |
+
text = line['question']
|
256 |
+
compute_ngrams_insert_dict(args, text, ngrams)
|
257 |
+
elif task_name == 'natural_questions':
|
258 |
+
text = line['question']['text']
|
259 |
+
compute_ngrams_insert_dict(args, text, ngrams)
|
260 |
+
elif task_name == 'coqa':
|
261 |
+
all_questions = line['questions']
|
262 |
+
for question in all_questions:
|
263 |
+
compute_ngrams_insert_dict(args, question, ngrams)
|
264 |
+
elif task_name == 'piqa':
|
265 |
+
text = line['goal']
|
266 |
+
compute_ngrams_insert_dict(args, text, ngrams)
|
267 |
+
except Exception as e:
|
268 |
+
print('Error:', e)
|
269 |
+
|
270 |
+
print(" After task {} entities in ngrams {}, added {}".format(task_name, \
|
271 |
+
len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)
|
272 |
+
|
273 |
+
def compute_tasks_ngrams(args, ngrams):
|
274 |
+
start_time = time.time()
|
275 |
+
for _, task_name in enumerate(args.tasks):
|
276 |
+
print('Task: {}'.format(task_name), flush=True)
|
277 |
+
if task_name == 'lambada':
|
278 |
+
assert args.lambada_path is not None
|
279 |
+
process_task_lambda(args, args.lambada_path, ngrams)
|
280 |
+
else:
|
281 |
+
process_task(args, task_name, ngrams)
|
282 |
+
print(" Taken time to compute ngrams {:.2f}".format(time.time() - \
|
283 |
+
start_time), flush=True)
|
284 |
+
|
285 |
+
def compute_ngram_freq_sorted(args, ngrams):
|
286 |
+
ngrams_freq = {}
|
287 |
+
for ngram_key in ngrams.keys():
|
288 |
+
length = len(ngram_key.split())
|
289 |
+
ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
|
290 |
+
ngrams_freq else 1
|
291 |
+
|
292 |
+
ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0])
|
293 |
+
print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
|
294 |
+
print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
|
295 |
+
len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
|
296 |
+
ngrams_freq_sorted) -1 ][0]), flush=True)
|
297 |
+
return ngrams_freq_sorted
|
298 |
+
|
299 |
+
def get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
|
300 |
+
dedup_file, dedup_key, ngrams_freq_sorted):
|
301 |
+
|
302 |
+
start_time = time.time()
|
303 |
+
# get the ngrams frequency
|
304 |
+
args.get_ngram_freq_only = True
|
305 |
+
|
306 |
+
# Open the large file to process in parallel
|
307 |
+
num_workers = args.num_threads
|
308 |
+
pool = multiprocessing.Pool(num_workers)
|
309 |
+
fin = open(dedup_file, 'r', encoding='utf-8')
|
310 |
+
free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \
|
311 |
+
ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted)
|
312 |
+
free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500)
|
313 |
+
|
314 |
+
counter = 0
|
315 |
+
for _, _, _, local_ngram in free_ngrams_abt:
|
316 |
+
counter += 1
|
317 |
+
if counter % 1000 == 0:
|
318 |
+
print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'.
|
319 |
+
format(counter, time.time() - start_time), flush=True)
|
320 |
+
for local_key in local_ngram:
|
321 |
+
if local_key in ngrams:
|
322 |
+
ngrams[local_key] += 1
|
323 |
+
local_ngram = {}
|
324 |
+
|
325 |
+
print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \
|
326 |
+
start_time), flush=True)
|
327 |
+
pool.close()
|
328 |
+
pool.join()
|
329 |
+
|
330 |
+
start_time = time.time()
|
331 |
+
counter_threshold = 0
|
332 |
+
# Get ngram below theadhold
|
333 |
+
for local_key, local_val in ngrams.items():
|
334 |
+
if ngrams[local_key] < args.key_threshold:
|
335 |
+
print(" [threshold] {} {}".format(local_key, local_val), flush=True)
|
336 |
+
counter_threshold += 1
|
337 |
+
ngrams_below_threshold[local_key] = 1
|
338 |
+
|
339 |
+
print(' Ngrams below threshold {}'.format(counter_threshold), flush=True)
|
340 |
+
fin.close()
|
341 |
+
|
342 |
+
def clean_ngrams_below_threshold(args, ngrams_below_threshold, dedup_file, \
|
343 |
+
dedup_key):
|
344 |
+
|
345 |
+
start_time = time.time()
|
346 |
+
# Now actually filter the dataset
|
347 |
+
args.get_ngram_freq_only = False
|
348 |
+
#id_prefix = '-'.join(args.tasks[::2])
|
349 |
+
id_prefix = '-'.join(args.tasks[::1])
|
350 |
+
|
351 |
+
# get the range of the size of the ngrams
|
352 |
+
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_below_threshold)
|
353 |
+
|
354 |
+
# Open the large file to process in parallel
|
355 |
+
counter = splitted = ignored = split_mt_thld = trimmed_count = 0
|
356 |
+
num_workers = args.num_threads
|
357 |
+
pool = multiprocessing.Pool(num_workers)
|
358 |
+
fin = open(dedup_file, 'r', encoding='utf-8')
|
359 |
+
free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \
|
360 |
+
ngrams=ngrams_below_threshold, ngrams_freq_sorted=ngrams_freq_sorted)
|
361 |
+
free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500)
|
362 |
+
|
363 |
+
out_f = open(args.output, 'wb')
|
364 |
+
|
365 |
+
for text_buf_ngram_free, trimmed, myjson, _ in free_ngrams_clean:
|
366 |
+
counter += 1
|
367 |
+
try:
|
368 |
+
|
369 |
+
trimmed_count += trimmed
|
370 |
+
|
371 |
+
if len(text_buf_ngram_free) > 1:
|
372 |
+
splitted += 1
|
373 |
+
if len(text_buf_ngram_free) == 0:
|
374 |
+
ignored += 1
|
375 |
+
# more than 10 splits ignored
|
376 |
+
if len(text_buf_ngram_free) > args.splits_count:
|
377 |
+
text_buf_ngram_free = []
|
378 |
+
split_mt_thld += 1
|
379 |
+
|
380 |
+
if args.output is not None:
|
381 |
+
if "split_id" in myjson:
|
382 |
+
use_prefix = myjson["split_id"] + "-"
|
383 |
+
else:
|
384 |
+
use_prefix = ""
|
385 |
+
|
386 |
+
for i in range(len(text_buf_ngram_free)):
|
387 |
+
split_id_string = id_prefix + '-{:010d}'.format(int(\
|
388 |
+
counter)) + '-{:04d}'.format(int(i))
|
389 |
+
myjson[dedup_key] = text_buf_ngram_free[i]
|
390 |
+
myjson["split_id"] = use_prefix + split_id_string
|
391 |
+
outjson = json.dumps(myjson, ensure_ascii=False)
|
392 |
+
#outjson = json.dumps({"text":text_buf_ngram_free[i],
|
393 |
+
# id_prefix+"_split_id":split_id_string},
|
394 |
+
# ensure_ascii=False)
|
395 |
+
out_f.write(outjson.encode('utf-8'))
|
396 |
+
out_f.write('\n'.encode('utf-8'))
|
397 |
+
|
398 |
+
if counter % 1000 == 0:
|
399 |
+
print(' [final]> processed {} documents in {:.2f} seconds ...'.
|
400 |
+
format(counter, time.time() - start_time), flush=True)
|
401 |
+
except Exception as e:
|
402 |
+
print('Error:', e)
|
403 |
+
|
404 |
+
print(' [final]> processed {} documents in {:.2f} seconds ...'.
|
405 |
+
format(counter, time.time() - start_time), flush=True)
|
406 |
+
|
407 |
+
print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\
|
408 |
+
' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\
|
409 |
+
, flush=True)
|
410 |
+
|
411 |
+
pool.close()
|
412 |
+
pool.join()
|
413 |
+
|
414 |
+
out_f.close()
|
415 |
+
fin.close()
|
416 |
+
|
417 |
+
if __name__ == '__main__':
|
418 |
+
|
419 |
+
# we use 13-grams, any text less than 200 characters got removed
|
420 |
+
# any text splitted more than 10 got removed as well
|
421 |
+
|
422 |
+
print('parsing the arguments ...')
|
423 |
+
|
424 |
+
parser = argparse.ArgumentParser()
|
425 |
+
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
|
426 |
+
help = 'Tasks to use for deduplication: currently '
|
427 |
+
' suuport [lambada, squad, natural_questions,'
|
428 |
+
' triviaqa, webqa, race, drop, coqa, and piqa]')
|
429 |
+
parser.add_argument('--lambada-path', type=str, default=None,
|
430 |
+
help='Only Lambada task needs the path')
|
431 |
+
parser.add_argument('--dedup-dataset', nargs = '*', default=None,
|
432 |
+
help='Dataset to deduplicate with the key to use'
|
433 |
+
' e.g. cc.json text')
|
434 |
+
parser.add_argument('--output', type=str, default=None,
|
435 |
+
help='Output file name to save dedup dataset')
|
436 |
+
parser.add_argument('--num-threads', type=int, default=40,
|
437 |
+
help='Number of threads to use')
|
438 |
+
# Default dedup values
|
439 |
+
parser.add_argument('--max-ngram-size', type=int, default=13,
|
440 |
+
help='Maximum size of ngram to use.')
|
441 |
+
parser.add_argument('--min-ngram-size', type=int, default=8,
|
442 |
+
help='Minimum size of ngram to use.')
|
443 |
+
parser.add_argument('--filter-text-char-len', type=int, default=200,
|
444 |
+
help='Remove any text below this length.')
|
445 |
+
parser.add_argument('--key-threshold', type=int, default=10,
|
446 |
+
help='Number of keys to consider as threshold')
|
447 |
+
parser.add_argument('--save-dictionary', type=str, default=None,
|
448 |
+
help='Save the dictionary')
|
449 |
+
parser.add_argument('--load-dictionary', type=str, default=None,
|
450 |
+
help='Load the dictionary')
|
451 |
+
parser.add_argument('--splits-count', type=int, default=10,
|
452 |
+
help='Remove any documents more than this many splits')
|
453 |
+
parser.add_argument('--remove-char-each-side', type=int, default=200,
|
454 |
+
help='Maximum size of ngram to use.')
|
455 |
+
|
456 |
+
args = parser.parse_args()
|
457 |
+
|
458 |
+
assert len(args.dedup_dataset) == 2
|
459 |
+
dedup_file = args.dedup_dataset[0]
|
460 |
+
dedup_key = args.dedup_dataset[1]
|
461 |
+
|
462 |
+
# Setup multi-processing
|
463 |
+
num_workers = args.num_threads
|
464 |
+
if args.load_dictionary is None:
|
465 |
+
|
466 |
+
# Build ngrams
|
467 |
+
ngrams = {}
|
468 |
+
compute_tasks_ngrams(args, ngrams)
|
469 |
+
|
470 |
+
# get the range of the size of the ngrams
|
471 |
+
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams)
|
472 |
+
|
473 |
+
# get ngram freq from large file in parallel
|
474 |
+
# get ngrams below threshold
|
475 |
+
ngrams_below_threshold = {}
|
476 |
+
get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
|
477 |
+
dedup_file, dedup_key, ngrams_freq_sorted)
|
478 |
+
|
479 |
+
# save the dictionary if needed
|
480 |
+
if args.save_dictionary is not None:
|
481 |
+
with open(args.save_dictionary, 'wb') as save_dict_handle:
|
482 |
+
pickle.dump(ngrams_below_threshold, save_dict_handle)
|
483 |
+
else:
|
484 |
+
with open(args.load_dictionary, 'rb') as load_dict_handle:
|
485 |
+
ngrams_below_threshold = pickle.load(load_dict_handle)
|
486 |
+
|
487 |
+
# filter the large file
|
488 |
+
if args.output is not None:
|
489 |
+
clean_ngrams_below_threshold(args, ngrams_below_threshold, \
|
490 |
+
dedup_file, dedup_key)
|
491 |
+
|
492 |
+
print('done :-)')
|
tools/openwebtext/find_duplicates.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
from functools import partial
|
18 |
+
import itertools
|
19 |
+
import json
|
20 |
+
from lsh import cache, minhash
|
21 |
+
import multiprocessing
|
22 |
+
import numpy as np
|
23 |
+
import time
|
24 |
+
import pickle
|
25 |
+
import sys
|
26 |
+
import os
|
27 |
+
|
28 |
+
# This function is adapted from:
|
29 |
+
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
|
30 |
+
def shingles(text, char_ngram=5):
|
31 |
+
return set(text[head:head + char_ngram]
|
32 |
+
for head in range(0, len(text) - char_ngram))
|
33 |
+
|
34 |
+
|
35 |
+
# This function is adapted from:
|
36 |
+
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
|
37 |
+
def jaccard(set_a, set_b, args):
|
38 |
+
if len(set_a) < 1 or len(set_b) < 1:
|
39 |
+
return 0.0
|
40 |
+
|
41 |
+
intersection = set_a & set_b
|
42 |
+
union = set_a | set_b
|
43 |
+
|
44 |
+
if args.jaccard == 'min':
|
45 |
+
return len(intersection) / min(len(set_a), len(set_b))
|
46 |
+
elif args.jaccard == 'max':
|
47 |
+
return len(intersection) / max(len(set_a), len(set_b))
|
48 |
+
else:
|
49 |
+
return len(intersection) / len(union)
|
50 |
+
|
51 |
+
def compute_fingerprint(line, key):
|
52 |
+
try:
|
53 |
+
myjson = json.loads(line)
|
54 |
+
url = myjson[key]
|
55 |
+
text = myjson['text']
|
56 |
+
fingerprint = hasher.fingerprint(text)
|
57 |
+
except Exception as e:
|
58 |
+
print('Error:', e)
|
59 |
+
return None, None, None, False
|
60 |
+
|
61 |
+
return url, text, fingerprint, True
|
62 |
+
|
63 |
+
def url_pairs_to_remove(args, bucket_urls, url_doc):
|
64 |
+
remove_urls_list = []
|
65 |
+
deduped_local, counter_local = 0, 0
|
66 |
+
iteration = 0
|
67 |
+
while len(bucket_urls) > 1:
|
68 |
+
if args.heuristic_iter != -1 and \
|
69 |
+
iteration == args.heuristic_iter:
|
70 |
+
break
|
71 |
+
|
72 |
+
items = list(bucket_urls)
|
73 |
+
remove_urls = []
|
74 |
+
main_url = items[np.random.randint(0, len(items))]
|
75 |
+
main_dhingles = shingles(url_doc[main_url])
|
76 |
+
|
77 |
+
for i in range(0, len(items)):
|
78 |
+
counter_local += 1
|
79 |
+
other_url = items[i]
|
80 |
+
if other_url == main_url:
|
81 |
+
continue
|
82 |
+
other_shingles = shingles(url_doc[other_url])
|
83 |
+
try:
|
84 |
+
jaccard_sim = jaccard(main_dhingles, other_shingles, args)
|
85 |
+
except Exception as e:
|
86 |
+
print('Error:', e)
|
87 |
+
jaccard_sim = 0.0
|
88 |
+
if jaccard_sim > 0.5:
|
89 |
+
remove_urls.append({other_url: jaccard_sim})
|
90 |
+
deduped_local += 1
|
91 |
+
bucket_urls.remove(other_url)
|
92 |
+
|
93 |
+
bucket_urls.remove(main_url)
|
94 |
+
if len(remove_urls) > 0:
|
95 |
+
remove_urls_list.append({main_url: remove_urls})
|
96 |
+
iteration += 1
|
97 |
+
return remove_urls_list, deduped_local, counter_local
|
98 |
+
|
99 |
+
def write_remove_urls_list(remove_urls_list, f_out):
|
100 |
+
if len(remove_urls_list) > 0:
|
101 |
+
for each_url_remove in remove_urls_list:
|
102 |
+
myjson = json.dumps(each_url_remove, ensure_ascii=False)
|
103 |
+
f_out.write(myjson.encode('utf-8'))
|
104 |
+
f_out.write('\n'.encode('utf-8'))
|
105 |
+
|
106 |
+
def compute_jaccard(each_bin, num_bins, start_time_local):
|
107 |
+
|
108 |
+
remove_urls_list = []
|
109 |
+
deduped_local, counter_local, bucket_local = 0, 0, 0
|
110 |
+
|
111 |
+
for bucket_id in each_bin:
|
112 |
+
bucket_local += 1
|
113 |
+
if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
|
114 |
+
print("Counter {}, progress {:.2f} time {:.2f}".\
|
115 |
+
format(bucket_local, float(bucket_local)/float(len(each_bin)),\
|
116 |
+
time.time() - start_time_local), flush=True)
|
117 |
+
|
118 |
+
if len(each_bin[bucket_id]) <= 1:
|
119 |
+
continue
|
120 |
+
|
121 |
+
bucket_urls = each_bin[bucket_id].copy()
|
122 |
+
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
|
123 |
+
url_pairs_to_remove(args, bucket_urls, url_doc)
|
124 |
+
|
125 |
+
deduped_local += deduped_local_sub
|
126 |
+
counter_local += counter_local_sub
|
127 |
+
if len(remove_urls_list_sub) > 0:
|
128 |
+
remove_urls_list.extend(remove_urls_list_sub)
|
129 |
+
|
130 |
+
return remove_urls_list, deduped_local, counter_local
|
131 |
+
|
132 |
+
def find_pair_urls_parallel(args, lshcache, url_doc):
|
133 |
+
start_time = time.time()
|
134 |
+
f_out = open(args.output, 'wb')
|
135 |
+
deduped, counter = 0, 0
|
136 |
+
|
137 |
+
# compute jaccards of buckets in bin in parallel (parallelism
|
138 |
+
# limited to # of bins)
|
139 |
+
num_bins = len(lshcache.bins)
|
140 |
+
pool = multiprocessing.Pool(num_bins)
|
141 |
+
compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
|
142 |
+
start_time_local=start_time)
|
143 |
+
# don't need to pass args and url_doc as they are already shared
|
144 |
+
compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)
|
145 |
+
|
146 |
+
print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
|
147 |
+
flush=True)
|
148 |
+
for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
|
149 |
+
deduped += deduped_local
|
150 |
+
counter += counter_local
|
151 |
+
write_remove_urls_list(remove_urls_list, f_out)
|
152 |
+
print(' [write]> processed {} documents in {:.2f} '
|
153 |
+
'seoncds and deduped {} documents ...'.format(counter, time.time()\
|
154 |
+
- start_time, deduped), flush=True)
|
155 |
+
|
156 |
+
pool.close()
|
157 |
+
pool.join()
|
158 |
+
f_out.close()
|
159 |
+
|
160 |
+
print(' Taken time for jaccard similariries {:.2f} seconds'.format(\
|
161 |
+
time.time() - start_time), flush=True)
|
162 |
+
|
163 |
+
def find_pair_urls_sequential(args, lshcache, url_doc):
|
164 |
+
start_time = time.time()
|
165 |
+
f_out = open(args.output, 'wb')
|
166 |
+
deduped, counter = 0, 0
|
167 |
+
for b in lshcache.bins:
|
168 |
+
for bucket_id in b:
|
169 |
+
if len(b[bucket_id]) <= 1:
|
170 |
+
continue
|
171 |
+
|
172 |
+
bucket_urls = b[bucket_id].copy()
|
173 |
+
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
|
174 |
+
url_pairs_to_remove(args, bucket_urls, url_doc)
|
175 |
+
|
176 |
+
deduped += deduped_local_sub
|
177 |
+
counter += counter_local_sub
|
178 |
+
write_remove_urls_list(remove_urls_list_sub, f_out)
|
179 |
+
if counter % 10000 == 0:
|
180 |
+
print(' [write]> processed {} documents in {:.2f} '
|
181 |
+
'seoncds and deduped {} documents ...'.
|
182 |
+
format(counter, time.time() - start_time,
|
183 |
+
deduped), flush=True)
|
184 |
+
f_out.close()
|
185 |
+
print(' [write]> processed {} documents in {:.2f} '
|
186 |
+
'seoncds and deduped {} documents ...'.
|
187 |
+
format(counter, time.time() - start_time,
|
188 |
+
deduped), flush=True)
|
189 |
+
|
190 |
+
if __name__ == '__main__':
|
191 |
+
|
192 |
+
print('parsing the arguments ...')
|
193 |
+
|
194 |
+
parser = argparse.ArgumentParser()
|
195 |
+
parser.add_argument('--seed', type=int, default=1234,
|
196 |
+
help='Random seed used for python, numpy')
|
197 |
+
parser.add_argument('--inputs', nargs = '*', default=None, help = \
|
198 |
+
'Pairwise list of the input files and keys, '
|
199 |
+
'e.g. --inputs cc.json cc_id news.json news_id')
|
200 |
+
parser.add_argument('--load-fingerprints', nargs = '*', default=None,
|
201 |
+
help='Load fingerprints from a list of pickle files,'
|
202 |
+
' e.g. cc.pkl news.pkl')
|
203 |
+
parser.add_argument('--save-fingerprints', type=str, default=None,
|
204 |
+
help='Save the fingerprints of the inputs.')
|
205 |
+
parser.add_argument('--output', type=str, default=None,
|
206 |
+
help='Output file name that consists of all ids'
|
207 |
+
' with matching similarities')
|
208 |
+
parser.add_argument('--jaccard', type=str, default='union',
|
209 |
+
choices=['union', 'min', 'max'], help='Jaccard'\
|
210 |
+
' similarity computation')
|
211 |
+
parser.add_argument('--heuristic-iter', type=int, default=1,
|
212 |
+
help='Number of iterations to run the heuristics'
|
213 |
+
': use -1 for exact')
|
214 |
+
parser.add_argument('--num-bands', type=int, default=10,
|
215 |
+
help='Number of bands to use in cache')
|
216 |
+
parser.add_argument('--num-seeds', type=int, default=100,
|
217 |
+
help='Number of seeds to use for minhash. Note that'
|
218 |
+
' this value should be divisible by num-bands')
|
219 |
+
parser.add_argument('--jaccard-parallel', action='store_true',
|
220 |
+
help='Use this to process large number of documents.')
|
221 |
+
args = parser.parse_args()
|
222 |
+
|
223 |
+
print('finding possible duplicate content ...')
|
224 |
+
|
225 |
+
# set seed and get an array of seeds of 100 integers
|
226 |
+
np.random.seed(args.seed)
|
227 |
+
seeds = np.random.randint(0, 1e6, size=args.num_seeds)
|
228 |
+
|
229 |
+
# initialize minhash and lsh cache
|
230 |
+
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
|
231 |
+
lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
|
232 |
+
|
233 |
+
url_doc = {}
|
234 |
+
|
235 |
+
# load fingerprints from pickle file if needed
|
236 |
+
if args.load_fingerprints is not None:
|
237 |
+
for count_fp, fp_file_name in enumerate(args.load_fingerprints):
|
238 |
+
print("Loading fingerprints from pickle file {}".format(
|
239 |
+
fp_file_name), flush=True)
|
240 |
+
fp = open(fp_file_name, "rb")
|
241 |
+
if count_fp == 0:
|
242 |
+
# assign directory for the first pkl
|
243 |
+
lshcache = pickle.load(fp)
|
244 |
+
url_doc = pickle.load(fp)
|
245 |
+
else:
|
246 |
+
# append these to lshcache and url_doc
|
247 |
+
local_lshcache = pickle.load(fp)
|
248 |
+
local_url_doc = pickle.load(fp)
|
249 |
+
for url in local_lshcache.fingerprints.keys():
|
250 |
+
url_doc[url] = local_url_doc[url]
|
251 |
+
lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
|
252 |
+
fp.close()
|
253 |
+
|
254 |
+
counter = 0
|
255 |
+
start_time = time.time()
|
256 |
+
|
257 |
+
# compute finger prints of the inputs if any
|
258 |
+
# input file and the key to use as id
|
259 |
+
if args.inputs is not None:
|
260 |
+
print("Computing fingerprints", flush=True)
|
261 |
+
assert len(args.inputs) % 2 == 0
|
262 |
+
for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
|
263 |
+
print(' document processing {} with key {}'.format(input_file, key),
|
264 |
+
flush=True)
|
265 |
+
|
266 |
+
# compute fingerprints in parallel
|
267 |
+
num_workers = 40
|
268 |
+
pool = multiprocessing.Pool(num_workers)
|
269 |
+
fin = open(input_file, 'r', encoding='utf-8')
|
270 |
+
compute_fingerprint_partial = partial(compute_fingerprint, key=key)
|
271 |
+
compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
|
272 |
+
fin, 512)
|
273 |
+
# traverse all the texts and add fingerprints
|
274 |
+
for url, text, fingerprint, flag in compute_fingerprint_iter:
|
275 |
+
counter += 1
|
276 |
+
if flag:
|
277 |
+
url_doc[url] = text
|
278 |
+
lshcache.add_fingerprint(fingerprint, url)
|
279 |
+
if counter % 10000 == 0:
|
280 |
+
print(' [read]> processed {} documents in {:.2f} '
|
281 |
+
'seconds ...'.format(counter, time.time() - \
|
282 |
+
start_time), flush=True)
|
283 |
+
|
284 |
+
fin.close()
|
285 |
+
pool.close()
|
286 |
+
pool.join()
|
287 |
+
|
288 |
+
# Save the fingerprints if needed
|
289 |
+
if args.save_fingerprints is not None:
|
290 |
+
print("Saving fingerprints to pickle file {}".format(
|
291 |
+
args.save_fingerprints), flush=True)
|
292 |
+
with open(args.save_fingerprints, 'wb') as f_save:
|
293 |
+
pickle.dump(lshcache, f_save)
|
294 |
+
pickle.dump(url_doc, f_save)
|
295 |
+
|
296 |
+
# compute jaccard index of the input texts and write to file if needed
|
297 |
+
if args.output is not None:
|
298 |
+
print("Compute jaccard similarity", flush=True)
|
299 |
+
if args.jaccard_parallel:
|
300 |
+
find_pair_urls_parallel(args, lshcache, url_doc)
|
301 |
+
else:
|
302 |
+
find_pair_urls_sequential(args, lshcache, url_doc)
|
303 |
+
|
304 |
+
print('done :-)')
|
305 |
+
|
tools/openwebtext/group_duplicate_url.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import json
|
17 |
+
import time
|
18 |
+
import sys
|
19 |
+
|
20 |
+
|
21 |
+
if __name__ == '__main__':
|
22 |
+
|
23 |
+
|
24 |
+
print('grouping duplicate urls ...')
|
25 |
+
|
26 |
+
input = sys.argv[1]
|
27 |
+
output = sys.argv[2]
|
28 |
+
if len(sys.argv) > 3:
|
29 |
+
jaccard_similarity_threshold = float(sys.argv[3])
|
30 |
+
else:
|
31 |
+
jaccard_similarity_threshold = 0.7
|
32 |
+
|
33 |
+
url_to_index = {}
|
34 |
+
index_to_urls = []
|
35 |
+
counter = 0
|
36 |
+
start_time = time.time()
|
37 |
+
with open(input, 'r') as f:
|
38 |
+
for line in f:
|
39 |
+
counter += 1
|
40 |
+
myjson = json.loads(line)
|
41 |
+
urls = []
|
42 |
+
for main_url in myjson.keys():
|
43 |
+
urls.append(main_url)
|
44 |
+
for value in myjson[main_url]:
|
45 |
+
for other_url, js in value.items():
|
46 |
+
if js >= jaccard_similarity_threshold:
|
47 |
+
urls.append(other_url)
|
48 |
+
current_index = -1
|
49 |
+
other_indices = set()
|
50 |
+
for url in urls:
|
51 |
+
if url in url_to_index:
|
52 |
+
if current_index == -1:
|
53 |
+
current_index = url_to_index[url]
|
54 |
+
elif current_index != url_to_index[url]:
|
55 |
+
other_indices.add(url_to_index[url])
|
56 |
+
if current_index == -1:
|
57 |
+
current_index = len(index_to_urls)
|
58 |
+
index_to_urls.append(set())
|
59 |
+
for url in urls:
|
60 |
+
url_to_index[url] = current_index
|
61 |
+
index_to_urls[current_index].add(url)
|
62 |
+
for index in other_indices:
|
63 |
+
for url in index_to_urls[index]:
|
64 |
+
index_to_urls[current_index].add(url)
|
65 |
+
url_to_index[url] = current_index
|
66 |
+
index_to_urls[index] = None
|
67 |
+
|
68 |
+
if counter % 100000 == 0:
|
69 |
+
print(' > processed {} lines in {} seconds ...'.format(
|
70 |
+
counter, time.time() - start_time))
|
71 |
+
|
72 |
+
|
73 |
+
total_remove = 0
|
74 |
+
total_remain = 0
|
75 |
+
for urls in index_to_urls:
|
76 |
+
if urls is not None:
|
77 |
+
if len(urls) > 1:
|
78 |
+
total_remove += (len(urls) - 1)
|
79 |
+
total_remain += 1
|
80 |
+
print('out of {} urls, only {} are unique and {} should be removed'.format(
|
81 |
+
total_remove+total_remain, total_remain, total_remove))
|
82 |
+
|
83 |
+
with open(output, 'wb') as f:
|
84 |
+
for i, urls in enumerate(index_to_urls):
|
85 |
+
if urls is not None:
|
86 |
+
if len(urls) > 1:
|
87 |
+
myjson = json.dumps({str(i): list(urls)},
|
88 |
+
ensure_ascii=False)
|
89 |
+
f.write(myjson.encode('utf-8'))
|
90 |
+
f.write('\n'.encode('utf-8'))
|
tools/openwebtext/merge_jsons.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import glob
|
18 |
+
import sys
|
19 |
+
import json
|
20 |
+
import argparse
|
21 |
+
|
22 |
+
if __name__ == '__main__':
|
23 |
+
|
24 |
+
parser = argparse.ArgumentParser()
|
25 |
+
parser.add_argument("--json_path", type=str, default=".",
|
26 |
+
help="path where all the json files are located")
|
27 |
+
|
28 |
+
parser.add_argument("--output_file", type=str, default="merged_output.json",
|
29 |
+
help="filename where the merged json should go")
|
30 |
+
|
31 |
+
args = parser.parse_args()
|
32 |
+
|
33 |
+
json_path = args.json_path
|
34 |
+
out_file = args.output_file
|
35 |
+
|
36 |
+
json_files = glob.glob(json_path + '/*.json')
|
37 |
+
|
38 |
+
counter = 0
|
39 |
+
|
40 |
+
with open(out_file, 'w') as outfile:
|
41 |
+
for fname in json_files:
|
42 |
+
counter += 1
|
43 |
+
|
44 |
+
if counter % 1024 == 0:
|
45 |
+
print("Merging at ", counter, flush=True)
|
46 |
+
|
47 |
+
with open(fname, 'r') as infile:
|
48 |
+
for row in infile:
|
49 |
+
each_row = json.loads(row)
|
50 |
+
outfile.write(row)
|
51 |
+
|
52 |
+
|
53 |
+
print("Merged file", out_file, flush=True)
|
54 |
+
|
55 |
+
|
tools/openwebtext/remove_group_duplicates.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import json
|
18 |
+
import time
|
19 |
+
import sys
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == '__main__':
|
23 |
+
|
24 |
+
url_filename = sys.argv[1]
|
25 |
+
data_filename = sys.argv[2]
|
26 |
+
output_filename = sys.argv[3]
|
27 |
+
|
28 |
+
urls = set()
|
29 |
+
with open(url_filename, 'r') as f:
|
30 |
+
for line in f:
|
31 |
+
myjson = json.loads(line)
|
32 |
+
for key in myjson:
|
33 |
+
this_urls = myjson[key]
|
34 |
+
for i in range(1, len(this_urls)):
|
35 |
+
urls.add(this_urls[i])
|
36 |
+
print('will be removing {} urls'.format(len(urls)), flush=True)
|
37 |
+
|
38 |
+
written_docs = 0
|
39 |
+
removed_docs = 0
|
40 |
+
removed_chars = 0
|
41 |
+
start_time = time.time()
|
42 |
+
with open(output_filename, 'wb') as fout:
|
43 |
+
with open(data_filename, 'r') as fin:
|
44 |
+
for line in fin:
|
45 |
+
try:
|
46 |
+
myjson = json.loads(line)
|
47 |
+
url = myjson['url']
|
48 |
+
if url in urls:
|
49 |
+
print('removing', myjson)
|
50 |
+
removed_docs += 1
|
51 |
+
removed_chars += len(myjson['text'])
|
52 |
+
continue
|
53 |
+
myjson = json.dumps(myjson, ensure_ascii=False)
|
54 |
+
fout.write(myjson.encode('utf-8'))
|
55 |
+
fout.write('\n'.encode('utf-8'))
|
56 |
+
written_docs += 1
|
57 |
+
if written_docs % 10000 == 0:
|
58 |
+
print(' [PROCESSED] time (s): {:.2f} | written: {} '
|
59 |
+
'| removed: {} (char: {})'.format(
|
60 |
+
time.time() - start_time,
|
61 |
+
written_docs, removed_docs, removed_chars))
|
62 |
+
except Exception as e:
|
63 |
+
print('[SKIPPING]', line, e)
|
64 |
+
|
65 |
+
print(' [PROCESSED] time (s): {:.2f} | written: {} '
|
66 |
+
'| removed: {} (char: {})'.format(
|
67 |
+
time.time() - start_time,
|
68 |
+
written_docs, removed_docs, removed_chars))
|
69 |
+
print('done :-)')
|
tools/preprocess_data.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Processing data for pretraining."""
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import json
|
20 |
+
import multiprocessing
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
|
24 |
+
os.path.pardir)))
|
25 |
+
import time
|
26 |
+
|
27 |
+
import torch
|
28 |
+
try:
|
29 |
+
import nltk
|
30 |
+
nltk_available = True
|
31 |
+
except ImportError:
|
32 |
+
nltk_available = False
|
33 |
+
|
34 |
+
from megatron.tokenizer import build_tokenizer
|
35 |
+
from megatron.data import indexed_dataset
|
36 |
+
|
37 |
+
|
38 |
+
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
|
39 |
+
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
|
40 |
+
|
41 |
+
_period_context_fmt = r"""
|
42 |
+
\S* # some word material
|
43 |
+
%(SentEndChars)s # a potential sentence ending
|
44 |
+
\s* # <-- THIS is what I changed
|
45 |
+
(?=(?P<after_tok>
|
46 |
+
%(NonWord)s # either other punctuation
|
47 |
+
|
|
48 |
+
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
|
49 |
+
))"""
|
50 |
+
|
51 |
+
class IdentitySplitter(object):
|
52 |
+
def tokenize(self, *text):
|
53 |
+
return text
|
54 |
+
|
55 |
+
class Encoder(object):
|
56 |
+
def __init__(self, args):
|
57 |
+
self.args = args
|
58 |
+
|
59 |
+
def initializer(self):
|
60 |
+
# Use Encoder class as a container for global data
|
61 |
+
Encoder.tokenizer = build_tokenizer(self.args)
|
62 |
+
if self.args.split_sentences:
|
63 |
+
if not nltk_available:
|
64 |
+
print("NLTK is not available to split sentences.")
|
65 |
+
exit()
|
66 |
+
splitter = nltk.load("tokenizers/punkt/english.pickle")
|
67 |
+
if self.args.keep_newlines:
|
68 |
+
# this prevents punkt from eating newlines after sentences
|
69 |
+
Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
|
70 |
+
train_text = splitter._params,
|
71 |
+
lang_vars = CustomLanguageVars())
|
72 |
+
else:
|
73 |
+
Encoder.splitter = splitter
|
74 |
+
|
75 |
+
else:
|
76 |
+
Encoder.splitter = IdentitySplitter()
|
77 |
+
|
78 |
+
def encode(self, json_line):
|
79 |
+
data = json.loads(json_line)
|
80 |
+
ids = {}
|
81 |
+
for key in self.args.json_keys:
|
82 |
+
text = data[key]
|
83 |
+
doc_ids = []
|
84 |
+
for sentence in Encoder.splitter.tokenize(text):
|
85 |
+
sentence_ids = Encoder.tokenizer.tokenize(sentence)
|
86 |
+
if len(sentence_ids) > 0:
|
87 |
+
doc_ids.append(sentence_ids)
|
88 |
+
if len(doc_ids) > 0 and self.args.append_eod:
|
89 |
+
doc_ids[-1].append(Encoder.tokenizer.eod)
|
90 |
+
ids[key] = doc_ids
|
91 |
+
return ids, len(json_line)
|
92 |
+
|
93 |
+
def get_args():
|
94 |
+
parser = argparse.ArgumentParser()
|
95 |
+
group = parser.add_argument_group(title='input data')
|
96 |
+
group.add_argument('--input', type=str, required=True,
|
97 |
+
help='Path to input JSON')
|
98 |
+
group.add_argument('--json-keys', nargs='+', default=['text'],
|
99 |
+
help='space separate listed of keys to extract from json')
|
100 |
+
group.add_argument('--split-sentences', action='store_true',
|
101 |
+
help='Split documents into sentences.')
|
102 |
+
group.add_argument('--keep-newlines', action='store_true',
|
103 |
+
help='Keep newlines between sentences when splitting.')
|
104 |
+
|
105 |
+
group = parser.add_argument_group(title='tokenizer')
|
106 |
+
group.add_argument('--tokenizer-type', type=str, required=True,
|
107 |
+
choices=['BertWordPieceLowerCase','BertWordPieceCase',
|
108 |
+
'GPT2BPETokenizer'],
|
109 |
+
help='What type of tokenizer to use.')
|
110 |
+
group.add_argument('--vocab-file', type=str, default=None,
|
111 |
+
help='Path to the vocab file')
|
112 |
+
group.add_argument('--merge-file', type=str, default=None,
|
113 |
+
help='Path to the BPE merge file (if necessary).')
|
114 |
+
group.add_argument('--append-eod', action='store_true',
|
115 |
+
help='Append an <eod> token to the end of a document.')
|
116 |
+
|
117 |
+
|
118 |
+
group = parser.add_argument_group(title='output data')
|
119 |
+
group.add_argument('--output-prefix', type=str, required=True,
|
120 |
+
help='Path to binary output file without suffix')
|
121 |
+
group.add_argument('--dataset-impl', type=str, default='mmap',
|
122 |
+
choices=['lazy', 'cached', 'mmap'])
|
123 |
+
|
124 |
+
group = parser.add_argument_group(title='runtime')
|
125 |
+
group.add_argument('--workers', type=int, required=True,
|
126 |
+
help='Number of worker processes to launch')
|
127 |
+
group.add_argument('--chunk-size', type=int, required=True,
|
128 |
+
help='Chunk size assigned to each worker process')
|
129 |
+
group.add_argument('--log-interval', type=int, default=100,
|
130 |
+
help='Interval between progress updates')
|
131 |
+
args = parser.parse_args()
|
132 |
+
args.keep_empty = False
|
133 |
+
|
134 |
+
if args.tokenizer_type.lower().startswith('bert'):
|
135 |
+
if not args.split_sentences:
|
136 |
+
print("Bert tokenizer detected, are you sure you don't want to split sentences?")
|
137 |
+
|
138 |
+
# some default/dummy values for the tokenizer
|
139 |
+
args.rank = 0
|
140 |
+
args.make_vocab_size_divisible_by = 128
|
141 |
+
args.tensor_model_parallel_size = 1
|
142 |
+
args.vocab_extra_ids = 0
|
143 |
+
|
144 |
+
return args
|
145 |
+
|
146 |
+
def main():
|
147 |
+
args = get_args()
|
148 |
+
startup_start = time.time()
|
149 |
+
|
150 |
+
print("Opening", args.input)
|
151 |
+
fin = open(args.input, 'r', encoding='utf-8')
|
152 |
+
|
153 |
+
if nltk_available and args.split_sentences:
|
154 |
+
nltk.download("punkt", quiet=True)
|
155 |
+
|
156 |
+
encoder = Encoder(args)
|
157 |
+
tokenizer = build_tokenizer(args)
|
158 |
+
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
|
159 |
+
encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size)
|
160 |
+
#encoded_docs = map(encoder.encode, fin)
|
161 |
+
|
162 |
+
level = "document"
|
163 |
+
if args.split_sentences:
|
164 |
+
level = "sentence"
|
165 |
+
|
166 |
+
print(f"Vocab size: {tokenizer.vocab_size}")
|
167 |
+
print(f"Output prefix: {args.output_prefix}")
|
168 |
+
output_bin_files = {}
|
169 |
+
output_idx_files = {}
|
170 |
+
builders = {}
|
171 |
+
for key in args.json_keys:
|
172 |
+
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
|
173 |
+
key, level)
|
174 |
+
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
|
175 |
+
key, level)
|
176 |
+
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
|
177 |
+
impl=args.dataset_impl,
|
178 |
+
vocab_size=tokenizer.vocab_size)
|
179 |
+
|
180 |
+
startup_end = time.time()
|
181 |
+
proc_start = time.time()
|
182 |
+
total_bytes_processed = 0
|
183 |
+
print("Time to startup:", startup_end - startup_start)
|
184 |
+
|
185 |
+
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
|
186 |
+
total_bytes_processed += bytes_processed
|
187 |
+
for key, sentences in doc.items():
|
188 |
+
if len(sentences) == 0:
|
189 |
+
continue
|
190 |
+
for sentence in sentences:
|
191 |
+
builders[key].add_item(torch.IntTensor(sentence))
|
192 |
+
builders[key].end_document()
|
193 |
+
if i % args.log_interval == 0:
|
194 |
+
current = time.time()
|
195 |
+
elapsed = current - proc_start
|
196 |
+
mbs = total_bytes_processed/elapsed/1024/1024
|
197 |
+
print(f"Processed {i} documents",
|
198 |
+
f"({i/elapsed} docs/s, {mbs} MB/s).",
|
199 |
+
file=sys.stderr)
|
200 |
+
|
201 |
+
for key in args.json_keys:
|
202 |
+
builders[key].finalize(output_idx_files[key])
|
203 |
+
|
204 |
+
if __name__ == '__main__':
|
205 |
+
main()
|
tools/run_build_data.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#INPUT="roberta_train_data_raw/valid.json"
|
3 |
+
INPUT="/mnt/nvme0/ouyangxuan/project_pretrain/make_pretrain_data/roberta_train_data_raw/valid.json"
|
4 |
+
python preprocess_data.py \
|
5 |
+
--input ${INPUT} \
|
6 |
+
--output-prefix my-bert \
|
7 |
+
--vocab bert-vocab.txt \
|
8 |
+
--dataset-impl mmap \
|
9 |
+
--worker 1 \
|
10 |
+
--chunk-size 1 \
|
11 |
+
--tokenizer-type BertWordPieceLowerCase \
|
12 |
+
--split-sentences
|
13 |
+
|
14 |
+
|
15 |
+
#--input /mnt/nvme1/ouyangxuan/project_pretrain/find_framework/tmp_data/data.json \
|
16 |
+
#--input roberta_train_data_raw/train_1g.json \
|
tools/run_text_generation_server.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Sample Generate GPT"""
|
17 |
+
import os
|
18 |
+
import sys
|
19 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
|
20 |
+
os.path.pardir)))
|
21 |
+
import socket
|
22 |
+
from megatron import get_args
|
23 |
+
from megatron import print_rank_0
|
24 |
+
from megatron import mpu
|
25 |
+
from megatron.checkpointing import load_checkpoint
|
26 |
+
from megatron.initialize import initialize_megatron
|
27 |
+
from megatron.model import GPTModel
|
28 |
+
from megatron.training import get_model
|
29 |
+
from megatron.text_generation_server import MegatronServer
|
30 |
+
from megatron.text_generation import generate_and_post_process
|
31 |
+
from megatron.text_generation import beam_search_and_post_process
|
32 |
+
import torch
|
33 |
+
|
34 |
+
def model_provider(pre_process=True, post_process=True):
|
35 |
+
"""Build the model."""
|
36 |
+
|
37 |
+
print_rank_0('building GPT model ...')
|
38 |
+
model = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process)
|
39 |
+
|
40 |
+
return model
|
41 |
+
|
42 |
+
def add_text_generate_args(parser):
|
43 |
+
group = parser.add_argument_group(title='text generation')
|
44 |
+
|
45 |
+
group.add_argument("--temperature", type=float, default=1.0,
|
46 |
+
help='Sampling temperature.')
|
47 |
+
group.add_argument("--top_p", type=float, default=0.0,
|
48 |
+
help='Top p sampling.')
|
49 |
+
group.add_argument("--top_k", type=int, default=0,
|
50 |
+
help='Top k sampling.')
|
51 |
+
group.add_argument("--out-seq-length", type=int, default=1024,
|
52 |
+
help='Size of the output generated text.')
|
53 |
+
return parser
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
initialize_megatron(extra_args_provider=add_text_generate_args,
|
58 |
+
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
|
59 |
+
'no_load_rng': True,
|
60 |
+
'no_load_optim': True})
|
61 |
+
|
62 |
+
args = get_args()
|
63 |
+
if args.num_layers_per_virtual_pipeline_stage is not None:
|
64 |
+
print("Interleaved pipeline schedule is not yet supported for text generation.")
|
65 |
+
exit()
|
66 |
+
# Set up model and load checkpoint
|
67 |
+
model = get_model(model_provider, wrap_with_ddp=False)
|
68 |
+
|
69 |
+
if args.load is not None:
|
70 |
+
_ = load_checkpoint(model, None, None)
|
71 |
+
|
72 |
+
assert len(model) == 1, "Above condition should have caught this"
|
73 |
+
model = model[0]
|
74 |
+
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
|
75 |
+
server = MegatronServer(model)
|
76 |
+
server.run("0.0.0.0")
|
77 |
+
|
78 |
+
while True:
|
79 |
+
choice = torch.cuda.LongTensor(1)
|
80 |
+
torch.distributed.broadcast(choice, 0)
|
81 |
+
if choice[0].item() == 0:
|
82 |
+
try:
|
83 |
+
generate_and_post_process(model)
|
84 |
+
except ValueError as ve:
|
85 |
+
pass
|
86 |
+
elif choice[0].item() == 1:
|
87 |
+
try:
|
88 |
+
beam_search_and_post_process(model)
|
89 |
+
except ValueError as ve:
|
90 |
+
pass
|
tools/text_generation_cli.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import json
|
16 |
+
import sys
|
17 |
+
import urllib2
|
18 |
+
class PutRequest(urllib2.Request):
|
19 |
+
'''class to handling putting with urllib2'''
|
20 |
+
|
21 |
+
def get_method(self, *args, **kwargs):
|
22 |
+
return 'PUT'
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
url = sys.argv[1]
|
26 |
+
while True:
|
27 |
+
sentence = raw_input("Enter prompt: ")
|
28 |
+
tokens_to_generate = int(input("Enter number of tokens to generate: "))
|
29 |
+
data = json.dumps({"prompts": [sentence], "tokens_to_generate":tokens_to_generate})
|
30 |
+
req = PutRequest(url, data, {'Content-Type': 'application/json'})
|
31 |
+
response = urllib2.urlopen(req)
|
32 |
+
resp_sentences = json.load(response)
|
33 |
+
print("Megatron Response: ")
|
34 |
+
print(resp_sentences["text"][0])
|