import torch from transformers import StoppingCriteria class StopWordsCriteria(StoppingCriteria): def __init__(self, stop_indices: list): self.stop_indices = stop_indices def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: # do not support batch inference for i in range(len(self.stop_indices)): if self.stop_indices[-1-i] != input_ids[0][-1-i]: return False return True