Spaces:
Sleeping
Sleeping
Implement get_attention for tape BERT
Browse files- poetry.lock +118 -1
- protention/attention.py +30 -11
- pyproject.toml +1 -0
- tests/test_attention.py +10 -1
poetry.lock
CHANGED
@@ -171,6 +171,38 @@ category = "main"
|
|
171 |
optional = false
|
172 |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
[[package]]
|
175 |
name = "cachetools"
|
176 |
version = "5.3.0"
|
@@ -572,6 +604,14 @@ MarkupSafe = ">=2.0"
|
|
572 |
[package.extras]
|
573 |
i18n = ["Babel (>=2.7)"]
|
574 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
575 |
[[package]]
|
576 |
name = "jsonpointer"
|
577 |
version = "2.3"
|
@@ -749,6 +789,14 @@ category = "main"
|
|
749 |
optional = false
|
750 |
python-versions = "*"
|
751 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
752 |
[[package]]
|
753 |
name = "markdown-it-py"
|
754 |
version = "2.2.0"
|
@@ -1474,6 +1522,36 @@ pygments = ">=2.13.0,<3.0.0"
|
|
1474 |
[package.extras]
|
1475 |
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
1476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1477 |
[[package]]
|
1478 |
name = "semver"
|
1479 |
version = "2.13.0"
|
@@ -1613,6 +1691,37 @@ python-versions = ">=3.8"
|
|
1613 |
[package.dependencies]
|
1614 |
mpmath = ">=0.19"
|
1615 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1616 |
[[package]]
|
1617 |
name = "terminado"
|
1618 |
version = "0.17.1"
|
@@ -1983,7 +2092,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-co
|
|
1983 |
[metadata]
|
1984 |
lock-version = "1.1"
|
1985 |
python-versions = "^3.10"
|
1986 |
-
content-hash = "
|
1987 |
|
1988 |
[metadata.files]
|
1989 |
altair = []
|
@@ -2027,6 +2136,8 @@ beautifulsoup4 = []
|
|
2027 |
biopython = []
|
2028 |
bleach = []
|
2029 |
blinker = []
|
|
|
|
|
2030 |
cachetools = []
|
2031 |
certifi = []
|
2032 |
cffi = []
|
@@ -2071,6 +2182,7 @@ ipywidgets = []
|
|
2071 |
isoduration = []
|
2072 |
jedi = []
|
2073 |
jinja2 = []
|
|
|
2074 |
jsonpointer = []
|
2075 |
jsonschema = []
|
2076 |
jupyter-client = []
|
@@ -2082,6 +2194,7 @@ jupyter-server-terminals = []
|
|
2082 |
jupyterlab-pygments = []
|
2083 |
jupyterlab-widgets = []
|
2084 |
lit = []
|
|
|
2085 |
markdown-it-py = []
|
2086 |
markupsafe = []
|
2087 |
matplotlib-inline = []
|
@@ -2205,6 +2318,8 @@ requests = []
|
|
2205 |
rfc3339-validator = []
|
2206 |
rfc3986-validator = []
|
2207 |
rich = []
|
|
|
|
|
2208 |
semver = []
|
2209 |
send2trash = [
|
2210 |
{file = "Send2Trash-1.8.0-py3-none-any.whl", hash = "sha256:f20eaadfdb517eaca5ce077640cb261c7d2698385a6a0f072a4a5447fd49fa08"},
|
@@ -2225,6 +2340,8 @@ stack-data = []
|
|
2225 |
stmol = []
|
2226 |
streamlit = []
|
2227 |
sympy = []
|
|
|
|
|
2228 |
terminado = []
|
2229 |
tinycss2 = []
|
2230 |
tokenizers = []
|
|
|
171 |
optional = false
|
172 |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
173 |
|
174 |
+
[[package]]
|
175 |
+
name = "boto3"
|
176 |
+
version = "1.26.95"
|
177 |
+
description = "The AWS SDK for Python"
|
178 |
+
category = "main"
|
179 |
+
optional = false
|
180 |
+
python-versions = ">= 3.7"
|
181 |
+
|
182 |
+
[package.dependencies]
|
183 |
+
botocore = ">=1.29.95,<1.30.0"
|
184 |
+
jmespath = ">=0.7.1,<2.0.0"
|
185 |
+
s3transfer = ">=0.6.0,<0.7.0"
|
186 |
+
|
187 |
+
[package.extras]
|
188 |
+
crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
|
189 |
+
|
190 |
+
[[package]]
|
191 |
+
name = "botocore"
|
192 |
+
version = "1.29.95"
|
193 |
+
description = "Low-level, data-driven core of boto 3."
|
194 |
+
category = "main"
|
195 |
+
optional = false
|
196 |
+
python-versions = ">= 3.7"
|
197 |
+
|
198 |
+
[package.dependencies]
|
199 |
+
jmespath = ">=0.7.1,<2.0.0"
|
200 |
+
python-dateutil = ">=2.1,<3.0.0"
|
201 |
+
urllib3 = ">=1.25.4,<1.27"
|
202 |
+
|
203 |
+
[package.extras]
|
204 |
+
crt = ["awscrt (==0.16.9)"]
|
205 |
+
|
206 |
[[package]]
|
207 |
name = "cachetools"
|
208 |
version = "5.3.0"
|
|
|
604 |
[package.extras]
|
605 |
i18n = ["Babel (>=2.7)"]
|
606 |
|
607 |
+
[[package]]
|
608 |
+
name = "jmespath"
|
609 |
+
version = "1.0.1"
|
610 |
+
description = "JSON Matching Expressions"
|
611 |
+
category = "main"
|
612 |
+
optional = false
|
613 |
+
python-versions = ">=3.7"
|
614 |
+
|
615 |
[[package]]
|
616 |
name = "jsonpointer"
|
617 |
version = "2.3"
|
|
|
789 |
optional = false
|
790 |
python-versions = "*"
|
791 |
|
792 |
+
[[package]]
|
793 |
+
name = "lmdb"
|
794 |
+
version = "1.4.0"
|
795 |
+
description = "Universal Python binding for the LMDB 'Lightning' Database"
|
796 |
+
category = "main"
|
797 |
+
optional = false
|
798 |
+
python-versions = "*"
|
799 |
+
|
800 |
[[package]]
|
801 |
name = "markdown-it-py"
|
802 |
version = "2.2.0"
|
|
|
1522 |
[package.extras]
|
1523 |
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
1524 |
|
1525 |
+
[[package]]
|
1526 |
+
name = "s3transfer"
|
1527 |
+
version = "0.6.0"
|
1528 |
+
description = "An Amazon S3 Transfer Manager"
|
1529 |
+
category = "main"
|
1530 |
+
optional = false
|
1531 |
+
python-versions = ">= 3.7"
|
1532 |
+
|
1533 |
+
[package.dependencies]
|
1534 |
+
botocore = ">=1.12.36,<2.0a.0"
|
1535 |
+
|
1536 |
+
[package.extras]
|
1537 |
+
crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"]
|
1538 |
+
|
1539 |
+
[[package]]
|
1540 |
+
name = "scipy"
|
1541 |
+
version = "1.9.3"
|
1542 |
+
description = "Fundamental algorithms for scientific computing in Python"
|
1543 |
+
category = "main"
|
1544 |
+
optional = false
|
1545 |
+
python-versions = ">=3.8"
|
1546 |
+
|
1547 |
+
[package.dependencies]
|
1548 |
+
numpy = ">=1.18.5,<1.26.0"
|
1549 |
+
|
1550 |
+
[package.extras]
|
1551 |
+
test = ["pytest", "pytest-cov", "pytest-xdist", "asv", "mpmath", "gmpy2", "threadpoolctl", "scikit-umfpack"]
|
1552 |
+
doc = ["sphinx (!=4.1.0)", "pydata-sphinx-theme (==0.9.0)", "sphinx-panels (>=0.5.2)", "matplotlib (>2)", "numpydoc", "sphinx-tabs"]
|
1553 |
+
dev = ["mypy", "typing-extensions", "pycodestyle", "flake8"]
|
1554 |
+
|
1555 |
[[package]]
|
1556 |
name = "semver"
|
1557 |
version = "2.13.0"
|
|
|
1691 |
[package.dependencies]
|
1692 |
mpmath = ">=0.19"
|
1693 |
|
1694 |
+
[[package]]
|
1695 |
+
name = "tape-proteins"
|
1696 |
+
version = "0.5"
|
1697 |
+
description = "Repostory of Protein Benchmarking and Modeling"
|
1698 |
+
category = "main"
|
1699 |
+
optional = false
|
1700 |
+
python-versions = "*"
|
1701 |
+
|
1702 |
+
[package.dependencies]
|
1703 |
+
biopython = "*"
|
1704 |
+
boto3 = "*"
|
1705 |
+
filelock = "*"
|
1706 |
+
lmdb = "*"
|
1707 |
+
requests = "*"
|
1708 |
+
scipy = "*"
|
1709 |
+
tensorboardX = "*"
|
1710 |
+
tqdm = "*"
|
1711 |
+
|
1712 |
+
[[package]]
|
1713 |
+
name = "tensorboardx"
|
1714 |
+
version = "2.6"
|
1715 |
+
description = "TensorBoardX lets you watch Tensors Flow without Tensorflow"
|
1716 |
+
category = "main"
|
1717 |
+
optional = false
|
1718 |
+
python-versions = "*"
|
1719 |
+
|
1720 |
+
[package.dependencies]
|
1721 |
+
numpy = "*"
|
1722 |
+
packaging = "*"
|
1723 |
+
protobuf = ">=3.8.0,<4"
|
1724 |
+
|
1725 |
[[package]]
|
1726 |
name = "terminado"
|
1727 |
version = "0.17.1"
|
|
|
2092 |
[metadata]
|
2093 |
lock-version = "1.1"
|
2094 |
python-versions = "^3.10"
|
2095 |
+
content-hash = "ad6054ae4a119d961e9941f135489d1b89310303aefc27d3132fbd1ed1c35a0f"
|
2096 |
|
2097 |
[metadata.files]
|
2098 |
altair = []
|
|
|
2136 |
biopython = []
|
2137 |
bleach = []
|
2138 |
blinker = []
|
2139 |
+
boto3 = []
|
2140 |
+
botocore = []
|
2141 |
cachetools = []
|
2142 |
certifi = []
|
2143 |
cffi = []
|
|
|
2182 |
isoduration = []
|
2183 |
jedi = []
|
2184 |
jinja2 = []
|
2185 |
+
jmespath = []
|
2186 |
jsonpointer = []
|
2187 |
jsonschema = []
|
2188 |
jupyter-client = []
|
|
|
2194 |
jupyterlab-pygments = []
|
2195 |
jupyterlab-widgets = []
|
2196 |
lit = []
|
2197 |
+
lmdb = []
|
2198 |
markdown-it-py = []
|
2199 |
markupsafe = []
|
2200 |
matplotlib-inline = []
|
|
|
2318 |
rfc3339-validator = []
|
2319 |
rfc3986-validator = []
|
2320 |
rich = []
|
2321 |
+
s3transfer = []
|
2322 |
+
scipy = []
|
2323 |
semver = []
|
2324 |
send2trash = [
|
2325 |
{file = "Send2Trash-1.8.0-py3-none-any.whl", hash = "sha256:f20eaadfdb517eaca5ce077640cb261c7d2698385a6a0f072a4a5447fd49fa08"},
|
|
|
2340 |
stmol = []
|
2341 |
streamlit = []
|
2342 |
sympy = []
|
2343 |
+
tape-proteins = []
|
2344 |
+
tensorboardx = []
|
2345 |
terminado = []
|
2346 |
tinycss2 = []
|
2347 |
tokenizers = []
|
protention/attention.py
CHANGED
@@ -1,11 +1,16 @@
|
|
|
|
1 |
from io import StringIO
|
2 |
from urllib import request
|
3 |
|
4 |
import torch
|
5 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
|
|
6 |
from transformers import T5EncoderModel, T5Tokenizer
|
7 |
|
8 |
|
|
|
|
|
|
|
9 |
def get_structure(pdb_code: str) -> Structure:
|
10 |
"""
|
11 |
Get structure from PDB
|
@@ -46,9 +51,14 @@ def get_protT5() -> tuple[T5Tokenizer, T5EncoderModel]:
|
|
46 |
|
47 |
return tokenizer, model
|
48 |
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
def get_attention(
|
51 |
-
pdb_code: str,
|
52 |
):
|
53 |
"""
|
54 |
Get attention from T5
|
@@ -57,13 +67,22 @@ def get_attention(
|
|
57 |
structure = get_structure(pdb_code)
|
58 |
# Get list of sequences
|
59 |
sequences = get_sequences(structure)
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
from io import StringIO
|
3 |
from urllib import request
|
4 |
|
5 |
import torch
|
6 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
7 |
+
from tape import ProteinBertModel, TAPETokenizer
|
8 |
from transformers import T5EncoderModel, T5Tokenizer
|
9 |
|
10 |
|
11 |
+
class Model(str, Enum):
|
12 |
+
tape_bert = "bert-base"
|
13 |
+
|
14 |
def get_structure(pdb_code: str) -> Structure:
|
15 |
"""
|
16 |
Get structure from PDB
|
|
|
51 |
|
52 |
return tokenizer, model
|
53 |
|
54 |
+
def get_tape_bert() -> tuple[TAPETokenizer, ProteinBertModel]:
|
55 |
+
tokenizer = TAPETokenizer()
|
56 |
+
model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
|
57 |
+
return tokenizer, model
|
58 |
+
|
59 |
|
60 |
def get_attention(
|
61 |
+
pdb_code: str, model: Model = Model.tape_bert
|
62 |
):
|
63 |
"""
|
64 |
Get attention from T5
|
|
|
67 |
structure = get_structure(pdb_code)
|
68 |
# Get list of sequences
|
69 |
sequences = get_sequences(structure)
|
70 |
+
# TODO handle multiple sequences
|
71 |
+
sequence = sequences[0]
|
72 |
+
|
73 |
+
match model:
|
74 |
+
case model.tape_bert:
|
75 |
+
tokenizer, model = get_tape_bert()
|
76 |
+
token_idxs = tokenizer.encode(sequence).tolist()
|
77 |
+
inputs = torch.tensor(token_idxs).unsqueeze(0)
|
78 |
+
with torch.no_grad():
|
79 |
+
attns = model(inputs)[-1]
|
80 |
+
# Remove attention from <CLS> (first) and <SEP> (last) token
|
81 |
+
attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
|
82 |
+
attns = torch.stack([attn.squeeze(0) for attn in attns])
|
83 |
+
case model.prot_T5:
|
84 |
+
# Space separate sequences
|
85 |
+
sequences = [" ".join(sequence) for sequence in sequences]
|
86 |
+
tokenizer, model = get_protT5()
|
87 |
+
|
88 |
+
return attns
|
pyproject.toml
CHANGED
@@ -12,6 +12,7 @@ biopython = "^1.81"
|
|
12 |
transformers = "^4.27.1"
|
13 |
torch = "^2.0.0"
|
14 |
sentencepiece = "^0.1.97"
|
|
|
15 |
|
16 |
[tool.poetry.dev-dependencies]
|
17 |
pytest = "^7.2.2"
|
|
|
12 |
transformers = "^4.27.1"
|
13 |
torch = "^2.0.0"
|
14 |
sentencepiece = "^0.1.97"
|
15 |
+
tape-proteins = "^0.5"
|
16 |
|
17 |
[tool.poetry.dev-dependencies]
|
18 |
pytest = "^7.2.2"
|
tests/test_attention.py
CHANGED
@@ -1,7 +1,9 @@
|
|
|
|
1 |
from Bio.PDB.Structure import Structure
|
2 |
from transformers import T5EncoderModel, T5Tokenizer
|
3 |
|
4 |
-
from protention.attention import
|
|
|
5 |
|
6 |
|
7 |
def test_get_structure():
|
@@ -33,3 +35,10 @@ def test_get_protT5():
|
|
33 |
|
34 |
assert isinstance(tokenizer, T5Tokenizer)
|
35 |
assert isinstance(model, T5EncoderModel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
from Bio.PDB.Structure import Structure
|
3 |
from transformers import T5EncoderModel, T5Tokenizer
|
4 |
|
5 |
+
from protention.attention import (Model, get_attention, get_protT5,
|
6 |
+
get_sequences, get_structure)
|
7 |
|
8 |
|
9 |
def test_get_structure():
|
|
|
35 |
|
36 |
assert isinstance(tokenizer, T5Tokenizer)
|
37 |
assert isinstance(model, T5EncoderModel)
|
38 |
+
|
39 |
+
def test_get_attention_tape():
|
40 |
+
|
41 |
+
result = get_attention("1AKE", model=Model.tape_bert)
|
42 |
+
|
43 |
+
assert result is not None
|
44 |
+
assert result.shape == torch.Size([12,12,456,456])
|