From 2828eb75454cfd0e156ae472a369b24c0e0110b0 Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Wed, 22 Nov 2023 08:45:26 +0100 Subject: [PATCH] Set dynamic max length per batch (#2523) --- onmt/opts.py | 10 ++++++++++ onmt/translate/translator.py | 13 +++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/onmt/opts.py b/onmt/opts.py index 38df61d30f..2458c57540 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -1687,6 +1687,16 @@ def _add_decoding_opts(parser): default=250, help="Maximum prediction length.", ) + group.add( + "--max_length_ratio", + "-max_length_ratio", + type=float, + default=1.25, + help="Maximum prediction length ratio." + "for European languages 1.25 is large enough" + "for target Asian characters need to increase to 2-3" + "for special languages (burmese, amharic) to 10", + ) # Decoding content constraint group.add( "--block_ngram_repeat", diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 8efd646fec..10f89665b5 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -107,6 +107,7 @@ def __init__( n_best=1, min_length=0, max_length=100, + max_length_ratio=1.5, ratio=0.0, beam_size=30, random_sampling_topk=0, @@ -152,6 +153,7 @@ def __init__( self.n_best = n_best self.max_length = max_length + self.max_length_ratio = max_length_ratio self.beam_size = beam_size self.random_sampling_temp = random_sampling_temp @@ -243,6 +245,7 @@ def from_opt( n_best=opt.n_best, min_length=opt.min_length, max_length=opt.max_length, + max_length_ratio=opt.max_length_ratio, ratio=opt.ratio, beam_size=opt.beam_size, random_sampling_topk=opt.random_sampling_topk, @@ -792,6 +795,12 @@ def _align_forward(self, batch, predictions): def translate_batch(self, batch, attn_debug): """Translate a batch of sentences.""" + if self.max_length_ratio > 0: + max_length = int( + min(self.max_length, batch["src"].size(1) * self.max_length_ratio + 5) + ) + else: + max_length = self.max_length with torch.no_grad(): if self.sample_from_topk != 0 or self.sample_from_topp != 0: decode_strategy = GreedySearch( @@ -804,7 +813,7 @@ def translate_batch(self, batch, attn_debug): batch_size=len(batch["srclen"]), global_scorer=self.global_scorer, min_length=self.min_length, - max_length=self.max_length, + max_length=max_length, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, @@ -828,7 +837,7 @@ def translate_batch(self, batch, attn_debug): n_best=self.n_best, global_scorer=self.global_scorer, min_length=self.min_length, - max_length=self.max_length, + max_length=max_length, return_attention=attn_debug or self.replace_unk, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs,