aksell commited on
Commit
3de389d
1 Parent(s): 97dc04c

Plot avg attention not sum

Browse files
Files changed (2) hide show
  1. hexviz/attention.py +9 -10
  2. tests/test_attention.py +4 -4
hexviz/attention.py CHANGED
@@ -85,21 +85,20 @@ def get_attention(
85
 
86
  return attentions
87
 
88
- def unidirectional_sum_filtered(attention, layer, head, threshold):
89
  num_layers, num_heads, seq_len, _ = attention.shape
90
  attention_head = attention[layer, head]
91
- unidirectional_sum_for_head = []
92
  for i in range(seq_len):
93
  for j in range(i, seq_len):
94
  # Attention matrices for BERT models are asymetric.
95
- # Bidirectional attention is reduced to one value by adding the
96
- # attention values
97
- # TODO think... does this operation make sense?
98
  sum = attention_head[i, j].item() + attention_head[j, i].item()
99
- if sum >= threshold:
100
- unidirectional_sum_for_head.append((sum, i, j))
101
- return unidirectional_sum_for_head
102
-
 
103
  @st.cache
104
  def get_attention_pairs(pdb_code: str, layer: int, head: int, threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
105
  # fetch structure
@@ -110,7 +109,7 @@ def get_attention_pairs(pdb_code: str, layer: int, head: int, threshold: int = 0
110
  attention_pairs = []
111
  for i, sequence in enumerate(sequences):
112
  attention = get_attention(sequence=sequence, model_type=model_type)
113
- attention_unidirectional = unidirectional_sum_filtered(attention, layer, head, threshold)
114
  chain = list(structure.get_chains())[i]
115
  for attn_value, res_1, res_2 in attention_unidirectional:
116
  try:
 
85
 
86
  return attentions
87
 
88
+ def unidirectional_avg_filtered(attention, layer, head, threshold):
89
  num_layers, num_heads, seq_len, _ = attention.shape
90
  attention_head = attention[layer, head]
91
+ unidirectional_avg_for_head = []
92
  for i in range(seq_len):
93
  for j in range(i, seq_len):
94
  # Attention matrices for BERT models are asymetric.
95
+ # Bidirectional attention is represented by the average of the two values
 
 
96
  sum = attention_head[i, j].item() + attention_head[j, i].item()
97
+ avg = sum / 2
98
+ if avg >= threshold:
99
+ unidirectional_avg_for_head.append((avg, i, j))
100
+ return unidirectional_avg_for_head
101
+
102
  @st.cache
103
  def get_attention_pairs(pdb_code: str, layer: int, head: int, threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
104
  # fetch structure
 
109
  attention_pairs = []
110
  for i, sequence in enumerate(sequences):
111
  attention = get_attention(sequence=sequence, model_type=model_type)
112
+ attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
113
  chain = list(structure.get_chains())[i]
114
  for attn_value, res_1, res_2 in attention_unidirectional:
115
  try:
tests/test_attention.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  from Bio.PDB.Structure import Structure
3
 
4
  from hexviz.attention import (ModelType, get_attention, get_sequences,
5
- get_structure, unidirectional_sum_filtered)
6
 
7
 
8
  def test_get_structure():
@@ -58,14 +58,14 @@ def test_get_attention_prot_bert():
58
  assert result is not None
59
  assert result.shape == torch.Size([30, 16, 3, 3])
60
 
61
- def test_get_unidirection_sum_filtered():
62
  # 1 head, 1 layer, 4 residues long attention tensor
63
  attention= torch.tensor([[[[1, 2, 3, 4],
64
  [2, 5, 6, 7],
65
  [3, 6, 8, 9],
66
  [4, 7, 9, 11]]]], dtype=torch.float32)
67
 
68
- result = unidirectional_sum_filtered(attention, 0, 0, 0)
69
 
70
  assert result is not None
71
  assert len(result) == 10
@@ -74,6 +74,6 @@ def test_get_unidirection_sum_filtered():
74
  [2, 5, 6],
75
  [4, 7, 91]]]], dtype=torch.float32)
76
 
77
- result = unidirectional_sum_filtered(attention, 0, 0, 0)
78
 
79
  assert len(result) == 6
 
2
  from Bio.PDB.Structure import Structure
3
 
4
  from hexviz.attention import (ModelType, get_attention, get_sequences,
5
+ get_structure, unidirectional_avg_filtered)
6
 
7
 
8
  def test_get_structure():
 
58
  assert result is not None
59
  assert result.shape == torch.Size([30, 16, 3, 3])
60
 
61
+ def test_get_unidirection_avg_filtered():
62
  # 1 head, 1 layer, 4 residues long attention tensor
63
  attention= torch.tensor([[[[1, 2, 3, 4],
64
  [2, 5, 6, 7],
65
  [3, 6, 8, 9],
66
  [4, 7, 9, 11]]]], dtype=torch.float32)
67
 
68
+ result = unidirectional_avg_filtered(attention, 0, 0, 0)
69
 
70
  assert result is not None
71
  assert len(result) == 10
 
74
  [2, 5, 6],
75
  [4, 7, 91]]]], dtype=torch.float32)
76
 
77
+ result = unidirectional_avg_filtered(attention, 0, 0, 0)
78
 
79
  assert len(result) == 6