yuyan-10b / megatron /model /glm_model.py
Shawn001's picture
Upload 131 files
23bd7af
raw
history blame
No virus
13.1 kB
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT model."""
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model import LayerNorm
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Convert attention mask to binary:
extended_attention_mask = (extended_attention_mask < 0.5)
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output):
super(BertLMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel)
self.layernorm = LayerNorm(hidden_size,
eps=layernorm_epsilon,
sequence_parallel=args.sequence_parallel)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
elif args.onnx_safe:
self.gelu = erf_gelu
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states,
word_embeddings_weight,
self.parallel_output,
bias=self.bias)
return output
class GlmModel(MegatronModule):
"""Bert Language model."""
def __init__(self,
num_tokentypes=2,
add_binary_head=True,
parallel_output=True,
pre_process=True,
post_process=True,
add_encoder=True,
add_decoder=True):
super(GlmModel, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.add_encoder = add_encoder
self.add_decoder = add_decoder
self.cal_encoder_loss = True
self.cal_decoder_loss = True
self.cal_sent_loss = True
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
add_encoder=self.add_encoder,
add_decoder=self.add_decoder,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal)
if self.post_process:
self.lm_head = BertLMHead(
self.word_embeddings_weight().size(0),
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
self._lm_head_key = 'lm_head'
self.binary_head = None
if self.add_binary_head:
self.binary_head = get_linear_layer(args.hidden_size, 2,
init_method)
self._binary_head_key = 'binary_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self,
encoder_input_ids,
encoder_labels,
decoder_input_ids,
decoder_labels,
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
tokentype_ids=None,
lm_labels=None,
sentence_labels=None,
enc_hidden_states=None,
):
encoder_position_ids = bert_position_ids(encoder_input_ids)
decoder_position_ids = bert_position_ids(decoder_input_ids)
extended_encoder_attn_mask = bert_extended_attention_mask(encoder_attn_mask)
extended_decoder_attn_mask = bert_extended_attention_mask(decoder_attn_mask)
extended_encoder_decoder_attn_mask = bert_extended_attention_mask(encoder_decoder_attn_mask)
lm_output = self.language_model(encoder_input_ids,
encoder_position_ids,
extended_encoder_attn_mask,
decoder_input_ids,
decoder_position_ids,
extended_decoder_attn_mask,
extended_encoder_decoder_attn_mask,
tokentype_ids=tokentype_ids,
enc_hidden_states=enc_hidden_states)
encoder_outputs, decoder_outputs, pooled_outputs = lm_output
return self.post_language_model_processing(encoder_input_ids,
encoder_outputs,
encoder_labels,
pooled_outputs,
self.lm_head,
self.binary_head,
self.word_embeddings_weight(),
self.fp16_lm_cross_entropy,
decoder_input_ids,
decoder_outputs,
decoder_labels,
decoder_attn_mask,
sentence_labels)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.post_process:
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.post_process and self.add_binary_head:
state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self.post_process:
self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict)
if self.post_process and self.add_binary_head:
self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict)
# Load word_embeddings.
if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
def post_language_model_processing(self, encoder_inputs,
encoder_outputs,
encoder_labels,
pooled_outputs,
lm_head,
binary_head,
logit_weights,
fp16_lm_cross_entropy,
decoder_inputs,
decoder_outputs,
decoder_labels,
attention_mask,
sentence_labels
):
# Output.
describe_tensor = {
"attention_mask" : attention_mask[0]
}
sent_loss = None
if self.cal_sent_loss:
binary_logits = binary_head(pooled_outputs)
sent_loss = F.cross_entropy(binary_logits.view(-1, 2).float(),
sentence_labels.view(-1),
ignore_index=-1)
encoder_loss = None
if self.cal_encoder_loss:
encoder_outputs_logits_ = lm_head(encoder_outputs, logit_weights)
if fp16_lm_cross_entropy:
assert encoder_outputs_logits_.dtype == torch.half
encoder_outputs_logits = encoder_outputs_logits_
else:
encoder_outputs_logits = encoder_outputs_logits_.float()
encoder_loss = mpu.vocab_parallel_cross_entropy(encoder_outputs_logits, encoder_labels.transpose(0, 1).contiguous())
describe_tensor["encoder_inputs"] = encoder_inputs[0]
describe_tensor["encoder_outputs"] = encoder_outputs_logits[0].transpose(0, 1)
describe_tensor["encoder_labels"] = encoder_labels[0]
decoder_loss = None
if self.cal_decoder_loss:
decoder_outputs_logits_ = decoder_outputs
if fp16_lm_cross_entropy:
assert decoder_outputs_logits_.dtype == torch.half
decoder_outputs_logits = decoder_outputs_logits_
else:
decoder_outputs_logits = decoder_outputs_logits_.float()
decoder_loss = mpu.vocab_parallel_cross_entropy(decoder_outputs_logits, decoder_labels.transpose(0, 1).contiguous())
describe_tensor["encoder_inputs"] = encoder_inputs[0]
describe_tensor["decoder_inputs"] = decoder_inputs[0]
describe_tensor["decoder_outputs"] = decoder_outputs_logits[0].transpose(0, 1)
describe_tensor["decoder_labels"] = decoder_labels[0]
return encoder_loss, decoder_loss, sent_loss, describe_tensor