|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""BERT model.""" |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from megatron import get_args |
|
from megatron import mpu |
|
from megatron.model.enums import AttnMaskType |
|
from megatron.model.language_model import parallel_lm_logits |
|
from megatron.model.language_model import get_language_model |
|
from megatron.model import LayerNorm |
|
from megatron.model.utils import openai_gelu, erf_gelu |
|
from megatron.model.utils import get_linear_layer |
|
from megatron.model.utils import init_method_normal |
|
from megatron.model.utils import scaled_init_method_normal |
|
from .module import MegatronModule |
|
|
|
def bert_extended_attention_mask(attention_mask): |
|
|
|
|
|
attention_mask_b1s = attention_mask.unsqueeze(1) |
|
|
|
attention_mask_bs1 = attention_mask.unsqueeze(2) |
|
|
|
attention_mask_bss = attention_mask_b1s * attention_mask_bs1 |
|
|
|
extended_attention_mask = attention_mask_bss.unsqueeze(1) |
|
|
|
|
|
extended_attention_mask = (extended_attention_mask < 0.5) |
|
|
|
return extended_attention_mask |
|
|
|
def bert_position_ids(token_ids): |
|
|
|
seq_length = token_ids.size(1) |
|
position_ids = torch.arange(seq_length, dtype=torch.long, |
|
device=token_ids.device) |
|
position_ids = position_ids.unsqueeze(0).expand_as(token_ids) |
|
|
|
return position_ids |
|
|
|
|
|
class BertLMHead(MegatronModule): |
|
"""Masked LM head for Bert |
|
|
|
Arguments: |
|
mpu_vocab_size: model parallel size of vocabulary. |
|
hidden_size: hidden size |
|
init_method: init method for weight initialization |
|
layernorm_epsilon: tolerance for layer norm divisions |
|
parallel_output: whether output logits being distributed or not. |
|
""" |
|
|
|
def __init__(self, mpu_vocab_size, hidden_size, init_method, |
|
layernorm_epsilon, parallel_output): |
|
|
|
super(BertLMHead, self).__init__() |
|
|
|
args = get_args() |
|
|
|
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) |
|
mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) |
|
self.parallel_output = parallel_output |
|
|
|
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) |
|
setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel) |
|
setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel) |
|
|
|
self.layernorm = LayerNorm(hidden_size, |
|
eps=layernorm_epsilon, |
|
sequence_parallel=args.sequence_parallel) |
|
self.gelu = torch.nn.functional.gelu |
|
if args.openai_gelu: |
|
self.gelu = openai_gelu |
|
elif args.onnx_safe: |
|
self.gelu = erf_gelu |
|
|
|
def forward(self, hidden_states, word_embeddings_weight): |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.gelu(hidden_states) |
|
hidden_states = self.layernorm(hidden_states) |
|
output = parallel_lm_logits(hidden_states, |
|
word_embeddings_weight, |
|
self.parallel_output, |
|
bias=self.bias) |
|
return output |
|
|
|
class GlmModel(MegatronModule): |
|
"""Bert Language model.""" |
|
|
|
def __init__(self, |
|
num_tokentypes=2, |
|
add_binary_head=True, |
|
parallel_output=True, |
|
pre_process=True, |
|
post_process=True, |
|
add_encoder=True, |
|
add_decoder=True): |
|
super(GlmModel, self).__init__() |
|
args = get_args() |
|
|
|
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy |
|
self.add_binary_head = add_binary_head |
|
self.parallel_output = parallel_output |
|
self.pre_process = pre_process |
|
self.post_process = post_process |
|
self.add_encoder = add_encoder |
|
self.add_decoder = add_decoder |
|
|
|
self.cal_encoder_loss = True |
|
self.cal_decoder_loss = True |
|
self.cal_sent_loss = True |
|
|
|
init_method = init_method_normal(args.init_method_std) |
|
scaled_init_method = scaled_init_method_normal(args.init_method_std, |
|
args.num_layers) |
|
|
|
self.language_model, self._language_model_key = get_language_model( |
|
num_tokentypes=num_tokentypes, |
|
add_pooler=self.add_binary_head, |
|
add_encoder=self.add_encoder, |
|
add_decoder=self.add_decoder, |
|
encoder_attn_mask_type=AttnMaskType.padding, |
|
init_method=init_method, |
|
scaled_init_method=scaled_init_method, |
|
pre_process=self.pre_process, |
|
post_process=self.post_process) |
|
|
|
self.initialize_word_embeddings(init_method_normal) |
|
if self.post_process: |
|
self.lm_head = BertLMHead( |
|
self.word_embeddings_weight().size(0), |
|
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output) |
|
self._lm_head_key = 'lm_head' |
|
self.binary_head = None |
|
if self.add_binary_head: |
|
self.binary_head = get_linear_layer(args.hidden_size, 2, |
|
init_method) |
|
self._binary_head_key = 'binary_head' |
|
|
|
def set_input_tensor(self, input_tensor): |
|
"""See megatron.model.transformer.set_input_tensor()""" |
|
self.language_model.set_input_tensor(input_tensor) |
|
|
|
def forward(self, |
|
encoder_input_ids, |
|
encoder_labels, |
|
decoder_input_ids, |
|
decoder_labels, |
|
encoder_attn_mask, |
|
decoder_attn_mask, |
|
encoder_decoder_attn_mask, |
|
tokentype_ids=None, |
|
lm_labels=None, |
|
sentence_labels=None, |
|
enc_hidden_states=None, |
|
): |
|
|
|
encoder_position_ids = bert_position_ids(encoder_input_ids) |
|
decoder_position_ids = bert_position_ids(decoder_input_ids) |
|
extended_encoder_attn_mask = bert_extended_attention_mask(encoder_attn_mask) |
|
extended_decoder_attn_mask = bert_extended_attention_mask(decoder_attn_mask) |
|
extended_encoder_decoder_attn_mask = bert_extended_attention_mask(encoder_decoder_attn_mask) |
|
|
|
lm_output = self.language_model(encoder_input_ids, |
|
encoder_position_ids, |
|
extended_encoder_attn_mask, |
|
decoder_input_ids, |
|
decoder_position_ids, |
|
extended_decoder_attn_mask, |
|
extended_encoder_decoder_attn_mask, |
|
tokentype_ids=tokentype_ids, |
|
enc_hidden_states=enc_hidden_states) |
|
|
|
encoder_outputs, decoder_outputs, pooled_outputs = lm_output |
|
|
|
return self.post_language_model_processing(encoder_input_ids, |
|
encoder_outputs, |
|
encoder_labels, |
|
pooled_outputs, |
|
self.lm_head, |
|
self.binary_head, |
|
self.word_embeddings_weight(), |
|
self.fp16_lm_cross_entropy, |
|
decoder_input_ids, |
|
decoder_outputs, |
|
decoder_labels, |
|
decoder_attn_mask, |
|
sentence_labels) |
|
|
|
def state_dict_for_save_checkpoint(self, destination=None, prefix='', |
|
keep_vars=False): |
|
"""For easy load when model is combined with other heads, |
|
add an extra key.""" |
|
|
|
state_dict_ = {} |
|
state_dict_[self._language_model_key] \ |
|
= self.language_model.state_dict_for_save_checkpoint( |
|
destination, prefix, keep_vars) |
|
if self.post_process: |
|
state_dict_[self._lm_head_key] \ |
|
= self.lm_head.state_dict_for_save_checkpoint( |
|
destination, prefix, keep_vars) |
|
if self.post_process and self.add_binary_head: |
|
state_dict_[self._binary_head_key] \ |
|
= self.binary_head.state_dict(destination, prefix, keep_vars) |
|
|
|
if self.post_process and not self.pre_process: |
|
state_dict_[self._word_embeddings_for_head_key] \ |
|
= self.word_embeddings.state_dict(destination, prefix, keep_vars) |
|
return state_dict_ |
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
"""Customized load.""" |
|
|
|
self.language_model.load_state_dict( |
|
state_dict[self._language_model_key], strict=strict) |
|
if self.post_process: |
|
self.lm_head.load_state_dict( |
|
state_dict[self._lm_head_key], strict=strict) |
|
if self.post_process and self.add_binary_head: |
|
self.binary_head.load_state_dict( |
|
state_dict[self._binary_head_key], strict=strict) |
|
|
|
if self.post_process and not self.pre_process: |
|
self.word_embeddings.load_state_dict( |
|
state_dict[self._word_embeddings_for_head_key], strict=strict) |
|
|
|
def post_language_model_processing(self, encoder_inputs, |
|
encoder_outputs, |
|
encoder_labels, |
|
pooled_outputs, |
|
lm_head, |
|
binary_head, |
|
logit_weights, |
|
fp16_lm_cross_entropy, |
|
decoder_inputs, |
|
decoder_outputs, |
|
decoder_labels, |
|
attention_mask, |
|
sentence_labels |
|
): |
|
|
|
describe_tensor = { |
|
"attention_mask" : attention_mask[0] |
|
} |
|
|
|
sent_loss = None |
|
if self.cal_sent_loss: |
|
binary_logits = binary_head(pooled_outputs) |
|
sent_loss = F.cross_entropy(binary_logits.view(-1, 2).float(), |
|
sentence_labels.view(-1), |
|
ignore_index=-1) |
|
|
|
encoder_loss = None |
|
if self.cal_encoder_loss: |
|
encoder_outputs_logits_ = lm_head(encoder_outputs, logit_weights) |
|
if fp16_lm_cross_entropy: |
|
assert encoder_outputs_logits_.dtype == torch.half |
|
encoder_outputs_logits = encoder_outputs_logits_ |
|
else: |
|
encoder_outputs_logits = encoder_outputs_logits_.float() |
|
encoder_loss = mpu.vocab_parallel_cross_entropy(encoder_outputs_logits, encoder_labels.transpose(0, 1).contiguous()) |
|
describe_tensor["encoder_inputs"] = encoder_inputs[0] |
|
describe_tensor["encoder_outputs"] = encoder_outputs_logits[0].transpose(0, 1) |
|
describe_tensor["encoder_labels"] = encoder_labels[0] |
|
|
|
decoder_loss = None |
|
if self.cal_decoder_loss: |
|
decoder_outputs_logits_ = decoder_outputs |
|
if fp16_lm_cross_entropy: |
|
assert decoder_outputs_logits_.dtype == torch.half |
|
decoder_outputs_logits = decoder_outputs_logits_ |
|
else: |
|
decoder_outputs_logits = decoder_outputs_logits_.float() |
|
decoder_loss = mpu.vocab_parallel_cross_entropy(decoder_outputs_logits, decoder_labels.transpose(0, 1).contiguous()) |
|
describe_tensor["encoder_inputs"] = encoder_inputs[0] |
|
describe_tensor["decoder_inputs"] = decoder_inputs[0] |
|
describe_tensor["decoder_outputs"] = decoder_outputs_logits[0].transpose(0, 1) |
|
describe_tensor["decoder_labels"] = decoder_labels[0] |
|
|
|
return encoder_loss, decoder_loss, sent_loss, describe_tensor |