diff --git a/onmt/inputters/dynamic_iterator.py b/onmt/inputters/dynamic_iterator.py index de3fd14ff6..d37a81e65f 100644 --- a/onmt/inputters/dynamic_iterator.py +++ b/onmt/inputters/dynamic_iterator.py @@ -279,27 +279,29 @@ def batch_size_fn(nbsents, maxlen): else: raise ValueError(f"Invalid argument batch_type={batch_type}") - minibatch, maxlen, size_so_far, seen = [], 0, 0, [] + def max_src_tgt(ex): + """return the max tokens btw src and tgt in the sequence.""" + if ex["tgt"]: + return max(len(ex["src"]["src_ids"]), len(ex["tgt"]["tgt_ids"])) + return len(ex["src"]["src_ids"]) + + minibatch, maxlen, size_so_far, seen = [], 0, 0, set() for ex in data: - if (ex["src"]["src"] not in seen) or (self.task != CorpusTask.TRAIN): - seen.append(ex["src"]["src"]) + src = ex["src"]["src"] + if src not in seen or (self.task != CorpusTask.TRAIN): + seen.add(src) minibatch.append(ex) nbsents = len(minibatch) - maxlen = max(text_sort_key(ex), maxlen) + maxlen = max(max_src_tgt(ex), maxlen) size_so_far = batch_size_fn(nbsents, maxlen) if size_so_far >= batch_size: - overflowed = 0 - if size_so_far > batch_size: - overflowed += 1 - if batch_size_multiple > 1: - overflowed += ( - len(minibatch) - overflowed - ) % batch_size_multiple + overflowed = 1 if size_so_far > batch_size else 0 + overflowed += (nbsents - overflowed) % batch_size_multiple if overflowed == 0: yield minibatch - minibatch, maxlen, size_so_far, seen = [], 0, 0, [] + minibatch, maxlen, size_so_far, seen = [], 0, 0, set() else: - if overflowed == len(minibatch): + if overflowed == nbsents: logger.warning( "The batch will be filled until we reach" " %d, its size may exceed %d tokens" @@ -308,10 +310,9 @@ def batch_size_fn(nbsents, maxlen): else: yield minibatch[:-overflowed] minibatch = minibatch[-overflowed:] - maxlen, size_so_far, seen = 0, 0, [] - for i, ex in enumerate(minibatch): - maxlen = max(text_sort_key(ex), maxlen) - size_so_far = batch_size_fn(i + 1, maxlen) + maxlen = max([max_src_tgt(ex) for ex in minibatch]) + size_so_far = batch_size_fn(len(minibatch), maxlen) + seen = set() if minibatch: yield minibatch diff --git a/onmt/inputters/text_utils.py b/onmt/inputters/text_utils.py index c44c0d4174..21082383ff 100644 --- a/onmt/inputters/text_utils.py +++ b/onmt/inputters/text_utils.py @@ -59,7 +59,7 @@ def append_features_to_text(text, features): def text_sort_key(ex): """Sort using the number of tokens in the sequence.""" if ex["tgt"]: - return max(len(ex["src"]["src_ids"]), len(ex["tgt"]["tgt_ids"])) + return len(ex["src"]["src_ids"]), len(ex["tgt"]["tgt_ids"]) return len(ex["src"]["src_ids"])