Skip to content

Commit

Permalink
better code
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Oct 26, 2023
1 parent ef7a6fd commit a0f4e63
Showing 1 changed file with 50 additions and 45 deletions.
95 changes: 50 additions & 45 deletions onmt/translate/beam_search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from onmt.translate import penalties
from onmt.translate.decode_strategy import DecodeStrategy

import warnings


Expand Down Expand Up @@ -184,6 +183,51 @@ def _pick(self, log_probs):

return topk_scores, topk_ids

def beams_non_finished(self, i, predictions, attention, step):
b = self._batch_offset[i]

if any(self.is_finished_list[i]):
# Store finished hypotheses for this example in the batch.
for j in [
k for k, fin in enumerate(self.is_finished_list[i]) if fin
]: # Beam level: finished beam j in example i of batch
if self.ratio > 0:
s = self.topk_scores[i, j] / (step + 1)
self.best_scores[b] = max(s, self.best_scores[b])
self.hypotheses[b].append(
(
self.topk_scores[i, j],
predictions[i, j, 1:], # Ignore start_token.
attention[i, j, :, : self.src_len[b]]
if attention is not None
else None,
)
)
if len(self.hypotheses[b]) >= 2:
self.hypotheses[b] = sorted(
self.hypotheses[b], key=lambda x: x[0], reverse=True
)

# End condition is the top beam finished and we can return
# n_best hypotheses.
if self.ratio > 0:
pred_len = self.src_len[b] * self.ratio
finish_flag = (
(self.topk_scores[i, 0] / pred_len) <= self.best_scores[b]
) or all(self.is_finished_list[i])
else:
# early stop when top beam is finished
finish_flag = self.is_finished_list[i][0]

if finish_flag and len(self.hypotheses[b]) >= self.n_best:
for score, pred, attn in self.hypotheses[b][: self.n_best]:
self.scores[b].append(score)
self.predictions[b].append(pred) # ``(batch, n_best,)``
self.attention[b].append(attn if attn is not None else [])
return False
else:
return True

def update_finished(self):
# Penalize beams that finished.
_B_old = self.topk_log_probs.shape[0]
Expand All @@ -201,51 +245,12 @@ def update_finished(self):
if self.alive_attn is not None
else None
)
non_finished_batch = []

for i in range(len(self.is_finished_list)): # Batch level
b = self._batch_offset[i]

if any(self.is_finished_list[i]):
# Store finished hypotheses for this batch.
for j in [
k for k, fin in enumerate(self.is_finished_list[i]) if fin
]: # Beam level: finished beam j in batch i
if self.ratio > 0:
s = self.topk_scores[i, j] / (step + 1)
self.best_scores[b] = max(s, self.best_scores[b])

self.hypotheses[b].append(
(
self.topk_scores[i, j],
predictions[i, j, 1:], # Ignore start_token.
attention[i, j, :, : self.src_len[b]]
if attention is not None
else None,
)
)

# End condition is the top beam finished and we can return
# n_best hypotheses.
if self.ratio > 0:
pred_len = self.src_len[b] * self.ratio
finish_flag = (
(self.topk_scores[i, 0] / pred_len) <= self.best_scores[b]
) or all(self.is_finished_list[i])
else:
# early stop when top beam is finished
finish_flag = self.is_finished_list[i][0]

if finish_flag and len(self.hypotheses[b]) >= self.beam_size:
best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True)[
: self.n_best
]
for score, pred, attn in best_hyp:
self.scores[b].append(score)
self.predictions[b].append(pred) # ``(batch, n_best,)``
self.attention[b].append(attn if attn is not None else [])
else:
non_finished_batch.append(i)
non_finished_batch = [
i
for i in range(len(self.is_finished_list))
if self.beams_non_finished(i, predictions, attention, step)
]

non_finished = torch.tensor(non_finished_batch)
# If all sentences are translated, no need to go further.
Expand Down

0 comments on commit a0f4e63

Please sign in to comment.