diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py index 466cefd127..43865f1467 100644 --- a/onmt/translate/beam_search.py +++ b/onmt/translate/beam_search.py @@ -265,6 +265,12 @@ def update_finished(self): # reset the selection for the next step self.select_indices = self._batch_index.view(_B_new * self.beam_size) + # assert torch.equal( + # self.src_len[self.select_indices], + # self.src_len.view(_B_old, self.beam_size)[non_finished].view( + # _B_new * self.beam_size + # ), + # ) self.src_len = self.src_len[self.select_indices] self.maybe_update_target_prefix(self.select_indices)