zsp / model.py
Massimo G. Totaro
update fix
fba8f5e
raw
history blame
No virus
4.28 kB
from huggingface_hub import HfApi, ModelFilter
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers.tokenization_utils_base import BatchEncoding
from transformers.modeling_outputs import MaskedLMOutput
# Function to fetch suitable ESM models from HuggingFace Hub
def get_models() -> list[None|str]:
"""Fetch suitable ESM models from HuggingFace Hub."""
if not any(
out := [
m.modelId for m in HfApi().list_models(
filter=ModelFilter(
author="facebook", model_name="esm", task="fill-mask"
),
sort="lastModified",
direction=-1
)
]
):
raise RuntimeError("Error while retrieving models from HuggingFace Hub")
return out
# Class to wrap ESM models
class Model:
"""Wrapper for ESM models."""
def __init__(self, model_name: str = ""):
"""Load selected model and tokenizer."""
self.model_name = model_name
if model_name:
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
self.batch_converter = AutoTokenizer.from_pretrained(model_name)
self.alphabet = self.batch_converter.get_vocab()
# Check if CUDA is available and if so, use it
if torch.cuda.is_available():
self.model = self.model.cuda()
def tokenise(self, input: str) -> BatchEncoding:
"""Convert input string to batch of tokens."""
return self.batch_converter(input, return_tensors="pt")
def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput:
"""Run model on batch of tokens."""
return self.model(batch_tokens, **kwargs)
def __getitem__(self, key: str) -> int:
"""Get token ID from character."""
return self.alphabet[key]
def run_model(self, data):
"""Run model on data."""
def label_row(row, token_probs):
"""Label row with score."""
# Extract wild type, index and mutant type from the row
wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
# Calculate the score as the difference between the token probabilities of the mutant type and the wild type
score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
return score.item()
# Tokenise the sequence data
batch_tokens = self.tokenise(data.seq).input_ids
# Calculate the token probabilities without updating the model parameters
with torch.no_grad():
token_probs = torch.log_softmax(self(batch_tokens).logits, dim=-1)
# Store the token probabilities in the data
data.token_probs = token_probs.cpu().numpy()
# If the scoring strategy starts with "masked-marginals"
if data.scoring_strategy.startswith("masked-marginals"):
all_token_probs = []
# For each token in the batch
for i in range(batch_tokens.size()[1]):
# If the token is in the list of residues
if i in data.resi:
# Clone the batch tokens and mask the current token
batch_tokens_masked = batch_tokens.clone()
batch_tokens_masked[0, i] = self['<mask>']
# Calculate the masked token probabilities
with torch.no_grad():
masked_token_probs = torch.log_softmax(
self(batch_tokens_masked).logits, dim=-1
)
else:
# If the token is not in the list of residues, use the original token probabilities
masked_token_probs = token_probs
# Append the token probabilities to the list
all_token_probs.append(masked_token_probs[:, i])
# Concatenate all token probabilities
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
# Apply the label_row function to each row of the substitutions dataframe
data.out[self.model_name] = data.sub.apply(
lambda row: label_row(
row['0'],
token_probs,
),
axis=1,
)