zhihan1996 commited on
Commit
faf3caf
1 Parent(s): 55b04c5

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +3 -8
bert_layers.py CHANGED
@@ -18,7 +18,6 @@ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
18
  from transformers.activations import ACT2FN
19
  from transformers.modeling_outputs import (MaskedLMOutput,
20
  SequenceClassifierOutput)
21
- from transformers.models.bert.modeling_bert import BertPreTrainedModel
22
  from transformers.modeling_utils import PreTrainedModel
23
 
24
  from .bert_padding import (index_first_axis,
@@ -522,7 +521,7 @@ class BertPredictionHeadTransform(nn.Module):
522
  return hidden_states
523
 
524
 
525
- class BertModel(BertPreTrainedModel):
526
  """Overall BERT model.
527
 
528
  Args:
@@ -682,7 +681,7 @@ class BertOnlyNSPHead(nn.Module):
682
 
683
 
684
 
685
- class BertForMaskedLM(BertPreTrainedModel):
686
 
687
  def __init__(self, config):
688
  super().__init__(config)
@@ -810,12 +809,8 @@ class BertForMaskedLM(BertPreTrainedModel):
810
  return {'input_ids': input_ids, 'attention_mask': attention_mask}
811
 
812
 
813
- class BertForNextSentencePrediction(BertPreTrainedModel):
814
- #TBD: Push in future commit
815
- pass
816
 
817
-
818
- class BertForSequenceClassification(BertPreTrainedModel):
819
  """Bert Model transformer with a sequence classification/regression head.
820
 
821
  This head is just a linear layer on top of the pooled output. Used for,
 
18
  from transformers.activations import ACT2FN
19
  from transformers.modeling_outputs import (MaskedLMOutput,
20
  SequenceClassifierOutput)
 
21
  from transformers.modeling_utils import PreTrainedModel
22
 
23
  from .bert_padding import (index_first_axis,
 
521
  return hidden_states
522
 
523
 
524
+ class BertModel(PreTrainedModel):
525
  """Overall BERT model.
526
 
527
  Args:
 
681
 
682
 
683
 
684
+ class BertForMaskedLM(PreTrainedModel):
685
 
686
  def __init__(self, config):
687
  super().__init__(config)
 
809
  return {'input_ids': input_ids, 'attention_mask': attention_mask}
810
 
811
 
 
 
 
812
 
813
+ class BertForSequenceClassification(PreTrainedModel):
 
814
  """Bert Model transformer with a sequence classification/regression head.
815
 
816
  This head is just a linear layer on top of the pooled output. Used for,