robinzixuan commited on
Commit
51eda75
1 Parent(s): 8b594fd

Upload modeling_bert.py

Browse files
Files changed (1) hide show
  1. modeling_bert.py +14 -0
modeling_bert.py CHANGED
@@ -21,6 +21,7 @@ import warnings
21
  from dataclasses import dataclass
22
  from typing import List, Optional, Tuple, Union
23
  from functools import partial
 
24
  import torch
25
  import torch.utils.checkpoint
26
  from packaging import version
@@ -56,6 +57,18 @@ from transformers.utils import (
56
  )
57
  from .configuration_bert import BertConfig
58
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def softmax_n_shifted_zeros(input: torch.Tensor, n: int, dim=-1) -> torch.Tensor:
61
  """
@@ -91,6 +104,7 @@ def clipped_softmax(data, dim=1, eta=1.1, gamma=-0.1, **kw):
91
  return torch.clip(stretched_out, 0, 1)
92
 
93
 
 
94
  def clipped_softmax1(data, dim=1, eta=1.1, gamma=-0.1, **kw):
95
  sm_out = softmax_1(data, dim=dim, **kw)
96
  stretched_out = sm_out * (eta - gamma) + gamma
 
21
  from dataclasses import dataclass
22
  from typing import List, Optional, Tuple, Union
23
  from functools import partial
24
+ from enum import Flag, auto
25
  import torch
26
  import torch.utils.checkpoint
27
  from packaging import version
 
57
  )
58
  from .configuration_bert import BertConfig
59
 
60
+ class BaseEnumOptions(Flag):
61
+ def __str__(self):
62
+ return self.name
63
+
64
+ @classmethod
65
+ def list_names(cls):
66
+ return [m.name for m in cls]
67
+ class AttentionGateType(BaseEnumOptions):
68
+ none = 0
69
+ unconditional_per_head = 1
70
+ conditional_per_head = 2
71
+ conditional_per_token = 3
72
 
73
  def softmax_n_shifted_zeros(input: torch.Tensor, n: int, dim=-1) -> torch.Tensor:
74
  """
 
104
  return torch.clip(stretched_out, 0, 1)
105
 
106
 
107
+
108
  def clipped_softmax1(data, dim=1, eta=1.1, gamma=-0.1, **kw):
109
  sm_out = softmax_1(data, dim=dim, **kw)
110
  stretched_out = sm_out * (eta - gamma) + gamma