|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BeamHypotheses(object): |
|
def __init__(self, num_beams, length_penalty=1.0, early_stopping=False): |
|
""" |
|
Initialize n-best list of hypotheses. |
|
""" |
|
self.length_penalty = length_penalty |
|
self.early_stopping = early_stopping |
|
self.num_beams = num_beams |
|
self.beams = [] |
|
self.worst_score = 1e9 |
|
|
|
def __len__(self): |
|
""" |
|
Number of hypotheses in the list. |
|
""" |
|
return len(self.beams) |
|
|
|
def add(self, hyp, sum_logprobs, length): |
|
""" |
|
Add a new hypothesis to the list. |
|
""" |
|
score = sum_logprobs / length ** self.length_penalty |
|
if len(self) < self.num_beams or score > self.worst_score: |
|
self.beams.append((score, hyp)) |
|
if len(self) > self.num_beams: |
|
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) |
|
del self.beams[sorted_scores[0][1]] |
|
self.worst_score = sorted_scores[1][0] |
|
else: |
|
self.worst_score = min(score, self.worst_score) |
|
|
|
def is_done(self, best_sum_logprobs, cur_len): |
|
""" |
|
If there are enough hypotheses and that none of the hypotheses being generated |
|
can become better than the worst one in the heap, then we are done with this sentence. |
|
""" |
|
|
|
if len(self) < self.num_beams: |
|
return False |
|
elif self.early_stopping: |
|
return True |
|
else: |
|
cur_score = best_sum_logprobs / cur_len ** self.length_penalty |
|
ret = self.worst_score >= cur_score |
|
return ret |
|
|
|
|