Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn commited on
Commit
cbfed57
1 Parent(s): d12d71b

Update modeling_hf_nomic_bert.py

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +145 -149
modeling_hf_nomic_bert.py CHANGED
@@ -3,39 +3,34 @@
3
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
  # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
 
 
 
6
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
7
  import os
8
- import logging
 
9
  from functools import partial
10
- from typing import Optional, List, Tuple, Union
11
 
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
  from einops import rearrange, repeat
 
16
  from transformers import GPT2Config, PreTrainedModel
17
  from transformers.models.bert.modeling_bert import (
18
  BaseModelOutputWithPoolingAndCrossAttentions,
19
  MaskedLMOutput,
20
- SequenceClassifierOutput
21
- )
22
-
23
- import re
24
- from collections import OrderedDict
25
- from safetensors.torch import load_file as safe_load_file
26
- from transformers.utils import (
27
- SAFE_WEIGHTS_INDEX_NAME,
28
- SAFE_WEIGHTS_NAME,
29
- WEIGHTS_INDEX_NAME,
30
- WEIGHTS_NAME,
31
  )
 
32
  from transformers.utils.hub import cached_file, get_checkpoint_shard_files
33
 
34
-
35
  from .configuration_hf_nomic_bert import NomicBertConfig
36
 
37
  logger = logging.getLogger(__name__)
38
 
 
39
  # adapted from flash attention, added safe serialization option for hf models
40
  def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
41
  # If not fp32, then we don't want to load directly to the GPU
@@ -50,18 +45,12 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
50
  safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
51
 
52
  if os.path.isfile(weights_path):
53
- resolved_archive_file = cached_file(
54
- model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
55
- )
56
  elif os.path.isfile(weights_index_path):
57
- resolved_archive_file = cached_file(
58
- model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
59
- )
60
  is_sharded = True
61
  elif os.path.isfile(safe_weights_path):
62
- resolved_archive_file = cached_file(
63
- model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
64
- )
65
  load_safe = True
66
  elif os.path.isfile(safe_weights_index_path):
67
  resolved_archive_file = cached_file(
@@ -74,8 +63,7 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
74
  resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False)
75
  if resolved_archive_file is None:
76
  weight_index = WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
77
- resolved_archive_file = cached_file(model_name, weight_index,
78
- _raise_exceptions_for_missing_entries=False)
79
  if resolved_archive_file is not None:
80
  is_sharded = True
81
 
@@ -92,9 +80,7 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
92
  if is_sharded:
93
  # resolved_archive_file becomes a list of files that point to the different
94
  # checkpoint shards in this case.
95
- resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
96
- model_name, resolved_archive_file
97
- )
98
  state_dict = {}
99
  for sharded_file in resolved_archive_file:
100
  state_dict.update(loader(sharded_file))
@@ -106,7 +92,7 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
106
  state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
107
  return state_dict
108
 
109
-
110
  def filter_shapes(state_dict, model):
111
  """
112
  Filters the state dict to match the current model shape.
@@ -118,11 +104,18 @@ def filter_shapes(state_dict, model):
118
  filtered_state_dict[key] = value
119
  return filtered_state_dict
120
 
121
-
122
- def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weights=False, add_pooling_layer=False):
 
 
 
 
 
 
123
  """
124
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
125
  """
 
126
  def add_bert_prefix(key):
127
  # prepend bert. to the key
128
  if key.startswith("bert.") or key.startswith("cls."):
@@ -130,7 +123,7 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
130
  return f"bert.{key}"
131
 
132
  state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
133
-
134
  # LayerNorm
135
  def key_mapping_ln_gamma_beta(key):
136
  key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
@@ -195,9 +188,7 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
195
  bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
196
  bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
197
  if not (last_layer_subset and d == config.num_hidden_layers - 1):
198
- state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat(
199
- [Wq, Wk, Wv], dim=0
200
- )
201
  state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
202
  else:
203
  state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
@@ -217,7 +208,6 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
217
  def key_mapping_decoder_bias(key):
218
  return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
219
 
220
-
221
  # remove nsp weights, we don't use
222
  state_dict.pop("cls.seq_relationship.weight", None)
223
  state_dict.pop("cls.seq_relationship.bias", None)
@@ -226,12 +216,14 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
226
  state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
227
 
228
  if remove_cls_weights:
229
- cls_weights = ["cls.predictions.decoder.bias",
230
- "cls.predictions.transform.dense.weight",
231
- "cls.predictions.transform.dense.bias",
232
- "cls.predictions.transform.layer_norm.weight",
233
- "cls.predictions.transform.layer_norm.bias",
234
- "cls.predictions.decoder.weight"]
 
 
235
  for weight in cls_weights:
236
  state_dict.pop(weight, None)
237
 
@@ -257,20 +249,21 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
257
  )
258
 
259
  if add_pooling_layer is False:
260
- pooler_weights = ["bert.pooler.dense.weight",
261
- "bert.pooler.dense.bias",
262
- ]
 
263
  for key in pooler_weights:
264
  state_dict.pop(key, None)
265
 
266
  if remove_bert:
 
267
  def remove_bert_prefix(key):
268
  key = re.sub(r"^bert.", "", key)
269
  return key
270
 
271
  state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
272
 
273
-
274
  return state_dict
275
 
276
 
@@ -278,6 +271,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
278
  """An abstract class to handle weights initialization and
279
  a simple interface for dowloading and loading pretrained models.
280
  """
 
281
  config_class = NomicBertConfig
282
  base_model_prefix = "model"
283
  supports_gradient_checkpointing = True
@@ -317,14 +311,13 @@ class NomicBertPreTrainedModel(PreTrainedModel):
317
  if config is None:
318
  config = cls.config_class.from_pretrained(model_name)
319
  remove_cls = cls != NomicBertForPreTraining
320
- remove_bert_prefix = cls != NomicBertForPreTraining
321
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
322
  num_labels = kwargs.pop("num_labels", None)
323
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
324
  if rotary_scaling_factor:
325
  config.rotary_scaling_factor = rotary_scaling_factor
326
- else:
327
- config.rotary_scaling_factor = None
328
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
329
  config.n_positions = 2048
330
  if num_labels:
@@ -341,26 +334,34 @@ class NomicBertPreTrainedModel(PreTrainedModel):
341
  # Assuming we know what we're doing when loading from disk
342
  # Prob a bad assumption but i'm tired and want to train this asap
343
  if os.path.exists(model_name):
344
- state_dict = torch.load(f"{model_name}/pytorch_model.bin")
 
 
 
 
 
 
 
 
345
  if ignore_mismatched_shapes:
346
  state_dict = filter_shapes(state_dict, model)
347
  load_return = model.load_state_dict(state_dict, strict=False)
348
  else:
349
  # TODO: can probably check config class and see if we need to remap from a bert model
350
- state_dict = state_dict_from_pretrained(model_name)
351
- state_dict = remap_bert_state_dict(state_dict,
352
- config,
353
- remove_bert=remove_bert_prefix,
354
- remove_cls_weights=remove_cls,
355
- add_pooling_layer=getattr(config, "add_pooling_layer", False)
356
- )
 
 
 
357
  if ignore_mismatched_shapes:
358
  state_dict = filter_shapes(state_dict, model)
359
 
360
- load_return = model.load_state_dict(
361
- state_dict,
362
- strict=True
363
- )
364
  logger.warning(load_return)
365
  return model
366
 
@@ -380,25 +381,21 @@ def _init_weights(module, initializer_range=0.02):
380
  if module.padding_idx is not None:
381
  nn.init.zeros_(module.weight[module.padding_idx])
382
 
383
-
384
  class NomicBertEmbeddings(nn.Module):
385
- def __init__(
386
- self,
387
- config
388
- ):
389
  """
390
  If max_position_embeddings <= 0, there's no position embeddings
391
  If type_vocab_size <= 0, there's no token type embeddings
392
  """
393
  super().__init__()
394
- self.word_embeddings = nn.Embedding(
395
- config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
396
- )
397
  self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
398
  self.type_vocab_size = config.type_vocab_size
399
  if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
400
  self.position_embeddings = nn.Embedding(
401
- config.max_position_embeddings, config.hidden_size,
 
402
  )
403
  if self.type_vocab_size > 0:
404
  self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
@@ -425,6 +422,7 @@ class NomicBertEmbeddings(nn.Module):
425
  embeddings = embeddings + position_embeddings
426
  return embeddings
427
 
 
428
  class NomicBertMLP(nn.Module):
429
  def __init__(
430
  self,
@@ -442,11 +440,7 @@ class NomicBertMLP(nn.Module):
442
  hidden_features = hidden_features if hidden_features is not None else in_features * 4
443
  self.return_residual = return_residual
444
  self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
445
- approximate = (
446
- "tanh"
447
- if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
448
- else "none"
449
- )
450
  self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
451
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
452
 
@@ -456,7 +450,7 @@ class NomicBertMLP(nn.Module):
456
  y = self.fc2(y)
457
  return y if not self.return_residual else (y, x)
458
 
459
-
460
  class NomciBertGatedMLP(nn.Module):
461
  def __init__(
462
  self,
@@ -474,9 +468,7 @@ class NomciBertGatedMLP(nn.Module):
474
  ):
475
  super().__init__()
476
  out_features = out_features if out_features is not None else in_features
477
- hidden_features = (
478
- hidden_features if hidden_features is not None else int(8 * in_features / 3)
479
- )
480
  hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
481
  self.return_residual = return_residual
482
 
@@ -513,8 +505,8 @@ def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False):
513
  ro_dim = cos.shape[-1] * 2
514
  assert ro_dim <= x.shape[-1]
515
  cos, sin = (
516
- cos[offset: offset + x.shape[1]],
517
- sin[offset: offset + x.shape[1]],
518
  )
519
  cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
520
  sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
@@ -571,10 +563,7 @@ class NomicBertRotaryEmbedding(nn.Module):
571
  self._sin_k_cached = None
572
 
573
  def _compute_inv_freq(self, device=None):
574
- return 1.0 / (
575
- self.base
576
- ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
577
- )
578
 
579
  def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
580
  # Reset the tables if the sequence length has changed,
@@ -646,14 +635,10 @@ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
646
  self.rotary_scaling_factor = rotary_scaling_factor
647
  self.max_position_embeddings = max_position_embeddings
648
 
649
-
650
  def _compute_inv_freq(self, base=None, device=None):
651
  if base is None:
652
  base = self.base
653
- return 1.0 / (
654
- base
655
- ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
656
- )
657
 
658
  def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
659
  # Reset the tables if the sequence length has changed,
@@ -704,8 +689,7 @@ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
704
  self._sin_cached = torch.sin(freqs).to(dtype)
705
  else:
706
  power = (
707
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
708
- - seqlen // 2
709
  ) / self.scale_base
710
  scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
711
  # We want the multiplication by scale to happen in fp32
@@ -714,6 +698,7 @@ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
714
  self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
715
  self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
716
 
 
717
  class NomicBertAttention(nn.Module):
718
  """Multi-head self-attention and cross-attention"""
719
 
@@ -754,8 +739,8 @@ class NomicBertAttention(nn.Module):
754
  scale_base=config.rotary_emb_scale_base,
755
  interleaved=config.rotary_emb_interleaved,
756
  rotary_scaling_factor=config.rotary_scaling_factor,
757
- max_position_embeddings=config.n_positions,
758
- )
759
  else:
760
  self.rotary_emb = NomicBertRotaryEmbedding(
761
  dim=self.rotary_emb_dim,
@@ -826,7 +811,7 @@ class NomicBertAttention(nn.Module):
826
  attn_output = self.out_proj(attn_output)
827
 
828
  return attn_output
829
-
830
 
831
  class NomicBertBlock(nn.Module):
832
  def __init__(
@@ -836,17 +821,31 @@ class NomicBertBlock(nn.Module):
836
  super().__init__()
837
  self.prenorm = config.prenorm
838
  self.fused_dropout_add_ln = config.fused_dropout_add_ln
839
-
840
- self.attn = NomicBertAttention(config)
841
  activation = (
842
- F.sigmoid
843
- if config.activation_function == "glu"
844
- else (F.silu if config.activation_function == "swiglu" else F.gelu)
845
  )
846
  if config.activation_function in ["glu", "swiglu", "geglu"]:
847
- self.mlp = NomciBertGatedMLP(config.n_embd, hidden_features=config.n_inner, bias1=config.mlp_fc1_bias, bias2=config.mlp_fc2_bias, activation=activation, fused_bias_fc=config.fused_bias_fc)
 
 
 
 
 
 
 
848
  else:
849
- self.mlp = NomicBertMLP(config.n_embd, hidden_features=config.n_inner, bias1=config.mlp_fc1_bias, bias2=config.mlp_fc2_bias, activation=activation, fused_bias_fc=config.fused_bias_fc)
 
 
 
 
 
 
 
850
 
851
  self.dropout1 = nn.Dropout(config.resid_pdrop)
852
  self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
@@ -880,7 +879,13 @@ class NomicBertBlock(nn.Module):
880
  dropped = self.dropout1(hidden_states)
881
  residual = (dropped + residual) if residual is not None else dropped
882
  hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
883
- hidden_states = self.attn(hidden_states, attention_mask=attention_mask, is_padded_inputs=is_padded_inputs, cu_seqlens=cu_seqlens, max_seq_len=max_seq_len)
 
 
 
 
 
 
884
 
885
  dropped = self.dropout2(hidden_states)
886
  residual = (dropped + residual) if residual is not None else dropped
@@ -890,36 +895,29 @@ class NomicBertBlock(nn.Module):
890
  return hidden_states, None, residual
891
  else:
892
  assert residual is None
893
- attn_outputs = self.attn(hidden_states,
894
- attention_mask=attention_mask,
895
- is_padded_inputs=is_padded_inputs,
896
- cu_seqlens=cu_seqlens,
897
- max_seq_len=max_seq_len)
898
- hidden_states = self.norm1(
899
- (self.dropout1(attn_outputs) + hidden_states).to(
900
- dtype=self.norm1.weight.dtype
901
- )
902
  )
 
903
  mlp_out = self.mlp(hidden_states)
904
 
905
- hidden_states = self.norm2(
906
- (self.dropout2(mlp_out) + hidden_states).to(
907
- dtype=self.norm2.weight.dtype
908
- )
909
- )
910
  return hidden_states, None, None
911
 
912
 
913
  class NomicBertEncoder(nn.Module):
914
  def __init__(self, config: GPT2Config):
915
  super().__init__()
916
- self.layers = nn.ModuleList(
917
- [NomicBertBlock(config) for _ in range(config.n_layer)]
918
- )
919
  self.gradient_checkpointing = False
920
  self.config = config
921
 
922
- def forward(self,
 
923
  hidden_states: torch.LongTensor = None,
924
  attention_mask: Optional[torch.Tensor] = None,
925
  position_ids: Optional[torch.LongTensor] = None,
@@ -929,8 +927,8 @@ class NomicBertEncoder(nn.Module):
929
  output_attentions: Optional[bool] = None,
930
  output_hidden_states: Optional[bool] = None,
931
  return_dict: Optional[bool] = None,
932
- is_padded_inputs: Optional[bool] = True,):
933
-
934
  """If subset_mask is not None, we only want output for the subset of the sequence.
935
  This means that we only compute the last layer output for these tokens.
936
  subset_mask: (batch, seqlen), dtype=torch.bool
@@ -938,7 +936,6 @@ class NomicBertEncoder(nn.Module):
938
  hidden_states2 = None
939
  residual = None
940
 
941
-
942
  for _, layer in enumerate(self.layers):
943
  if self.gradient_checkpointing and self.training:
944
 
@@ -998,11 +995,7 @@ class NomicBertPredictionHeadTransform(nn.Module):
998
  def __init__(self, config):
999
  super().__init__()
1000
  self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
1001
- approximate = (
1002
- "tanh"
1003
- if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
1004
- else "none"
1005
- )
1006
  if config.activation_function == "swiglu":
1007
  self.transform_act_fn = F.silu
1008
  else:
@@ -1047,15 +1040,19 @@ class NomicBertModel(NomicBertPreTrainedModel):
1047
  super().__init__(config)
1048
  self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
1049
  if config.vocab_size % self.pad_vocab_size_multiple != 0:
1050
- config.vocab_size += self.pad_vocab_size_multiple - (
1051
- config.vocab_size % self.pad_vocab_size_multiple
1052
- )
1053
-
1054
- assert config.activation_function in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh", "swiglu", "geglu", "glu"]
1055
-
1056
- self.embeddings = NomicBertEmbeddings(
1057
- config
1058
- )
 
 
 
 
1059
  self.emb_drop = nn.Dropout(config.resid_pdrop)
1060
  self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1061
  self.encoder = NomicBertEncoder(config)
@@ -1069,22 +1066,23 @@ class NomicBertModel(NomicBertPreTrainedModel):
1069
  position_ids=None,
1070
  token_type_ids=None,
1071
  attention_mask=None,
 
 
1072
  ):
1073
  if token_type_ids is None:
1074
  token_type_ids = torch.zeros_like(input_ids)
1075
- hidden_states = self.embeddings(
1076
- input_ids, position_ids=position_ids, token_type_ids=token_type_ids
1077
- )
1078
  hidden_states = self.emb_ln(hidden_states)
1079
  hidden_states = self.emb_drop(hidden_states)
1080
 
1081
  attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
1082
- sequence_output = self.encoder(
1083
- hidden_states, attention_mask=attention_mask
1084
- )
1085
 
1086
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1087
 
 
 
 
1088
  return BaseModelOutputWithPoolingAndCrossAttentions(
1089
  last_hidden_state=sequence_output,
1090
  pooler_output=pooled_output,
@@ -1151,10 +1149,10 @@ class NomicBertForPreTraining(NomicBertPreTrainedModel):
1151
  loss=total_loss,
1152
  logits=prediction_scores,
1153
  hidden_states=outputs.hidden_states,
1154
- attentions=None,
1155
  )
1156
 
1157
-
1158
  class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1159
  def __init__(self, config):
1160
  super().__init__(config)
@@ -1162,9 +1160,7 @@ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1162
  self.config = config
1163
 
1164
  self.bert = NomicBertModel(config)
1165
- classifier_dropout = (
1166
- getattr(config, "classifier_dropout", config.embd_pdrop)
1167
- )
1168
  self.dropout = nn.Dropout(classifier_dropout)
1169
  self.classifier = nn.Linear(config.n_embd, config.num_labels)
1170
 
 
3
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
  # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
 
6
+ import logging
7
+
8
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
  import os
10
+ import re
11
+ from collections import OrderedDict
12
  from functools import partial
13
+ from typing import List, Optional, Tuple, Union
14
 
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
18
  from einops import rearrange, repeat
19
+ from safetensors.torch import load_file as safe_load_file
20
  from transformers import GPT2Config, PreTrainedModel
21
  from transformers.models.bert.modeling_bert import (
22
  BaseModelOutputWithPoolingAndCrossAttentions,
23
  MaskedLMOutput,
24
+ SequenceClassifierOutput,
 
 
 
 
 
 
 
 
 
 
25
  )
26
+ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
27
  from transformers.utils.hub import cached_file, get_checkpoint_shard_files
28
 
 
29
  from .configuration_hf_nomic_bert import NomicBertConfig
30
 
31
  logger = logging.getLogger(__name__)
32
 
33
+
34
  # adapted from flash attention, added safe serialization option for hf models
35
  def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
36
  # If not fp32, then we don't want to load directly to the GPU
 
45
  safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
46
 
47
  if os.path.isfile(weights_path):
48
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
 
 
49
  elif os.path.isfile(weights_index_path):
50
+ resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False)
 
 
51
  is_sharded = True
52
  elif os.path.isfile(safe_weights_path):
53
+ resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
 
 
54
  load_safe = True
55
  elif os.path.isfile(safe_weights_index_path):
56
  resolved_archive_file = cached_file(
 
63
  resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False)
64
  if resolved_archive_file is None:
65
  weight_index = WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
66
+ resolved_archive_file = cached_file(model_name, weight_index, _raise_exceptions_for_missing_entries=False)
 
67
  if resolved_archive_file is not None:
68
  is_sharded = True
69
 
 
80
  if is_sharded:
81
  # resolved_archive_file becomes a list of files that point to the different
82
  # checkpoint shards in this case.
83
+ resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file)
 
 
84
  state_dict = {}
85
  for sharded_file in resolved_archive_file:
86
  state_dict.update(loader(sharded_file))
 
92
  state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
93
  return state_dict
94
 
95
+
96
  def filter_shapes(state_dict, model):
97
  """
98
  Filters the state dict to match the current model shape.
 
104
  filtered_state_dict[key] = value
105
  return filtered_state_dict
106
 
107
+
108
+ def remap_bert_state_dict(
109
+ state_dict,
110
+ config,
111
+ remove_bert=False,
112
+ remove_cls_weights=False,
113
+ add_pooling_layer=False,
114
+ ):
115
  """
116
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
117
  """
118
+
119
  def add_bert_prefix(key):
120
  # prepend bert. to the key
121
  if key.startswith("bert.") or key.startswith("cls."):
 
123
  return f"bert.{key}"
124
 
125
  state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
126
+
127
  # LayerNorm
128
  def key_mapping_ln_gamma_beta(key):
129
  key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
 
188
  bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
189
  bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
190
  if not (last_layer_subset and d == config.num_hidden_layers - 1):
191
+ state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
 
 
192
  state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
193
  else:
194
  state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
 
208
  def key_mapping_decoder_bias(key):
209
  return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
210
 
 
211
  # remove nsp weights, we don't use
212
  state_dict.pop("cls.seq_relationship.weight", None)
213
  state_dict.pop("cls.seq_relationship.bias", None)
 
216
  state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
217
 
218
  if remove_cls_weights:
219
+ cls_weights = [
220
+ "cls.predictions.decoder.bias",
221
+ "cls.predictions.transform.dense.weight",
222
+ "cls.predictions.transform.dense.bias",
223
+ "cls.predictions.transform.layer_norm.weight",
224
+ "cls.predictions.transform.layer_norm.bias",
225
+ "cls.predictions.decoder.weight",
226
+ ]
227
  for weight in cls_weights:
228
  state_dict.pop(weight, None)
229
 
 
249
  )
250
 
251
  if add_pooling_layer is False:
252
+ pooler_weights = [
253
+ "bert.pooler.dense.weight",
254
+ "bert.pooler.dense.bias",
255
+ ]
256
  for key in pooler_weights:
257
  state_dict.pop(key, None)
258
 
259
  if remove_bert:
260
+
261
  def remove_bert_prefix(key):
262
  key = re.sub(r"^bert.", "", key)
263
  return key
264
 
265
  state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
266
 
 
267
  return state_dict
268
 
269
 
 
271
  """An abstract class to handle weights initialization and
272
  a simple interface for dowloading and loading pretrained models.
273
  """
274
+
275
  config_class = NomicBertConfig
276
  base_model_prefix = "model"
277
  supports_gradient_checkpointing = True
 
311
  if config is None:
312
  config = cls.config_class.from_pretrained(model_name)
313
  remove_cls = cls != NomicBertForPreTraining
314
+ remove_bert_prefix = cls != NomicBertForPreTraining and cls != NomicBertForSequenceClassification
315
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
316
  num_labels = kwargs.pop("num_labels", None)
317
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
318
  if rotary_scaling_factor:
319
  config.rotary_scaling_factor = rotary_scaling_factor
320
+
 
321
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
322
  config.n_positions = 2048
323
  if num_labels:
 
334
  # Assuming we know what we're doing when loading from disk
335
  # Prob a bad assumption but i'm tired and want to train this asap
336
  if os.path.exists(model_name):
337
+ model_path = f"{model_name}/pytorch_model.bin"
338
+ if os.path.exists(model_path):
339
+ state_dict = torch.load(f"{model_name}/pytorch_model.bin")
340
+ else:
341
+ model_path = f"{model_name}/model.safetensors"
342
+ if not os.path.exists(model_path):
343
+ raise ValueError(f"Model path {model_path} not found")
344
+ state_dict = safe_load_file(model_path)
345
+
346
  if ignore_mismatched_shapes:
347
  state_dict = filter_shapes(state_dict, model)
348
  load_return = model.load_state_dict(state_dict, strict=False)
349
  else:
350
  # TODO: can probably check config class and see if we need to remap from a bert model
351
+ state_dict = state_dict_from_pretrained(
352
+ model_name, safe_serialization=kwargs.get("safe_serialization", False)
353
+ )
354
+ state_dict = remap_bert_state_dict(
355
+ state_dict,
356
+ config,
357
+ remove_bert=remove_bert_prefix,
358
+ remove_cls_weights=remove_cls,
359
+ add_pooling_layer=getattr(config, "add_pooling_layer", False),
360
+ )
361
  if ignore_mismatched_shapes:
362
  state_dict = filter_shapes(state_dict, model)
363
 
364
+ load_return = model.load_state_dict(state_dict, strict=True)
 
 
 
365
  logger.warning(load_return)
366
  return model
367
 
 
381
  if module.padding_idx is not None:
382
  nn.init.zeros_(module.weight[module.padding_idx])
383
 
384
+
385
  class NomicBertEmbeddings(nn.Module):
386
+ def __init__(self, config):
 
 
 
387
  """
388
  If max_position_embeddings <= 0, there's no position embeddings
389
  If type_vocab_size <= 0, there's no token type embeddings
390
  """
391
  super().__init__()
392
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
 
 
393
  self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
394
  self.type_vocab_size = config.type_vocab_size
395
  if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
396
  self.position_embeddings = nn.Embedding(
397
+ config.max_position_embeddings,
398
+ config.hidden_size,
399
  )
400
  if self.type_vocab_size > 0:
401
  self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
 
422
  embeddings = embeddings + position_embeddings
423
  return embeddings
424
 
425
+
426
  class NomicBertMLP(nn.Module):
427
  def __init__(
428
  self,
 
440
  hidden_features = hidden_features if hidden_features is not None else in_features * 4
441
  self.return_residual = return_residual
442
  self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
443
+ approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
 
 
 
 
444
  self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
445
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
446
 
 
450
  y = self.fc2(y)
451
  return y if not self.return_residual else (y, x)
452
 
453
+
454
  class NomciBertGatedMLP(nn.Module):
455
  def __init__(
456
  self,
 
468
  ):
469
  super().__init__()
470
  out_features = out_features if out_features is not None else in_features
471
+ hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3)
 
 
472
  hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
473
  self.return_residual = return_residual
474
 
 
505
  ro_dim = cos.shape[-1] * 2
506
  assert ro_dim <= x.shape[-1]
507
  cos, sin = (
508
+ cos[offset : offset + x.shape[1]],
509
+ sin[offset : offset + x.shape[1]],
510
  )
511
  cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
512
  sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
 
563
  self._sin_k_cached = None
564
 
565
  def _compute_inv_freq(self, device=None):
566
+ return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
 
 
 
567
 
568
  def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
569
  # Reset the tables if the sequence length has changed,
 
635
  self.rotary_scaling_factor = rotary_scaling_factor
636
  self.max_position_embeddings = max_position_embeddings
637
 
 
638
  def _compute_inv_freq(self, base=None, device=None):
639
  if base is None:
640
  base = self.base
641
+ return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
 
 
 
642
 
643
  def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
644
  # Reset the tables if the sequence length has changed,
 
689
  self._sin_cached = torch.sin(freqs).to(dtype)
690
  else:
691
  power = (
692
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
 
693
  ) / self.scale_base
694
  scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
695
  # We want the multiplication by scale to happen in fp32
 
698
  self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
699
  self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
700
 
701
+
702
  class NomicBertAttention(nn.Module):
703
  """Multi-head self-attention and cross-attention"""
704
 
 
739
  scale_base=config.rotary_emb_scale_base,
740
  interleaved=config.rotary_emb_interleaved,
741
  rotary_scaling_factor=config.rotary_scaling_factor,
742
+ max_position_embeddings=config.max_trained_positions,
743
+ )
744
  else:
745
  self.rotary_emb = NomicBertRotaryEmbedding(
746
  dim=self.rotary_emb_dim,
 
811
  attn_output = self.out_proj(attn_output)
812
 
813
  return attn_output
814
+
815
 
816
  class NomicBertBlock(nn.Module):
817
  def __init__(
 
821
  super().__init__()
822
  self.prenorm = config.prenorm
823
  self.fused_dropout_add_ln = config.fused_dropout_add_ln
824
+
825
+ self.attn = NomicBertAttention(config)
826
  activation = (
827
+ F.sigmoid
828
+ if config.activation_function == "glu"
829
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
830
  )
831
  if config.activation_function in ["glu", "swiglu", "geglu"]:
832
+ self.mlp = NomciBertGatedMLP(
833
+ config.n_embd,
834
+ hidden_features=config.n_inner,
835
+ bias1=config.mlp_fc1_bias,
836
+ bias2=config.mlp_fc2_bias,
837
+ activation=activation,
838
+ fused_bias_fc=config.fused_bias_fc,
839
+ )
840
  else:
841
+ self.mlp = NomicBertMLP(
842
+ config.n_embd,
843
+ hidden_features=config.n_inner,
844
+ bias1=config.mlp_fc1_bias,
845
+ bias2=config.mlp_fc2_bias,
846
+ activation=activation,
847
+ fused_bias_fc=config.fused_bias_fc,
848
+ )
849
 
850
  self.dropout1 = nn.Dropout(config.resid_pdrop)
851
  self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
 
879
  dropped = self.dropout1(hidden_states)
880
  residual = (dropped + residual) if residual is not None else dropped
881
  hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
882
+ hidden_states = self.attn(
883
+ hidden_states,
884
+ attention_mask=attention_mask,
885
+ is_padded_inputs=is_padded_inputs,
886
+ cu_seqlens=cu_seqlens,
887
+ max_seq_len=max_seq_len,
888
+ )
889
 
890
  dropped = self.dropout2(hidden_states)
891
  residual = (dropped + residual) if residual is not None else dropped
 
895
  return hidden_states, None, residual
896
  else:
897
  assert residual is None
898
+ attn_outputs = self.attn(
899
+ hidden_states,
900
+ attention_mask=attention_mask,
901
+ is_padded_inputs=is_padded_inputs,
902
+ cu_seqlens=cu_seqlens,
903
+ max_seq_len=max_seq_len,
 
 
 
904
  )
905
+ hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
906
  mlp_out = self.mlp(hidden_states)
907
 
908
+ hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
 
 
 
 
909
  return hidden_states, None, None
910
 
911
 
912
  class NomicBertEncoder(nn.Module):
913
  def __init__(self, config: GPT2Config):
914
  super().__init__()
915
+ self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
 
 
916
  self.gradient_checkpointing = False
917
  self.config = config
918
 
919
+ def forward(
920
+ self,
921
  hidden_states: torch.LongTensor = None,
922
  attention_mask: Optional[torch.Tensor] = None,
923
  position_ids: Optional[torch.LongTensor] = None,
 
927
  output_attentions: Optional[bool] = None,
928
  output_hidden_states: Optional[bool] = None,
929
  return_dict: Optional[bool] = None,
930
+ is_padded_inputs: Optional[bool] = True,
931
+ ):
932
  """If subset_mask is not None, we only want output for the subset of the sequence.
933
  This means that we only compute the last layer output for these tokens.
934
  subset_mask: (batch, seqlen), dtype=torch.bool
 
936
  hidden_states2 = None
937
  residual = None
938
 
 
939
  for _, layer in enumerate(self.layers):
940
  if self.gradient_checkpointing and self.training:
941
 
 
995
  def __init__(self, config):
996
  super().__init__()
997
  self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
998
+ approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
 
 
 
 
999
  if config.activation_function == "swiglu":
1000
  self.transform_act_fn = F.silu
1001
  else:
 
1040
  super().__init__(config)
1041
  self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
1042
  if config.vocab_size % self.pad_vocab_size_multiple != 0:
1043
+ config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple)
1044
+
1045
+ assert config.activation_function in [
1046
+ "gelu",
1047
+ "gelu_new",
1048
+ "gelu_fast",
1049
+ "gelu_pytorch_tanh",
1050
+ "swiglu",
1051
+ "geglu",
1052
+ "glu",
1053
+ ]
1054
+
1055
+ self.embeddings = NomicBertEmbeddings(config)
1056
  self.emb_drop = nn.Dropout(config.resid_pdrop)
1057
  self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1058
  self.encoder = NomicBertEncoder(config)
 
1066
  position_ids=None,
1067
  token_type_ids=None,
1068
  attention_mask=None,
1069
+ return_dict=None,
1070
+ matryoshka_dim=None,
1071
  ):
1072
  if token_type_ids is None:
1073
  token_type_ids = torch.zeros_like(input_ids)
1074
+ hidden_states = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
 
 
1075
  hidden_states = self.emb_ln(hidden_states)
1076
  hidden_states = self.emb_drop(hidden_states)
1077
 
1078
  attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
1079
+ sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
 
 
1080
 
1081
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1082
 
1083
+ if matryoshka_dim:
1084
+ sequence_output = sequence_output[:, :matryoshka_dim]
1085
+
1086
  return BaseModelOutputWithPoolingAndCrossAttentions(
1087
  last_hidden_state=sequence_output,
1088
  pooler_output=pooled_output,
 
1149
  loss=total_loss,
1150
  logits=prediction_scores,
1151
  hidden_states=outputs.hidden_states,
1152
+ attentions=None,
1153
  )
1154
 
1155
+
1156
  class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1157
  def __init__(self, config):
1158
  super().__init__(config)
 
1160
  self.config = config
1161
 
1162
  self.bert = NomicBertModel(config)
1163
+ classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop)
 
 
1164
  self.dropout = nn.Dropout(classifier_dropout)
1165
  self.classifier = nn.Linear(config.n_embd, config.num_labels)
1166