Skip to content

Commit

Permalink
Set dynamic max length per batch (OpenNMT#2523)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s authored Nov 22, 2023
1 parent 1b8c290 commit 2828eb7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
10 changes: 10 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,16 @@ def _add_decoding_opts(parser):
default=250,
help="Maximum prediction length.",
)
group.add(
"--max_length_ratio",
"-max_length_ratio",
type=float,
default=1.25,
help="Maximum prediction length ratio."
"for European languages 1.25 is large enough"
"for target Asian characters need to increase to 2-3"
"for special languages (burmese, amharic) to 10",
)
# Decoding content constraint
group.add(
"--block_ngram_repeat",
Expand Down
13 changes: 11 additions & 2 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
n_best=1,
min_length=0,
max_length=100,
max_length_ratio=1.5,
ratio=0.0,
beam_size=30,
random_sampling_topk=0,
Expand Down Expand Up @@ -152,6 +153,7 @@ def __init__(

self.n_best = n_best
self.max_length = max_length
self.max_length_ratio = max_length_ratio

self.beam_size = beam_size
self.random_sampling_temp = random_sampling_temp
Expand Down Expand Up @@ -243,6 +245,7 @@ def from_opt(
n_best=opt.n_best,
min_length=opt.min_length,
max_length=opt.max_length,
max_length_ratio=opt.max_length_ratio,
ratio=opt.ratio,
beam_size=opt.beam_size,
random_sampling_topk=opt.random_sampling_topk,
Expand Down Expand Up @@ -792,6 +795,12 @@ def _align_forward(self, batch, predictions):

def translate_batch(self, batch, attn_debug):
"""Translate a batch of sentences."""
if self.max_length_ratio > 0:
max_length = int(
min(self.max_length, batch["src"].size(1) * self.max_length_ratio + 5)
)
else:
max_length = self.max_length
with torch.no_grad():
if self.sample_from_topk != 0 or self.sample_from_topp != 0:
decode_strategy = GreedySearch(
Expand All @@ -804,7 +813,7 @@ def translate_batch(self, batch, attn_debug):
batch_size=len(batch["srclen"]),
global_scorer=self.global_scorer,
min_length=self.min_length,
max_length=self.max_length,
max_length=max_length,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
return_attention=attn_debug or self.replace_unk,
Expand All @@ -828,7 +837,7 @@ def translate_batch(self, batch, attn_debug):
n_best=self.n_best,
global_scorer=self.global_scorer,
min_length=self.min_length,
max_length=self.max_length,
max_length=max_length,
return_attention=attn_debug or self.replace_unk,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
Expand Down

0 comments on commit 2828eb7

Please sign in to comment.