czczup commited on
Commit
c3bb826
1 Parent(s): 3a71f41

Update modeling_intern_vit.py

Browse files
Files changed (1) hide show
  1. modeling_intern_vit.py +9 -2
modeling_intern_vit.py CHANGED
@@ -129,6 +129,12 @@ except Exception:
129
  pass
130
 
131
 
 
 
 
 
 
 
132
  class InternVisionEmbeddings(nn.Module):
133
  def __init__(self, config: InternVisionConfig):
134
  super().__init__()
@@ -267,11 +273,12 @@ class InternVisionEncoderLayer(nn.Module):
267
  super().__init__()
268
  self.embed_dim = config.hidden_size
269
  self.intermediate_size = config.intermediate_size
 
270
 
271
  self.attn = InternAttention(config)
272
  self.mlp = InternMLP(config)
273
- self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
274
- self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
275
 
276
  self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
277
  self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
 
129
  pass
130
 
131
 
132
+ NORM2FN = {
133
+ 'rms_norm': InternRMSNorm,
134
+ 'layer_norm': nn.LayerNorm,
135
+ }
136
+
137
+
138
  class InternVisionEmbeddings(nn.Module):
139
  def __init__(self, config: InternVisionConfig):
140
  super().__init__()
 
273
  super().__init__()
274
  self.embed_dim = config.hidden_size
275
  self.intermediate_size = config.intermediate_size
276
+ self.norm_type = config.norm_type
277
 
278
  self.attn = InternAttention(config)
279
  self.mlp = InternMLP(config)
280
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
281
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
282
 
283
  self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
284
  self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))