diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py index b02d6b31cd..f9f13a89d8 100644 --- a/onmt/translate/beam_search.py +++ b/onmt/translate/beam_search.py @@ -1,7 +1,6 @@ import torch from onmt.translate import penalties from onmt.translate.decode_strategy import DecodeStrategy - import warnings @@ -184,6 +183,51 @@ def _pick(self, log_probs): return topk_scores, topk_ids + def beams_non_finished(self, i, predictions, attention, step): + b = self._batch_offset[i] + + if any(self.is_finished_list[i]): + # Store finished hypotheses for this example in the batch. + for j in [ + k for k, fin in enumerate(self.is_finished_list[i]) if fin + ]: # Beam level: finished beam j in example i of batch + if self.ratio > 0: + s = self.topk_scores[i, j] / (step + 1) + self.best_scores[b] = max(s, self.best_scores[b]) + self.hypotheses[b].append( + ( + self.topk_scores[i, j], + predictions[i, j, 1:], # Ignore start_token. + attention[i, j, :, : self.src_len[b]] + if attention is not None + else None, + ) + ) + if len(self.hypotheses[b]) >= 2: + self.hypotheses[b] = sorted( + self.hypotheses[b], key=lambda x: x[0], reverse=True + ) + + # End condition is the top beam finished and we can return + # n_best hypotheses. + if self.ratio > 0: + pred_len = self.src_len[b] * self.ratio + finish_flag = ( + (self.topk_scores[i, 0] / pred_len) <= self.best_scores[b] + ) or all(self.is_finished_list[i]) + else: + # early stop when top beam is finished + finish_flag = self.is_finished_list[i][0] + + if finish_flag and len(self.hypotheses[b]) >= self.n_best: + for score, pred, attn in self.hypotheses[b][: self.n_best]: + self.scores[b].append(score) + self.predictions[b].append(pred) # ``(batch, n_best,)`` + self.attention[b].append(attn if attn is not None else []) + return False + else: + return True + def update_finished(self): # Penalize beams that finished. _B_old = self.topk_log_probs.shape[0] @@ -201,51 +245,12 @@ def update_finished(self): if self.alive_attn is not None else None ) - non_finished_batch = [] - - for i in range(len(self.is_finished_list)): # Batch level - b = self._batch_offset[i] - - if any(self.is_finished_list[i]): - # Store finished hypotheses for this batch. - for j in [ - k for k, fin in enumerate(self.is_finished_list[i]) if fin - ]: # Beam level: finished beam j in batch i - if self.ratio > 0: - s = self.topk_scores[i, j] / (step + 1) - self.best_scores[b] = max(s, self.best_scores[b]) - - self.hypotheses[b].append( - ( - self.topk_scores[i, j], - predictions[i, j, 1:], # Ignore start_token. - attention[i, j, :, : self.src_len[b]] - if attention is not None - else None, - ) - ) - # End condition is the top beam finished and we can return - # n_best hypotheses. - if self.ratio > 0: - pred_len = self.src_len[b] * self.ratio - finish_flag = ( - (self.topk_scores[i, 0] / pred_len) <= self.best_scores[b] - ) or all(self.is_finished_list[i]) - else: - # early stop when top beam is finished - finish_flag = self.is_finished_list[i][0] - - if finish_flag and len(self.hypotheses[b]) >= self.beam_size: - best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True)[ - : self.n_best - ] - for score, pred, attn in best_hyp: - self.scores[b].append(score) - self.predictions[b].append(pred) # ``(batch, n_best,)`` - self.attention[b].append(attn if attn is not None else []) - else: - non_finished_batch.append(i) + non_finished_batch = [ + i + for i in range(len(self.is_finished_list)) + if self.beams_non_finished(i, predictions, attention, step) + ] non_finished = torch.tensor(non_finished_batch) # If all sentences are translated, no need to go further.