habdine commited on
Commit
6ccdea6
1 Parent(s): f0f09a0

Upload code

Browse files
Files changed (2) hide show
  1. configuration_prot2text.py +74 -0
  2. modeling_prot2text.py +200 -0
configuration_prot2text.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Prot2Text configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers import AutoConfig
5
+ from transformers.utils import logging
6
+
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+ class Prot2TextConfig(PretrainedConfig):
12
+ model_type = "prot2text"
13
+ keys_to_ignore_at_inference = ["past_key_values"]
14
+ _keys_to_ignore_on_load_missing = [r"transformer"]
15
+
16
+ def __init__(
17
+ self,
18
+ cross_esm_graph=True,
19
+ decoder_start_token_id=50257,
20
+ early_stopping=True,
21
+ eos_token_id=50258,
22
+ bos_token_id=50257,
23
+ esm=True,
24
+ esm_model_name="facebook/esm2_t6_8M_UR50D",
25
+ gpt_model_name="gpt2",
26
+ length_penalty=2.0,
27
+ max_new_tokens=256,
28
+ no_repeat_ngram_size=3,
29
+ pad_token_id=50256,
30
+ prot2text_version="1.1",
31
+ rgcn=True,
32
+ rgc_input_dim=67,
33
+ rgcn_n_layers=6,
34
+ gpt_config=None,
35
+ esm_config=None,
36
+ **kwargs,
37
+ ):
38
+ self.cross_esm_graph = cross_esm_graph
39
+ self.decoder_start_token_id = decoder_start_token_id
40
+ self.early_stopping = early_stopping
41
+ self.eos_token_id = eos_token_id
42
+ self.esm = esm
43
+ self.esm_model_name = esm_model_name
44
+ self.gpt_model_name = gpt_model_name
45
+ self.length_penalty = length_penalty
46
+ self.max_new_tokens = max_new_tokens
47
+ self.no_repeat_ngram_size = no_repeat_ngram_size
48
+ self.pad_token_id = pad_token_id
49
+ self.prot2text_version = prot2text_version
50
+ self.rgcn = rgcn
51
+ self.rgc_input_dim = rgc_input_dim
52
+ self.rgcn_n_layers = rgcn_n_layers
53
+ if gpt_config is None:
54
+ self.gpt_config = AutoConfig.from_pretrained(gpt_model_name,
55
+ _name_or_path= gpt_model_name,
56
+ is_encoder_decoder=True,
57
+ use_cache=False,
58
+ add_cross_attention=True,
59
+ bos_token_id=bos_token_id,
60
+ decoder_start_token_id=decoder_start_token_id,
61
+ eos_token_id=eos_token_id,
62
+ max_new_tokens=max_new_tokens,
63
+ pad_token_id=50256,
64
+ vocab_size=50259,
65
+ num_beams=1,
66
+ max_length=256,
67
+ min_length=1).to_dict()
68
+ else:
69
+ self.gpt_config = gpt_config
70
+ if esm_config is None:
71
+ self.esm_config = AutoConfig.from_pretrained(esm_model_name).to_dict()
72
+ self.esm_config = esm_config
73
+
74
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
modeling_prot2text.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config, AutoTokenizer, GPT2Config
2
+ from transformers import PretrainedConfig, PreTrainedModel
3
+ import transformers
4
+ from typing import Optional, Tuple, Callable
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
8
+ from .utils import CABlock, _GPT2LMHeadModel
9
+ from .configuration_prot2text import Prot2TextConfig
10
+ from transformers.generation.configuration_utils import GenerationConfig
11
+ from transformers.generation.logits_process import LogitsProcessorList
12
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
13
+
14
+
15
+ class Prot2TextModel(PreTrainedModel):
16
+ config_class = Prot2TextConfig
17
+ _keys_to_ignore_on_load_missing = [r"transformer"]
18
+ base_model_prefix = "decoder"
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+
22
+ self.gpt_config = GPT2Config.from_dict(config.gpt_config)
23
+
24
+ # define the GPT2 decoder
25
+ self.decoder = _GPT2LMHeadModel(self.gpt_config)
26
+
27
+ # if using ESM to encode protein's sequence, define the ESM layer, the Projection layer and the fusion layer
28
+ if config.esm:
29
+ self.esm_config = PretrainedConfig.from_dict(config.esm_config)
30
+ self.esm = transformers.EsmModel(self.esm_config)
31
+ self.to_embedding = nn.Linear(self.esm_config.hidden_size, self.gpt_config.n_embd)
32
+ if config.cross_esm_graph and config.rgcn:
33
+ self.h = nn.ModuleList([CABlock(self.gpt_config, layer_idx=i) for i in range(4)])
34
+ self.ln_f = nn.LayerNorm(self.gpt_config.n_embd, eps=self.gpt_config.layer_norm_epsilon)
35
+
36
+ self.config = config
37
+
38
+
39
+ def get_encoder(self):
40
+ return self.encoder
41
+
42
+ def get_decoder(self):
43
+ return self.decoder
44
+
45
+ def get_input_embeddings(self):
46
+ if hasattr(self, "transformer"):
47
+ return self.transformer.wte
48
+ return self.decoder.transformer.wte
49
+
50
+ def warm_up(self, gpt_model=None, esm_model=None):
51
+ if esm_model is not None:
52
+ self.esm = transformers.EsmModel.from_pretrained(esm_model)
53
+ if gpt_model is not None:
54
+ self.decoder = _GPT2LMHeadModel.from_pretrained(gpt_model, add_cross_attention=True, use_cache=False)
55
+ self.decoder.resize_token_embeddings(self.gpt_config.vocab_size)
56
+ self.decoder.config = self.gpt_config
57
+
58
+
59
+ def forward(self,
60
+ encoder_input_ids: Optional[torch.LongTensor] = None,
61
+ edge_index: Optional[torch.LongTensor] = None,
62
+ batch: Optional[torch.LongTensor] = None,
63
+ x: Optional[torch.FloatTensor] = None,
64
+ edge_type: Optional[torch.LongTensor] = None,
65
+ decoder_input_ids: Optional[torch.LongTensor] = None,
66
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
67
+ past_key_values_graph_esm: Optional[Tuple[Tuple[torch.Tensor]]] = None,
68
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
69
+ attention_mask: Optional[torch.FloatTensor] = None,
70
+ token_type_ids: Optional[torch.LongTensor] = None,
71
+ position_ids: Optional[torch.LongTensor] = None,
72
+ head_mask: Optional[torch.FloatTensor] = None,
73
+ inputs_embeds: Optional[torch.FloatTensor] = None,
74
+ encoder_hidden_states: Optional[torch.Tensor] = None,
75
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
76
+ labels: Optional[torch.LongTensor] = None,
77
+ use_cache: Optional[bool] = None,
78
+ output_attentions: Optional[bool] = None,
79
+ output_hidden_states: Optional[bool] = None,
80
+ return_dict: Optional[bool] = None,
81
+ get_graph_emb: Optional[bool] = False,
82
+ **delete_args,
83
+ ):
84
+ use_cache = use_cache if use_cache is not None else self.gpt_config.use_cache
85
+ return_dict = return_dict if return_dict is not None else self.gpt_config.use_return_dict
86
+
87
+
88
+ if decoder_input_ids is not None and len(decoder_input_ids.size()) == 3:
89
+ decoder_input_ids = decoder_input_ids.squeeze(0)
90
+
91
+ if self.config.esm:
92
+ if self.config.prot2text_version=='1.0':
93
+ if encoder_input_ids.size()[1] != 1021:
94
+ raise ValueError("For this version of the model you need to PAD/Truncate the amino acid sequence for the ESM model to 1021")
95
+
96
+ esm_emb = self.esm(input_ids=encoder_input_ids, attention_mask=attention_mask, return_dict=return_dict).last_hidden_state
97
+ esm_emb = self.to_embedding(esm_emb)
98
+ graph_emb = esm_emb
99
+ else:
100
+ attention_mask = None
101
+ if self.config.prot2text_version=='1.0':
102
+ attention_mask = None
103
+ if get_graph_emb:
104
+ return graph_emb
105
+
106
+ transformer_outputs = self.decoder(input_ids=decoder_input_ids,
107
+ past_key_values=past_key_values,
108
+ attention_mask=decoder_attention_mask,
109
+ token_type_ids=token_type_ids,
110
+ position_ids=position_ids,
111
+ head_mask=head_mask,
112
+ inputs_embeds=inputs_embeds,
113
+ encoder_hidden_states=graph_emb,
114
+ encoder_attention_mask=attention_mask,
115
+ labels=labels,
116
+ use_cache=use_cache,
117
+ output_attentions=output_attentions,
118
+ output_hidden_states=output_hidden_states,
119
+ return_dict=return_dict,
120
+ )
121
+
122
+ return transformer_outputs
123
+
124
+ @torch.no_grad()
125
+ def generate_protein_description(self,
126
+ protein_pdbID=None,
127
+ protein_sequence=None,
128
+ edge_index: Optional[torch.LongTensor] = None,
129
+ x: Optional[torch.FloatTensor] = None,
130
+ edge_type: Optional[torch.LongTensor] = None,
131
+ tokenizer=None,
132
+ device='cpu'
133
+ ):
134
+
135
+ if self.config.esm and not self.config.rgcn and protein_sequence==None:
136
+ raise ValueError(
137
+ "The model you are trying to use is based only on protein sequence, please provide an amino-acid protein_sequence"
138
+ )
139
+ if self.config.rgcn and protein_pdbID==None and (x==None or edge_index==None or edge_type==None):
140
+ raise ValueError(
141
+ "The model you are trying to use is based on protein structure, please provide a AlphaFold ID (you must have to have internet connection using protein_pdbID, or provide the triplet inputs: x (node features), edge_index and edge_type"
142
+ )
143
+ if self.config.esm:
144
+ esmtokenizer = AutoTokenizer.from_pretrained(self.config.esm_model_name)
145
+
146
+ if protein_pdbID==None and protein_sequence==None:
147
+ raise ValueError(
148
+ "you need to provide either a protein AlphaFold Id or an amino-acid sequence"
149
+ )
150
+
151
+
152
+ seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
153
+ inputs={}
154
+ inputs['encoder_input_ids'] = seq['input_ids']
155
+ inputs['attention_mask'] = seq['attention_mask']
156
+ inputs['decoder_input_ids'] = inputs['encoder_input_ids'][:,0:1].clone()
157
+ inputs['decoder_input_ids'][:,0] = tokenizer.bos_token_id
158
+
159
+ self.to(device)
160
+ inputs = {k: v.to(device=device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
161
+ encoder_state = dict()
162
+ encoder_state['hidden_states'] = self(**inputs, get_graph_emb=True, output_attentions=True)
163
+ generated = tokenizer.batch_decode(self.decoder.generate(input_ids=inputs['decoder_input_ids'], encoder_outputs=encoder_state, use_cache=True), skip_special_tokens=True)
164
+
165
+ return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
166
+
167
+ @torch.no_grad()
168
+ def generate(self,
169
+ inputs: Optional[torch.Tensor] = None,
170
+ generation_config: Optional[GenerationConfig] = None,
171
+ logits_processor: Optional[LogitsProcessorList] = None,
172
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
173
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
174
+ synced_gpus: Optional[bool] = None,
175
+ assistant_model: Optional["PreTrainedModel"] = None,
176
+ streamer: Optional["BaseStreamer"] = None,
177
+ **kwargs,
178
+ ):
179
+ encoder_state = self(**kwargs, get_graph_emb=True)
180
+ input_ids = kwargs['decoder_input_ids']
181
+ attention_mask = kwargs['decoder_attention_mask']
182
+ kwargs['encoder_attention_mask'] = kwargs['attention_mask']
183
+ if not self.config.cross_esm_graph and self.config.rgcn and self.config.esm:
184
+ t_add = torch.ones((kwargs['encoder_attention_mask'].size(0), 1)).to(kwargs['encoder_attention_mask'].get_device())
185
+ kwargs['encoder_attention_mask'] = torch.cat((t_add, kwargs['encoder_attention_mask']), dim=1)
186
+ for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids', 'decoder_input_ids', 'decoder_attention_mask', 'batch', 'attention_mask', 'max_length',
187
+ '_num_nodes', 'node_id', 'name', 'sequence', 'distance_matrix', 'distance', 'coordinates', 'ptr', 'num_nodes',]:
188
+ if key in kwargs.keys():
189
+ kwargs.pop(key)
190
+ return self.decoder.generate(input_ids=input_ids,
191
+ generation_config=generation_config,
192
+ logits_processor=logits_processor,
193
+ stopping_criteria=stopping_criteria,
194
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
195
+ synced_gpus=synced_gpus,
196
+ assistant_model=assistant_model,
197
+ streamer=streamer,
198
+ encoder_outputs={'hidden_states': encoder_state, 'attentions':0},
199
+ **kwargs
200
+ )