import argparse | |
import importlib | |
import torch.multiprocessing as mp | |
import os | |
import sys | |
# A loader is a python file with at least two functions | |
# - add_arguments - takes in a parser and adds any arguments needed | |
# - load_checkpoint - takes in the queue and parsed arguments | |
# A saver is similar but has save_checkpoint instead of | |
# load_checkpoint | |
# The loader and saver process are each given a queue, the loader | |
# should load the checkpoint and send the weights in messages in the | |
# following order, the saver should receive them in this order and | |
# save the checkpoints. A message consists of a python dictionary with | |
# a "name" for error checking and an entry for each tensor as | |
# indicated below. Note that the weight sent over the queue are the | |
# full model weights, nothing split. | |
# If the loader ever sends "exit" to the queue, that means something | |
# went wrong and it is exiting. | |
# - Metadata Namespace with the following attributes: | |
# model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line) | |
# num_layers - Number of transformer layers | |
# hidden_size | |
# seq_length | |
# num_attention_heads | |
# max_position_embeddings | |
# tokenizer_type | |
# iteration | |
# params_dtype | |
# bert_binary_head - Used only if model_type is BERT | |
# previous_tensor_parallel_size - Optional | |
# previous_pipeline_parallel_size - Optional | |
# true_vocab_size | |
# make_vocab_size_divisble_by | |
# consumed_train_samples | |
# consumed_valid_samples | |
# messages | |
# { | |
# "name": "embeddings" | |
# "position embeddings" | |
# "word embeddings" | |
# } | |
# (for each transformer layer): | |
# { | |
# "name": "transformer layer N" | |
# "input layernorm weight" | |
# "input layernorm bias" | |
# "qkv weight" | |
# "qkv bias" | |
# "dense weight" | |
# "dense bias" | |
# "post layernorm weight" | |
# "post layernorm bias" | |
# "mlp l0 weight" | |
# "mlp l0 bias" | |
# "mlp l1 weight" | |
# "mlp l1 bias" | |
# } | |
# { | |
# "name": "final layer norm" | |
# "weight" | |
# "bias" | |
# } | |
# if present (i.e. for BERT): | |
# { | |
# "name": "pooler" | |
# "weight" | |
# "bias" | |
# } | |
# { | |
# "name": "lm head" | |
# "dense weight" | |
# "dense bias" | |
# "layernorm weight" | |
# "layernorm bias" | |
# } | |
# { | |
# "name": "binary head" | |
# "weight" | |
# "bias" | |
# } | |
# - "done" | |
def load_plugin(plugin_type, name): | |
module_name = f"checkpoint_{plugin_type}_{name}" | |
try: | |
plugin = importlib.import_module(module_name) | |
except ModuleNotFoundError: | |
module_name = name | |
try: | |
plugin = importlib.import_module(module_name) | |
except ModuleNotFoundError: | |
sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.") | |
if not hasattr(plugin, 'add_arguments'): | |
sys.exit(f"{module_name} module is not a plugin. Exiting.") | |
print(f"Loaded {module_name} as the {plugin_type}.") | |
return plugin | |
def main(): | |
import argparse | |
parser = argparse.ArgumentParser(description="Megatron Checkpoint Utility Arguments", | |
allow_abbrev=False, conflict_handler='resolve') | |
parser.add_argument('--model-type', type=str, required=True, | |
choices=['GPT', 'BERT'], | |
help='Type of the model') | |
parser.add_argument('--loader', type=str, default='megatron', | |
help='Module name to load checkpoint, should be on python path') | |
parser.add_argument('--saver', type=str, default='megatron', | |
help='Module name to save checkpoint, shdoul be on python path') | |
parser.add_argument('--load-dir', type=str, required=True, | |
help='Directory to load model checkpoint from') | |
parser.add_argument('--save-dir', type=str, required=True, | |
help='Directory to save model checkpoint to') | |
parser.add_argument('--max-queue-size', type=int, default=50, | |
help='Maximum number of tensors in the queue') | |
parser.add_argument('--no-checking', action='store_false', | |
help='Do not perform checking on the name and ordering of weights', | |
dest='checking') | |
known_args, _ = parser.parse_known_args() | |
loader = load_plugin('loader', known_args.loader) | |
saver = load_plugin('saver', known_args.saver) | |
loader.add_arguments(parser) | |
saver.add_arguments(parser) | |
args = parser.parse_args() | |
queue = mp.Queue(maxsize=args.max_queue_size) | |
print("Starting saver...") | |
saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args)) | |
saver_proc.start() | |
print("Starting loader...") | |
loader.load_checkpoint(queue, args) | |
print("Waiting for saver to complete...") | |
saver_proc.join() | |
if __name__ == '__main__': | |
main() | |