zhihan1996 guyaglionby commited on
Commit
2efd650
1 Parent(s): 1cdf84d

Fix error loading model with AutoModel (#1)

Browse files

- Fix error loading model with AutoModel (3263bb1e2dd26b172d2e21beba67da74818e61d9)


Co-authored-by: Guy Aglionby <guyaglionby@users.noreply.huggingface.co>

Files changed (1) hide show
  1. bert_layers.py +2 -0
bert_layers.py CHANGED
@@ -23,6 +23,7 @@ from transformers.models.bert.modeling_bert import BertPreTrainedModel
23
  from .bert_padding import (index_first_axis,
24
  index_put_first_axis, pad_input,
25
  unpad_input, unpad_input_only)
 
26
 
27
  try:
28
  from .flash_attn_triton import flash_attn_qkvpacked_func
@@ -563,6 +564,7 @@ class BertModel(BertPreTrainedModel):
563
  all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
564
  ```
565
  """
 
566
 
567
  def __init__(self, config, add_pooling_layer=True):
568
  super(BertModel, self).__init__(config)
 
23
  from .bert_padding import (index_first_axis,
24
  index_put_first_axis, pad_input,
25
  unpad_input, unpad_input_only)
26
+ from .configuration_bert import BertConfig
27
 
28
  try:
29
  from .flash_attn_triton import flash_attn_qkvpacked_func
 
564
  all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
565
  ```
566
  """
567
+ config_class = BertConfig
568
 
569
  def __init__(self, config, add_pooling_layer=True):
570
  super(BertModel, self).__init__(config)