aksell commited on
Commit
1c423ed
1 Parent(s): 02840f4

Add PROT BERT

Browse files
hexviz/attention.py CHANGED
@@ -6,11 +6,11 @@ from urllib import request
6
  import streamlit as st
7
  import torch
8
  from Bio.PDB import PDBParser, Polypeptide, Structure
9
- from models import (ModelType, get_protgpt2, get_protT5, get_tape_bert,
10
- get_zymctrl)
 
11
 
12
 
13
- @st.cache
14
  def get_structure(pdb_code: str) -> Structure:
15
  """
16
  Get structure from PDB
@@ -83,6 +83,17 @@ def get_attention(
83
  # ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
84
  attention_stacked = torch.stack([attention for attention in attention_squeezed])
85
  attentions = attention_stacked
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  elif model_type == ModelType.PROT_T5:
88
  # Introduce white-space between all amino acids
 
6
  import streamlit as st
7
  import torch
8
  from Bio.PDB import PDBParser, Polypeptide, Structure
9
+
10
+ from hexviz.models import (ModelType, get_prot_bert, get_protgpt2, get_protT5,
11
+ get_tape_bert, get_zymctrl)
12
 
13
 
 
14
  def get_structure(pdb_code: str) -> Structure:
15
  """
16
  Get structure from PDB
 
83
  # ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
84
  attention_stacked = torch.stack([attention for attention in attention_squeezed])
85
  attentions = attention_stacked
86
+ # TODO extend attentions to be per token, not per word piece
87
+ # simplest way to draw attention for multi residue token models for now
88
+ elif model_type == ModelType.PROT_BERT:
89
+ tokenizer, model = get_prot_bert()
90
+ token_idxs = tokenizer.encode(sequence)
91
+ inputs = torch.tensor(token_idxs).unsqueeze(0)
92
+ with torch.no_grad():
93
+ attentions = model(inputs)[-1]
94
+ # Remove attention from <CLS> (first) and <SEP> (last) token
95
+ attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
96
+ attentions = torch.stack([attention.squeeze(0) for attention in attentions])
97
 
98
  elif model_type == ModelType.PROT_T5:
99
  # Introduce white-space between all amino acids
hexviz/models.py CHANGED
@@ -4,8 +4,8 @@ from typing import Tuple
4
  import streamlit as st
5
  import torch
6
  from tape import ProteinBertModel, TAPETokenizer
7
- from transformers import (AutoTokenizer, GPT2LMHeadModel, T5EncoderModel,
8
- T5Tokenizer)
9
 
10
 
11
  class ModelType(str, Enum):
@@ -13,6 +13,7 @@ class ModelType(str, Enum):
13
  PROT_T5 = "prot_t5_xl_half_uniref50-enc"
14
  ZymCTRL = "ZymCTRL"
15
  ProtGPT2 = "ProtGPT2"
 
16
 
17
 
18
  class Model:
@@ -42,6 +43,12 @@ def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
42
  model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
43
  return tokenizer, model
44
 
 
 
 
 
 
 
45
  @st.cache
46
  def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
47
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
4
  import streamlit as st
5
  import torch
6
  from tape import ProteinBertModel, TAPETokenizer
7
+ from transformers import (AutoTokenizer, BertForMaskedLM, BertTokenizer,
8
+ GPT2LMHeadModel, T5EncoderModel, T5Tokenizer)
9
 
10
 
11
  class ModelType(str, Enum):
 
13
  PROT_T5 = "prot_t5_xl_half_uniref50-enc"
14
  ZymCTRL = "ZymCTRL"
15
  ProtGPT2 = "ProtGPT2"
16
+ PROT_BERT = "ProtBert"
17
 
18
 
19
  class Model:
 
43
  model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
44
  return tokenizer, model
45
 
46
+ @st.cache
47
+ def get_prot_bert() -> Tuple[BertTokenizer, BertForMaskedLM]:
48
+ tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
49
+ model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")
50
+ return tokenizer, model
51
+
52
  @st.cache
53
  def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
54
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tests/test_attention.py CHANGED
@@ -51,6 +51,13 @@ def test_get_attention_tape():
51
  assert result is not None
52
  assert result.shape == torch.Size([12,12,13,13])
53
 
 
 
 
 
 
 
 
54
  def test_get_unidirection_sum_filtered():
55
  # 1 head, 1 layer, 4 residues long attention tensor
56
  attention= torch.tensor([[[[1, 2, 3, 4],
 
51
  assert result is not None
52
  assert result.shape == torch.Size([12,12,13,13])
53
 
54
+ def test_get_attention_prot_bert():
55
+
56
+ result = get_attention("GGG", model_type=ModelType.PROT_BERT)
57
+
58
+ assert result is not None
59
+ assert result.shape == torch.Size([30, 16, 3, 3])
60
+
61
  def test_get_unidirection_sum_filtered():
62
  # 1 head, 1 layer, 4 residues long attention tensor
63
  attention= torch.tensor([[[[1, 2, 3, 4],