yuyan-10b / tools /checkpoint_util.py
Shawn001's picture
Upload 21 files
1101a21
raw
history blame contribute delete
No virus
4.77 kB
import argparse
import importlib
import torch.multiprocessing as mp
import os
import sys
# A loader is a python file with at least two functions
# - add_arguments - takes in a parser and adds any arguments needed
# - load_checkpoint - takes in the queue and parsed arguments
# A saver is similar but has save_checkpoint instead of
# load_checkpoint
# The loader and saver process are each given a queue, the loader
# should load the checkpoint and send the weights in messages in the
# following order, the saver should receive them in this order and
# save the checkpoints. A message consists of a python dictionary with
# a "name" for error checking and an entry for each tensor as
# indicated below. Note that the weight sent over the queue are the
# full model weights, nothing split.
# If the loader ever sends "exit" to the queue, that means something
# went wrong and it is exiting.
# - Metadata Namespace with the following attributes:
# model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line)
# num_layers - Number of transformer layers
# hidden_size
# seq_length
# num_attention_heads
# max_position_embeddings
# tokenizer_type
# iteration
# params_dtype
# bert_binary_head - Used only if model_type is BERT
# previous_tensor_parallel_size - Optional
# previous_pipeline_parallel_size - Optional
# true_vocab_size
# make_vocab_size_divisble_by
# consumed_train_samples
# consumed_valid_samples
# messages
# {
# "name": "embeddings"
# "position embeddings"
# "word embeddings"
# }
# (for each transformer layer):
# {
# "name": "transformer layer N"
# "input layernorm weight"
# "input layernorm bias"
# "qkv weight"
# "qkv bias"
# "dense weight"
# "dense bias"
# "post layernorm weight"
# "post layernorm bias"
# "mlp l0 weight"
# "mlp l0 bias"
# "mlp l1 weight"
# "mlp l1 bias"
# }
# {
# "name": "final layer norm"
# "weight"
# "bias"
# }
# if present (i.e. for BERT):
# {
# "name": "pooler"
# "weight"
# "bias"
# }
# {
# "name": "lm head"
# "dense weight"
# "dense bias"
# "layernorm weight"
# "layernorm bias"
# }
# {
# "name": "binary head"
# "weight"
# "bias"
# }
# - "done"
def load_plugin(plugin_type, name):
module_name = f"checkpoint_{plugin_type}_{name}"
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
module_name = name
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.")
if not hasattr(plugin, 'add_arguments'):
sys.exit(f"{module_name} module is not a plugin. Exiting.")
print(f"Loaded {module_name} as the {plugin_type}.")
return plugin
def main():
import argparse
parser = argparse.ArgumentParser(description="Megatron Checkpoint Utility Arguments",
allow_abbrev=False, conflict_handler='resolve')
parser.add_argument('--model-type', type=str, required=True,
choices=['GPT', 'BERT'],
help='Type of the model')
parser.add_argument('--loader', type=str, default='megatron',
help='Module name to load checkpoint, should be on python path')
parser.add_argument('--saver', type=str, default='megatron',
help='Module name to save checkpoint, shdoul be on python path')
parser.add_argument('--load-dir', type=str, required=True,
help='Directory to load model checkpoint from')
parser.add_argument('--save-dir', type=str, required=True,
help='Directory to save model checkpoint to')
parser.add_argument('--max-queue-size', type=int, default=50,
help='Maximum number of tensors in the queue')
parser.add_argument('--no-checking', action='store_false',
help='Do not perform checking on the name and ordering of weights',
dest='checking')
known_args, _ = parser.parse_known_args()
loader = load_plugin('loader', known_args.loader)
saver = load_plugin('saver', known_args.saver)
loader.add_arguments(parser)
saver.add_arguments(parser)
args = parser.parse_args()
queue = mp.Queue(maxsize=args.max_queue_size)
print("Starting saver...")
saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args))
saver_proc.start()
print("Starting loader...")
loader.load_checkpoint(queue, args)
print("Waiting for saver to complete...")
saver_proc.join()
if __name__ == '__main__':
main()