diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py index 44b0269976..91cd056bcb 100644 --- a/onmt/translate/beam_search.py +++ b/onmt/translate/beam_search.py @@ -241,7 +241,7 @@ def update_finished(self): ] for n, (score, pred, attn) in enumerate(best_hyp): self.scores[b].append(score) - self.predictions[b].append(pred) # ``(batch, n_best,)`` + self.predictions[b].append(pred.cpu()) # ``(batch, n_best,)`` self.attention[b].append(attn if attn is not None else []) else: non_finished_batch.append(i) diff --git a/onmt/translate/greedy_search.py b/onmt/translate/greedy_search.py index 3b402cc9e8..f2272d3c3d 100644 --- a/onmt/translate/greedy_search.py +++ b/onmt/translate/greedy_search.py @@ -271,7 +271,7 @@ def update_finished(self): best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True) for score, pred, attn in best_hyp: self.scores[b].append(score) - self.predictions[b].append(pred) + self.predictions[b].append(pred.cpu()) self.attention[b].append(attn) return is_alive = ~self.is_finished.view(-1)