diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 7c13c8eea0..9c8c047224 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -169,8 +169,96 @@ jobs: -scoring_debug "true" \ -tensorboard_log_dir /tmp/logs_dynamic-scoring_and_copy \ -dump_preds /tmp/dump_preds \ + -position_encoding \ -copy_attn python onmt/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_copy -tensorboard_checks valid_metrics + - name : Test Transformer training and validation with dynamic scoring and maxrelative + run: | + python3 train.py \ + -config data/data.yaml \ + -src_vocab /tmp/onmt.vocab.src \ + -tgt_vocab /tmp/onmt.vocab.tgt \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -encoder_type transformer \ + -decoder_type transformer \ + -layers 4 \ + -word_vec_size 16 \ + -hidden_size 16 \ + -num_workers 0 -bucket_size 1024 \ + -heads 2 \ + -transformer_ff 64 \ + -num_workers 0 -bucket_size 1024 \ + -accum_count 2 4 8 \ + -accum_steps 0 15000 30000 \ + -save_model /tmp/onmt.model \ + -train_steps 10 -valid_steps 5 \ + -report_every 2 \ + -valid_metrics "BLEU" "TER" \ + -tensorboard "true" \ + -scoring_debug "true" \ + -tensorboard_log_dir /tmp/logs_dynamic-scoring_and_relative \ + -dump_preds /tmp/dump_preds \ + -max_relative_positions 8 + python onmt/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_relative -tensorboard_checks valid_metrics + - name : Test Transformer training and validation with dynamic scoring and rotary + run: | + python3 train.py \ + -config data/data.yaml \ + -src_vocab /tmp/onmt.vocab.src \ + -tgt_vocab /tmp/onmt.vocab.tgt \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -encoder_type transformer \ + -decoder_type transformer \ + -layers 4 \ + -word_vec_size 16 \ + -hidden_size 16 \ + -num_workers 0 -bucket_size 1024 \ + -heads 2 \ + -transformer_ff 64 \ + -num_workers 0 -bucket_size 1024 \ + -accum_count 2 4 8 \ + -accum_steps 0 15000 30000 \ + -save_model /tmp/onmt.model \ + -train_steps 10 -valid_steps 5 \ + -report_every 2 \ + -valid_metrics "BLEU" "TER" \ + -tensorboard "true" \ + -scoring_debug "true" \ + -tensorboard_log_dir /tmp/logs_dynamic-scoring_and_rotary \ + -dump_preds /tmp/dump_preds \ + -max_relative_positions -1 + python onmt/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_rotary -tensorboard_checks valid_metrics + - name : Test Transformer training and validation with dynamic scoring and alibi + run: | + python3 train.py \ + -config data/data.yaml \ + -src_vocab /tmp/onmt.vocab.src \ + -tgt_vocab /tmp/onmt.vocab.tgt \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -encoder_type transformer \ + -decoder_type transformer \ + -layers 4 \ + -word_vec_size 16 \ + -hidden_size 16 \ + -num_workers 0 -bucket_size 1024 \ + -heads 2 \ + -transformer_ff 64 \ + -num_workers 0 -bucket_size 1024 \ + -accum_count 2 4 8 \ + -accum_steps 0 15000 30000 \ + -save_model /tmp/onmt.model \ + -train_steps 10 -valid_steps 5 \ + -report_every 2 \ + -valid_metrics "BLEU" "TER" \ + -tensorboard "true" \ + -scoring_debug "true" \ + -tensorboard_log_dir /tmp/logs_dynamic-scoring_and_alibi \ + -dump_preds /tmp/dump_preds \ + -max_relative_positions 8 + python onmt/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_alibi -tensorboard_checks valid_metrics - name: Test LM training run: | python train.py \ diff --git a/data/data_lm/gen-nucleus-sampling-sol2.txt b/data/data_lm/gen-nucleus-sampling-sol2.txt index 4742cb0a0d..4a1090ea1b 100644 --- a/data/data_lm/gen-nucleus-sampling-sol2.txt +++ b/data/data_lm/gen-nucleus-sampling-sol2.txt @@ -1,7 +1,7 @@ -like your equipment ! -in your Presidency or registering erratic full detailed table since Waddington , as well handled separately . -the importance 's future . -there is survivors , under new public beforehand ? -bankers sit on the honour and the old , accusations by Cubase for political leadership in the fifth generation of 0.5 , 1 January in serving the old , 1 . -can do planet what Mr Titley , should like What is really do so done guarantee - ™ your Gore . -we can restore confidence in Switzerland . +well ? +liberty to decide in favor of the Stability tattoo . +a moment of the top of the importance of the importance 's future . +I think that received negative perception that it . +heard from India majority , Romagna . +We move should do any experiences . +our website , you invite shop in the year . diff --git a/data/data_lm/gen-sampling-beams-sol2.txt b/data/data_lm/gen-sampling-beams-sol2.txt index 1bd89a15bd..c361bcc408 100644 --- a/data/data_lm/gen-sampling-beams-sol2.txt +++ b/data/data_lm/gen-sampling-beams-sol2.txt @@ -1,7 +1,7 @@ -you ! your staff ! -in German Presidency or rather than Andalusia . -a moment , one century . -" s operational from 06: 00 on weekdays , clothes hang thick from the and 00 until about 00 until the sea . -fine words has underlined the fine vegetarian ... -service are buying any of my property . -we are much running behind the facts . +you ! Next to your luck . +inspired by the absolut well-beeing-feeling of the Tauern Spa and ... +the top of the top of the top of the moment , health of the importance of capital , the world . +" Austrian " sold with " Delta , received from twenty years . +administered by the usefulness of the interinstitutional coherence of the recorded by the present in the young majority of renewable energies . +do so would like this subsidy is requested . +800 m2 can 't be seen on payments . diff --git a/onmt/bin/translate.py b/onmt/bin/translate.py index 5fc64f6409..aae076ec85 100644 --- a/onmt/bin/translate.py +++ b/onmt/bin/translate.py @@ -9,6 +9,7 @@ from onmt.utils.parse import ArgumentParser from onmt.utils.misc import use_gpu, set_random_seed from torch.profiler import profile, record_function, ProfilerActivity +import time def translate(opt): @@ -52,13 +53,15 @@ def main(): parser = _get_parser() opt = parser.parse_args() if opt.profile: + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: with record_function("Translate"): translate(opt) - print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30)) - + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=40)) else: + init_time = time.time() translate(opt) + print("Time w/o python interpreter load/terminate: ", time.time() - init_time) if __name__ == "__main__": diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index 5d4ff51afa..84225124ba 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -581,20 +581,18 @@ def forward(self, tgt, enc_out=None, step=None, **kwargs): {"keys": torch.tensor([]), "values": torch.tensor([])}, ) - emb = self.embeddings(tgt, step=step) - dec_out = emb - assert emb.dim() == 3 # len x batch x embedding_dim + dec_out = self.embeddings(tgt, step=step) pad_idx = self.embeddings.word_padding_idx - src_lens = kwargs["src_len"] + src_len = kwargs["src_len"] src_max_len = self.state["src"].shape[1] - src_pad_mask = ~sequence_mask(src_lens, src_max_len) # [B x slen] - src_pad_mask = src_pad_mask.unsqueeze(1) # [B x 1 x slen] + src_pad_mask = sequence_mask(src_len, src_max_len).unsqueeze( + 1 + ) # [B x 1 x slen] tgt_pad_mask = tgt[:, :, 0].eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] with_align = kwargs.pop("with_align", False) - return_attn = kwargs.pop("return_attn", False) - return_attn = with_align or self._copy or return_attn + return_attn = with_align or self._copy or kwargs.pop("return_attn", False) attn_aligns = [] diff --git a/onmt/encoders/mean_encoder.py b/onmt/encoders/mean_encoder.py index 115d42b9fa..602524559b 100644 --- a/onmt/encoders/mean_encoder.py +++ b/onmt/encoders/mean_encoder.py @@ -30,7 +30,7 @@ def forward(self, src, src_len=None): if src_len is not None: # we avoid padding while mean pooling - mask = sequence_mask(src_len).float() + mask = (~sequence_mask(src_len)).float() mask = mask / src_len.unsqueeze(1).float() mean = torch.bmm(mask.unsqueeze(1), emb).squeeze(1) else: diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 727bafa654..184d44881f 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -228,10 +228,9 @@ def from_opt(cls, opt, embeddings): def forward(self, src, src_len=None): """See :func:`EncoderBase.forward()`""" enc_out = self.embeddings(src) - mask = ~sequence_mask(src_len).unsqueeze(1) - mask = mask.unsqueeze(1) + mask = sequence_mask(src_len).unsqueeze(1).unsqueeze(1) mask = mask.expand(-1, -1, mask.size(3), -1) - # mask is now (batch x 1 x slen x slen) + # Padding mask is now (batch x 1 x slen x slen) # 1 to be expanded to number of heads in MHA # Run the forward pass of every layer of the tranformer. diff --git a/onmt/inference_engine.py b/onmt/inference_engine.py index 1a61d5b09a..5c7f04a985 100755 --- a/onmt/inference_engine.py +++ b/onmt/inference_engine.py @@ -255,7 +255,7 @@ def translate_batch(self, batch, opt): def _translate(self, infer_iter): scores = [] preds = [] - for batch in infer_iter: + for batch, bucket_idx in infer_iter: _scores, _preds = self.translate_batch(batch, self.opt) scores += _scores preds += _preds diff --git a/onmt/inputters/dynamic_iterator.py b/onmt/inputters/dynamic_iterator.py index 014e6fa07b..40479420e1 100644 --- a/onmt/inputters/dynamic_iterator.py +++ b/onmt/inputters/dynamic_iterator.py @@ -163,6 +163,7 @@ def __init__( raise ValueError(f"Invalid argument skip_empty_level={skip_empty_level}") self.skip_empty_level = skip_empty_level self.random_shuffler = RandomShuffler() + self.bucket_idx = 0 @classmethod def from_opt( @@ -247,6 +248,14 @@ def _tuple_to_json_with_tokIDs(self, tuple_bucket): bucket.append(numericalize(self.vocabs, example)) return bucket + def _add_indice(self, bucket): + indice = 0 + indexed_bucket = [] + for ex in bucket: + indexed_bucket.append((ex, indice)) + indice += 1 + return indexed_bucket + def _bucketing(self): """ Add up to bucket_size examples from the mixed corpora according @@ -258,17 +267,19 @@ def _bucketing(self): _bucket_size = self.bucket_size_init else: _bucket_size = self.bucket_size + for ex in self.mixer: bucket.append(ex) if len(bucket) == _bucket_size: - yield self._tuple_to_json_with_tokIDs(bucket) + yield (self._tuple_to_json_with_tokIDs(bucket), self.bucket_idx) + self.bucket_idx += 1 bucket = [] if _bucket_size < self.bucket_size: _bucket_size += self.bucket_size_increment else: _bucket_size = self.bucket_size if bucket: - yield self._tuple_to_json_with_tokIDs(bucket) + yield (self._tuple_to_json_with_tokIDs(bucket), self.bucket_idx) def batch_iter(self, data, batch_size, batch_type="sents", batch_size_multiple=1): """Yield elements from data in chunks of batch_size, @@ -290,11 +301,11 @@ def max_src_tgt(ex): return len(ex["src"]["src_ids"]) minibatch, maxlen, size_so_far, seen = [], 0, 0, set() - for ex in data: + for ex, indice in data: src = ex["src"]["src"] if src not in seen or (self.task != CorpusTask.TRAIN): seen.add(src) - minibatch.append(ex) + minibatch.append((ex, indice)) nbsents = len(minibatch) maxlen = max(max_src_tgt(ex), maxlen) size_so_far = batch_size_fn(nbsents, maxlen) @@ -315,7 +326,7 @@ def max_src_tgt(ex): else: yield minibatch[:-overflowed] minibatch = minibatch[-overflowed:] - maxlen = max([max_src_tgt(ex) for ex in minibatch]) + maxlen = max([max_src_tgt(ex) for ex, ind in minibatch]) size_so_far = batch_size_fn(len(minibatch), maxlen) seen = set() @@ -323,11 +334,9 @@ def max_src_tgt(ex): yield minibatch def __iter__(self): - for bucket in self._bucketing(): - # For TRAIN we need to group examples by length - # for faster performance, but otherwise, sequential. - if self.task == CorpusTask.TRAIN: - bucket = sorted(bucket, key=self.sort_key) + for bucket, bucket_idx in self._bucketing(): + bucket = self._add_indice(bucket) + bucket = sorted(bucket, key=lambda x: self.sort_key(x[0])) p_batch = list( self.batch_iter( bucket, @@ -340,13 +349,13 @@ def __iter__(self): # otherwise sequential if self.task == CorpusTask.TRAIN: p_batch = self.random_shuffler(p_batch) - for minibatch in p_batch: + for i, minibatch in enumerate(p_batch): # for specific case of rnn_packed need to be sorted # within the batch if self.task == CorpusTask.TRAIN: - minibatch.sort(key=self.sort_key, reverse=True) + minibatch.sort(key=lambda x: self.sort_key(x[0]), reverse=True) tensor_batch = tensorify(self.vocabs, minibatch, self.device) - yield tensor_batch + yield (tensor_batch, bucket_idx) class OnDeviceDatasetIter: @@ -355,11 +364,11 @@ def __init__(self, data_iter, device): self.device = device def __iter__(self): - for tensor_batch in self.data_iter: + for (tensor_batch, bucket_idx) in self.data_iter: for key in tensor_batch.keys(): - if key != "src_ex_vocab": + if key not in ["src_ex_vocab", "cid"]: tensor_batch[key] = tensor_batch[key].to(self.device) - yield tensor_batch + yield (tensor_batch, bucket_idx) def build_dynamic_dataset_iter( diff --git a/onmt/inputters/text_corpus.py b/onmt/inputters/text_corpus.py index 504c547a17..5641daae06 100644 --- a/onmt/inputters/text_corpus.py +++ b/onmt/inputters/text_corpus.py @@ -172,43 +172,22 @@ def __init__( self.stride = stride self.offset = offset - def _tokenize(self, stream): - for example in stream: + def _process(self, stream): + for i, example in enumerate(stream): example["src"] = example["src"].strip("\n").split() example["src_original"] = example["src_original"].strip("\n").split() if "src_feats" in example: example["src_feats"] = [ feat.strip("\n").split() for feat in example["src_feats"] ] - if example["tgt"] is not None: - example["tgt"] = example["tgt"].strip("\n").split() - example["tgt_original"] = example["tgt_original"].strip("\n").split() + line_number = i * self.stride + self.offset + example["cid_line_number"] = line_number + example["cid"] = self.cid if "align" in example: example["align"] = example["align"].strip("\n").split() - yield example - - def _transform(self, stream): - for example in stream: - # NOTE: moved to dynamic_iterator.py cf process() - # item = self.transform.apply( - # example, is_train=self.infinitely, corpus_name=self.cid) - item = (example, self.transform, self.cid) - if item is not None: - yield item - report_msg = self.transform.stats() - if report_msg != "": - logger.info( - "* Transform statistics for {}({:.2f}%):\n{}\n".format( - self.cid, 100 / self.stride, report_msg - ) - ) - - def _add_index(self, stream): - for i, item in enumerate(stream): - example = item[0] - line_number = i * self.stride + self.offset - example["indices"] = line_number if example["tgt"] is not None: + example["tgt"] = example["tgt"].strip("\n").split() + example["tgt_original"] = example["tgt_original"].strip("\n").split() if ( len(example["src"]) == 0 or len(example["tgt"]) == 0 @@ -221,16 +200,21 @@ def _add_index(self, stream): elif self.skip_empty_level == "warning": logger.warning(empty_msg) if len(example["src"]) == 0 and len(example["tgt"]) == 0: - yield item + yield (example, self.transform, self.cid) continue - yield item + yield (example, self.transform, self.cid) + report_msg = self.transform.stats() + if report_msg != "": + logger.info( + "* Transform statistics for {}({:.2f}%):\n{}\n".format( + self.cid, 100 / self.stride, report_msg + ) + ) def __iter__(self): corpus_stream = self.corpus.load(stride=self.stride, offset=self.offset) - tokenized_corpus = self._tokenize(corpus_stream) - transformed_corpus = self._transform(tokenized_corpus) - indexed_corpus = self._add_index(transformed_corpus) - yield from indexed_corpus + corpus = self._process(corpus_stream) + yield from corpus def build_corpora_iters( diff --git a/onmt/inputters/text_utils.py b/onmt/inputters/text_utils.py index 4795b9ef0a..605e15259c 100644 --- a/onmt/inputters/text_utils.py +++ b/onmt/inputters/text_utils.py @@ -89,6 +89,8 @@ def process(task, bucket, **kwargs): transform_cid_to_examples[transform_cid].append(example) processed_bucket = [] + # careful below it will return a bucket sorted by corpora + # but we sort by length later and shuffle batches for (transform, cid), sub_bucket in transform_cid_to_examples.items(): transf_bucket = transform.batch_apply( sub_bucket, is_train=(task == CorpusTask.TRAIN), corpus_name=cid @@ -103,7 +105,8 @@ def process(task, bucket, **kwargs): # 'tgt': {'tgt': ...}, # 'src_original': ['tok1', ...'tokn'], # 'tgt_original': ['tok1', ...'tokm'], - # 'indices' : seq in bucket + # 'cid': corpus id + # 'cid_line_number' : cid line number # 'align': ..., # } if len(processed_bucket) > 0: @@ -173,13 +176,17 @@ def tensorify(vocabs, minibatch, device): 'tgt': {'tgt': ..., 'tgt_ids': ...}, 'src_original': ['tok1', ...'tokn'], 'tgt_original': ['tok1', ...'tokm'], - 'indices' : seq in bucket + 'cid': corpus id + 'cid_line_number' : corpus id line number + 'ind_in_bucket': index in bucket 'align': ..., } Returns Dict of batch Tensors {'src': [seqlen, batchsize, n_feats+1], 'tgt' : [seqlen, batchsize, n_feats=1], - 'indices' : [batchsize], + 'cid': [batchsize], + 'cid_line_number' : [batchsize], + 'ind_in_bucket': [batchsize], 'srclen': [batchsize], 'tgtlen': [batchsize], 'align': alignment sparse tensor @@ -188,18 +195,18 @@ def tensorify(vocabs, minibatch, device): tensor_batch = {} tbatchsrc = [ torch.tensor(ex["src"]["src_ids"], dtype=torch.long, device=device) - for ex in minibatch + for ex, indice in minibatch ] padidx = vocabs["src"][DefaultTokens.PAD] tbatchsrc = pad_sequence(tbatchsrc, batch_first=True, padding_value=padidx) - if "feats" in minibatch[0]["src"]: + if "feats" in minibatch[0][0]["src"]: tbatchfs = [tbatchsrc] - for feat_id in range(len(minibatch[0]["src"]["feats"])): + for feat_id in range(len(minibatch[0][0]["src"]["feats"])): tbatchfeat = [ torch.tensor( ex["src"]["feats"][feat_id], dtype=torch.long, device=device ) - for ex in minibatch + for ex, indice in minibatch ] padidx = vocabs["src_feats"][feat_id][DefaultTokens.PAD] tbatchfeat = pad_sequence( @@ -212,66 +219,75 @@ def tensorify(vocabs, minibatch, device): tbatchsrc = tbatchsrc[:, :, None] tensor_batch["src"] = tbatchsrc - tensor_batch["indices"] = torch.tensor( - [ex["indices"] for ex in minibatch], dtype=torch.long, device=device - ) + tensor_batch["srclen"] = torch.tensor( - [len(ex["src"]["src_ids"]) for ex in minibatch], dtype=torch.long, device=device + [len(ex["src"]["src_ids"]) for ex, indice in minibatch], + dtype=torch.long, + device=device, ) - if minibatch[0]["tgt"] is not None: + if minibatch[0][0]["tgt"] is not None: tbatchtgt = [ torch.tensor(ex["tgt"]["tgt_ids"], dtype=torch.long, device=device) - for ex in minibatch + for ex, indice in minibatch ] padidx = vocabs["tgt"][DefaultTokens.PAD] tbatchtgt = pad_sequence(tbatchtgt, batch_first=True, padding_value=padidx) tbatchtgt = tbatchtgt[:, :, None] tbatchtgtlen = torch.tensor( - [len(ex["tgt"]["tgt_ids"]) for ex in minibatch], + [len(ex["tgt"]["tgt_ids"]) for ex, indice in minibatch], dtype=torch.long, device=device, ) tensor_batch["tgt"] = tbatchtgt tensor_batch["tgtlen"] = tbatchtgtlen - if "align" in minibatch[0].keys() and minibatch[0]["align"] is not None: + if "align" in minibatch[0][0].keys() and minibatch[0][0]["align"] is not None: sparse_idx = [] - for i, ex in enumerate(minibatch): + for i, (ex, indice) in enumerate(minibatch): for src, tgt in parse_align_idx(ex["align"]): sparse_idx.append([i, tgt + 1, src]) tbatchalign = torch.tensor(sparse_idx, dtype=torch.long, device=device) tensor_batch["align"] = tbatchalign - if "src_map" in minibatch[0].keys(): - src_vocab_size = max([max(ex["src_map"]) for ex in minibatch]) + 1 + if "src_map" in minibatch[0][0].keys(): + src_vocab_size = max([max(ex["src_map"]) for ex, indice in minibatch]) + 1 src_map = torch.zeros( len(tensor_batch["srclen"]), tbatchsrc.size(1), src_vocab_size, device=device, ) - for i, ex in enumerate(minibatch): + for i, (ex, indice) in enumerate(minibatch): for j, t in enumerate(ex["src_map"]): src_map[i, j, t] = 1 tensor_batch["src_map"] = src_map - if "alignment" in minibatch[0].keys(): + if "alignment" in minibatch[0][0].keys(): alignment = torch.zeros( len(tensor_batch["srclen"]), tbatchtgt.size(1), dtype=torch.long, device=device, ) - for i, ex in enumerate(minibatch): + for i, (ex, indice) in enumerate(minibatch): alignment[i, : len(ex["alignment"])] = torch.tensor( ex["alignment"], dtype=torch.long, device=device ) tensor_batch["alignment"] = alignment - if "src_ex_vocab" in minibatch[0].keys(): - tensor_batch["src_ex_vocab"] = [ex["src_ex_vocab"] for ex in minibatch] + if "src_ex_vocab" in minibatch[0][0].keys(): + tensor_batch["src_ex_vocab"] = [ex["src_ex_vocab"] for ex, indice in minibatch] + tensor_batch["ind_in_bucket"] = torch.tensor( + [indice for ex, indice in minibatch], dtype=torch.long, device=device + ) + tensor_batch["cid"] = [ex["cid"] for ex, indice in minibatch] + tensor_batch["cid_line_number"] = torch.tensor( + [ex["cid_line_number"] for ex, indice in minibatch], + dtype=torch.long, + device=device, + ) return tensor_batch @@ -285,11 +301,12 @@ def textbatch_to_tensor(vocabs, batch, device, is_train=False): for i, ex in enumerate(batch): # Keep it consistent with dynamic data ex["srclen"] = len(ex["src"]["src"].split()) - ex["indices"] = i + ex["in_in_bucket"] = i + ex["cid"] = "text" + ex["cid_line_number"] = i ex["align"] = None - numeric.append(numericalize(vocabs, ex)) - numeric.sort(key=text_sort_key, reverse=True) - infer_iter = [tensorify(vocabs, numeric, device)] + numeric.append((numericalize(vocabs, ex), i)) + infer_iter = [(tensorify(vocabs, numeric, device), 0)] # force bucket_idx to 0 return infer_iter diff --git a/onmt/model_builder.py b/onmt/model_builder.py index e37e4dcce5..c78831a8c4 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -123,6 +123,7 @@ def load_test_model(opt, device_id=0, model_path=None): model_opt.attention_dropout = ( 0.0 # required to force no dropout at inference with flash ) + model = build_base_model(model_opt, vocabs) precision = torch.float32 @@ -162,6 +163,7 @@ def load_test_model(opt, device_id=0, model_path=None): for name, module in model.named_modules(): if hasattr(module, "dropout_p"): module.dropout_p = 0.0 + return vocabs, model, model_opt diff --git a/onmt/modules/copy_generator.py b/onmt/modules/copy_generator.py index 6f83381516..34a933aa3a 100644 --- a/onmt/modules/copy_generator.py +++ b/onmt/modules/copy_generator.py @@ -19,7 +19,7 @@ def collapse_copy_scores( src_vocab = batch["src_ex_vocab"][b] else: batch_id = batch_offset[b] if batch_offset is not None else b - index = batch["indices"].data[batch_id] + index = batch["ind_in_bucket"].data[batch_id] src_vocab = src_vocabs[index] for i in range(1, len(src_vocab)): @@ -29,8 +29,8 @@ def collapse_copy_scores( blank.append(offset + i) fill.append(ti) if blank: - blank = torch.Tensor(blank).type_as(batch["indices"].data) - fill = torch.Tensor(fill).type_as(batch["indices"].data) + blank = torch.Tensor(blank).type_as(batch["ind_in_bucket"].data) + fill = torch.Tensor(fill).type_as(batch["ind_in_bucket"].data) score = scores[:, b] if batch_dim == 1 else scores[b] score.index_add_(1, fill, score.index_select(1, blank)) score.index_fill_(1, blank, 1e-10) diff --git a/onmt/modules/global_attention.py b/onmt/modules/global_attention.py index 493f49005d..9fb2b27c33 100644 --- a/onmt/modules/global_attention.py +++ b/onmt/modules/global_attention.py @@ -169,7 +169,7 @@ def forward(self, src, enc_out, src_len=None, coverage=None): align = self.score(src, enc_out) if src_len is not None: - mask = sequence_mask(src_len, max_len=align.size(-1)) + mask = ~sequence_mask(src_len, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(~mask, -float("inf")) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 31813cb7e3..de5a6d085c 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -53,7 +53,7 @@ def relative_matmul(x: Tensor, z: Tensor, transpose: bool) -> Tensor: https://arxiv.org/pdf/1803.02155.pdf x shape [batch_size x heads x q_len x k_len] """ - batch_size, heads, length = x.size() + batch_size, heads, length, _ = x.size() x_t = x.permute(2, 0, 1, 3) x_t_r = x_t.contiguous().view(length, heads * batch_size, -1) if transpose: diff --git a/onmt/tests/pull_request_chk.sh b/onmt/tests/pull_request_chk.sh index 0bb7b97131..99bfc81680 100755 --- a/onmt/tests/pull_request_chk.sh +++ b/onmt/tests/pull_request_chk.sh @@ -176,8 +176,7 @@ ${PYTHON} onmt/bin/train.py \ [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} - -echo -n " [+] Testing NMT training w/ validation with dynamic scoring and copy ..." +echo -n " [+] Testing NMT transformer training w/ validation with dynamic scoring and copy ..." ${PYTHON} onmt/bin/train.py \ -config ${DATA_DIR}/data.yaml \ -src_vocab $TMP_OUT_DIR/onmt.vocab.src \ @@ -200,6 +199,7 @@ ${PYTHON} onmt/bin/train.py \ -tensorboard "true" \ -scoring_debug "true" \ -copy_attn \ + -position_encoding \ -dump_preds $TMP_OUT_DIR/dump_pred \ -tensorboard_log_dir $TMP_OUT_DIR/logs_dynamic-scoring_and_copy >> ${LOG_FILE} 2>&1 @@ -208,6 +208,99 @@ ${PYTHON} onmt/tests/test_events.py --logdir $TMP_OUT_DIR/logs_dynamic-scoring_a echo "Succeeded" | tee -a ${LOG_FILE} rm -r $TMP_OUT_DIR/logs_dynamic-scoring_and_copy +echo -n " [+] Testing NMT transformer training w/ validation with dynamic scoring and maxrelative ..." +${PYTHON} onmt/bin/train.py \ + -config ${DATA_DIR}/data.yaml \ + -src_vocab $TMP_OUT_DIR/onmt.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/onmt.vocab.tgt \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -encoder_type transformer \ + -decoder_type transformer \ + -layers 4 \ + -word_vec_size 16 \ + -hidden_size 16 \ + -num_workers 0 -bucket_size 1024 \ + -heads 2 \ + -transformer_ff 64 \ + -bucket_size 1024 \ + -train_steps 10 \ + -report_every 2 \ + -valid_steps 5 \ + -valid_metrics "BLEU" "TER" \ + -tensorboard "true" \ + -scoring_debug "true" \ + -max_relative_positions 8 \ + -dump_preds $TMP_OUT_DIR/dump_pred \ + -tensorboard_log_dir $TMP_OUT_DIR/logs_dynamic-scoring_and_relative >> ${LOG_FILE} 2>&1 + +${PYTHON} onmt/tests/test_events.py --logdir $TMP_OUT_DIR/logs_dynamic-scoring_and_relative -tensorboard_checks valid_metrics +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} +rm -r $TMP_OUT_DIR/logs_dynamic-scoring_and_relative + +echo -n " [+] Testing NMT transformer training w/ validation with dynamic scoring and rotary ..." +${PYTHON} onmt/bin/train.py \ + -config ${DATA_DIR}/data.yaml \ + -src_vocab $TMP_OUT_DIR/onmt.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/onmt.vocab.tgt \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -encoder_type transformer \ + -decoder_type transformer \ + -layers 4 \ + -word_vec_size 16 \ + -hidden_size 16 \ + -num_workers 0 -bucket_size 1024 \ + -heads 2 \ + -transformer_ff 64 \ + -bucket_size 1024 \ + -train_steps 10 \ + -report_every 2 \ + -valid_steps 5 \ + -valid_metrics "BLEU" "TER" \ + -tensorboard "true" \ + -scoring_debug "true" \ + -max_relative_positions -1 \ + -dump_preds $TMP_OUT_DIR/dump_pred \ + -tensorboard_log_dir $TMP_OUT_DIR/logs_dynamic-scoring_and_rotary >> ${LOG_FILE} 2>&1 + +${PYTHON} onmt/tests/test_events.py --logdir $TMP_OUT_DIR/logs_dynamic-scoring_and_rotary -tensorboard_checks valid_metrics +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} +rm -r $TMP_OUT_DIR/logs_dynamic-scoring_and_rotary + +echo -n " [+] Testing NMT transformer training w/ validation with dynamic scoring and alibi ..." +${PYTHON} onmt/bin/train.py \ + -config ${DATA_DIR}/data.yaml \ + -src_vocab $TMP_OUT_DIR/onmt.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/onmt.vocab.tgt \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -encoder_type transformer \ + -decoder_type transformer \ + -layers 4 \ + -word_vec_size 16 \ + -hidden_size 16 \ + -num_workers 0 -bucket_size 1024 \ + -heads 2 \ + -transformer_ff 64 \ + -bucket_size 1024 \ + -train_steps 10 \ + -report_every 2 \ + -valid_steps 5 \ + -valid_metrics "BLEU" "TER" \ + -tensorboard "true" \ + -scoring_debug "true" \ + -max_relative_positions -2 \ + -dump_preds $TMP_OUT_DIR/dump_pred \ + -tensorboard_log_dir $TMP_OUT_DIR/logs_dynamic-scoring_and_alibi >> ${LOG_FILE} 2>&1 + +${PYTHON} onmt/tests/test_events.py --logdir $TMP_OUT_DIR/logs_dynamic-scoring_and_alibi -tensorboard_checks valid_metrics +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} +rm -r $TMP_OUT_DIR/logs_dynamic-scoring_and_alibi + echo -n " [+] Testing LM training..." ${PYTHON} onmt/bin/train.py \ -config ${DATA_DIR}/lm_data.yaml \ diff --git a/onmt/trainer.py b/onmt/trainer.py index 4cee90470e..a570039707 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -235,7 +235,7 @@ def _accum_batches(self, iterator): batches = [] normalization = 0 self.accum_count = self._accum_count(self.optim.training_step) - for batch in iterator: + for batch, bucket_idx in iterator: batches.append(batch) if self.norm_method == "tokens": num_tokens = ( @@ -388,7 +388,7 @@ def validate(self, valid_iter, moving_average=None): with torch.no_grad(): stats = onmt.utils.Statistics() start = time.time() - for batch in valid_iter: + for batch, bucket_idx in valid_iter: src = batch["src"] src_len = batch["srclen"] tgt = batch["tgt"] diff --git a/onmt/transforms/docify.py b/onmt/transforms/docify.py index a146697f93..b6ae01b1d4 100644 --- a/onmt/transforms/docify.py +++ b/onmt/transforms/docify.py @@ -68,7 +68,9 @@ def batch_apply(self, batch, is_train=False, stats=None, **kwargs): doc = {} doc["src"] = [] doc["tgt"] = [] - doc["indices"] = 0 + doc["ind_in_bucket"] = 0 + doc["cid"] = "" + doc["cid_line_number"] = 0 for ex, _, cid in batch: if ex["tgt"] is not None: @@ -80,7 +82,9 @@ def batch_apply(self, batch, is_train=False, stats=None, **kwargs): doc = {} doc["src"] = [] doc["tgt"] = [] - doc["indices"] = ex["indices"] + doc["ind_in_bucket"] = ex["ind_in_bucket"] + doc["cid"] = ex["cid"] + doc["cid_line_number"] = ex["cid_line_number"] elif cur_len > self.doc_length: if len(doc["src"]) == 0: # case 1st ex is already longer @@ -106,7 +110,9 @@ def batch_apply(self, batch, is_train=False, stats=None, **kwargs): doc = {} doc["src"] = [] doc["tgt"] = [] - doc["indices"] = ex["indices"] + doc["ind_in_bucket"] = ex["ind_in_bucket"] + doc["cid"] = ex["cid"] + doc["cid_line_number"] = ex["cid_line_number"] else: cur_len = len(doc["src"] + ex["src"]) doc["tgt"] = None @@ -132,7 +138,9 @@ def batch_apply(self, batch, is_train=False, stats=None, **kwargs): trf_batch.append((doc, self, cid)) doc = {} doc["src"] = [] - doc["indices"] = ex["indices"] + doc["ind_in_bucket"] = ex["ind_in_bucket"] + doc["cid"] = ex["cid"] + doc["cid_line_number"] = ex["cid_line_number"] if len(doc["src"]) > 0: trf_batch.append((doc, self, cid)) return trf_batch diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py index b02d6b31cd..f9f13a89d8 100644 --- a/onmt/translate/beam_search.py +++ b/onmt/translate/beam_search.py @@ -1,7 +1,6 @@ import torch from onmt.translate import penalties from onmt.translate.decode_strategy import DecodeStrategy - import warnings @@ -184,6 +183,51 @@ def _pick(self, log_probs): return topk_scores, topk_ids + def beams_non_finished(self, i, predictions, attention, step): + b = self._batch_offset[i] + + if any(self.is_finished_list[i]): + # Store finished hypotheses for this example in the batch. + for j in [ + k for k, fin in enumerate(self.is_finished_list[i]) if fin + ]: # Beam level: finished beam j in example i of batch + if self.ratio > 0: + s = self.topk_scores[i, j] / (step + 1) + self.best_scores[b] = max(s, self.best_scores[b]) + self.hypotheses[b].append( + ( + self.topk_scores[i, j], + predictions[i, j, 1:], # Ignore start_token. + attention[i, j, :, : self.src_len[b]] + if attention is not None + else None, + ) + ) + if len(self.hypotheses[b]) >= 2: + self.hypotheses[b] = sorted( + self.hypotheses[b], key=lambda x: x[0], reverse=True + ) + + # End condition is the top beam finished and we can return + # n_best hypotheses. + if self.ratio > 0: + pred_len = self.src_len[b] * self.ratio + finish_flag = ( + (self.topk_scores[i, 0] / pred_len) <= self.best_scores[b] + ) or all(self.is_finished_list[i]) + else: + # early stop when top beam is finished + finish_flag = self.is_finished_list[i][0] + + if finish_flag and len(self.hypotheses[b]) >= self.n_best: + for score, pred, attn in self.hypotheses[b][: self.n_best]: + self.scores[b].append(score) + self.predictions[b].append(pred) # ``(batch, n_best,)`` + self.attention[b].append(attn if attn is not None else []) + return False + else: + return True + def update_finished(self): # Penalize beams that finished. _B_old = self.topk_log_probs.shape[0] @@ -201,51 +245,12 @@ def update_finished(self): if self.alive_attn is not None else None ) - non_finished_batch = [] - - for i in range(len(self.is_finished_list)): # Batch level - b = self._batch_offset[i] - - if any(self.is_finished_list[i]): - # Store finished hypotheses for this batch. - for j in [ - k for k, fin in enumerate(self.is_finished_list[i]) if fin - ]: # Beam level: finished beam j in batch i - if self.ratio > 0: - s = self.topk_scores[i, j] / (step + 1) - self.best_scores[b] = max(s, self.best_scores[b]) - - self.hypotheses[b].append( - ( - self.topk_scores[i, j], - predictions[i, j, 1:], # Ignore start_token. - attention[i, j, :, : self.src_len[b]] - if attention is not None - else None, - ) - ) - # End condition is the top beam finished and we can return - # n_best hypotheses. - if self.ratio > 0: - pred_len = self.src_len[b] * self.ratio - finish_flag = ( - (self.topk_scores[i, 0] / pred_len) <= self.best_scores[b] - ) or all(self.is_finished_list[i]) - else: - # early stop when top beam is finished - finish_flag = self.is_finished_list[i][0] - - if finish_flag and len(self.hypotheses[b]) >= self.beam_size: - best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True)[ - : self.n_best - ] - for score, pred, attn in best_hyp: - self.scores[b].append(score) - self.predictions[b].append(pred) # ``(batch, n_best,)`` - self.attention[b].append(attn if attn is not None else []) - else: - non_finished_batch.append(i) + non_finished_batch = [ + i + for i in range(len(self.is_finished_list)) + if self.beams_non_finished(i, predictions, attention, step) + ] non_finished = torch.tensor(non_finished_batch) # If all sentences are translated, no need to go further. diff --git a/onmt/translate/translation.py b/onmt/translate/translation.py index 8341946a3c..ab462041f4 100644 --- a/onmt/translate/translation.py +++ b/onmt/translate/translation.py @@ -40,7 +40,7 @@ def _build_target_tokens(self, src, srclen, pred, attn, voc, dyn_voc): voc[tok] if tok < len(voc) else dyn_voc.ids_to_tokens[tok - len(self.vocabs["src"].ids_to_tokens)] - for tok in pred + for tok in pred.tolist() ] if tokens[-1] == DefaultTokens.EOS: tokens = tokens[:-1] @@ -73,7 +73,7 @@ def from_batch(self, translation_batch): translation_batch["attention"], translation_batch["alignment"], translation_batch["gold_score"], - batch["indices"], + batch["ind_in_bucket"], ) if not any(align): # when align is a empty nested list @@ -159,7 +159,7 @@ class Translation(object): "gold_sent", "gold_score", "word_aligns", - "indices", + "ind_in_bucket", ] def __init__( @@ -172,7 +172,7 @@ def __init__( tgt_sent, gold_score, word_aligns, - indices, + ind_in_bucket, ): self.src = src self.srclen = srclen @@ -182,7 +182,7 @@ def __init__( self.gold_sent = tgt_sent self.gold_score = gold_score self.word_aligns = word_aligns - self.indices = indices + self.ind_in_bucket = ind_in_bucket def log(self, sent_number, src_raw=""): """ diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index d0b6d9344a..41c607bf34 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -346,7 +346,7 @@ def _translate( def _maybe_retranslate(translations, batch): """Here we handle the cases of mismatch in number of segments between source and target. We re-translate seg by seg.""" - inds, perm = torch.sort(batch["indices"]) + inds, perm = torch.sort(batch["ind_in_bucket"]) trans_copy = deepcopy(translations) inserted_so_far = 0 for j, trans in enumerate(translations): @@ -382,7 +382,7 @@ def _maybe_retranslate(translations, batch): t_sub_batch = { "src": t_sub_src.to(device), "srclen": t_sub_src_len.to(device), - "indices": t_sub_src_ind.to(device), + "ind_in_bucket": t_sub_src_ind.to(device), } # new sub-batch ready to be translated sub_data = self.translate_batch(t_sub_batch, attn_debug) @@ -395,36 +395,29 @@ def _maybe_retranslate(translations, batch): inserted_so_far += len(sub_src) - 1 return trans_copy - for batch in infer_iter: - batch_data = self.translate_batch(batch, attn_debug) - - translations = xlation_builder.from_batch(batch_data) - if ( - not isinstance(self, GeneratorLM) - and self._tgt_sep_idx != self._tgt_unk_idx - and (batch["src"] == self._tgt_sep_idx).any().item() - ): - # For seq2seq when we need to force doc to spit the same number of sents - translations = _maybe_retranslate(translations, batch) - + def _process_bucket(bucket_translations): + bucket_scores = [] + bucket_predictions = [] + bucket_score = 0 + bucket_words = 0 + bucket_gold_score = 0 + bucket_gold_words = 0 voc_src = self.vocabs["src"].ids_to_tokens - - for j, trans in enumerate(translations): - all_scores += [trans.pred_scores[: self.n_best]] - pred_score_total += trans.pred_scores[0] - pred_words_total += len(trans.pred_sents[0]) + bucket_translations = sorted( + bucket_translations, key=lambda x: x.ind_in_bucket + ) + for trans in bucket_translations: + bucket_scores += [trans.pred_scores[: self.n_best]] + bucket_score += trans.pred_scores[0] + bucket_words += len(trans.pred_sents[0]) if "tgt" in batch.keys(): - gold_score_total += trans.gold_score - gold_words_total += len(trans.gold_sent) + 1 + bucket_gold_score += trans.gold_score + bucket_gold_words += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[: self.n_best] ] - n_best_scores = [ - score.item() for score in trans.pred_scores[: self.n_best] - ] - if self.report_align: align_pharaohs = [ build_align_pharaoh(align) @@ -441,14 +434,16 @@ def _maybe_retranslate(translations, batch): if transform_pipe is not None: n_best_preds = transform_pipe.batch_apply_reverse(n_best_preds) - all_predictions += [n_best_preds] - - out_all = [ - pred + "\t" + str(score) - for (pred, score) in zip(n_best_preds, n_best_scores) - ] + bucket_predictions += [n_best_preds] if self.with_score: + n_best_scores = [ + score.item() for score in trans.pred_scores[: self.n_best] + ] + out_all = [ + pred + "\t" + str(score) + for (pred, score) in zip(n_best_preds, n_best_scores) + ] self.out_file.write("\n".join(out_all) + "\n") else: self.out_file.write("\n".join(n_best_preds) + "\n") @@ -496,6 +491,72 @@ def _maybe_retranslate(translations, batch): self.logger.info(output) else: os.write(1, output.encode("utf-8")) + return ( + bucket_scores, + bucket_predictions, + bucket_score, + bucket_words, + bucket_gold_score, + bucket_gold_words, + ) + + bucket_translations = [] + prev_idx = 0 + + for batch, bucket_idx in infer_iter: + + batch_data = self.translate_batch(batch, attn_debug) + + translations = xlation_builder.from_batch(batch_data) + if ( + not isinstance(self, GeneratorLM) + and self._tgt_sep_idx != self._tgt_unk_idx + and (batch["src"] == self._tgt_sep_idx).any().item() + ): + # For seq2seq when we need to force doc to spit the same number of sents + translations = _maybe_retranslate(translations, batch) + + bucket_translations += translations + + if ( + not isinstance(infer_iter, list) + and len(bucket_translations) >= infer_iter.bucket_size + ): + bucket_idx += 1 + + if bucket_idx != prev_idx: + prev_idx = bucket_idx + ( + bucket_scores, + bucket_predictions, + bucket_score, + bucket_words, + bucket_gold_score, + bucket_gold_words, + ) = _process_bucket(bucket_translations) + all_scores += bucket_scores + all_predictions += bucket_predictions + pred_score_total += bucket_score + pred_words_total += bucket_words + gold_score_total += bucket_gold_score + gold_words_total += bucket_gold_words + bucket_translations = [] + + if len(bucket_translations) > 0: + ( + bucket_scores, + bucket_predictions, + bucket_score, + bucket_words, + bucket_gold_score, + bucket_gold_words, + ) = _process_bucket(bucket_translations) + all_scores += bucket_scores + all_predictions += bucket_predictions + pred_score_total += bucket_score + pred_words_total += bucket_words + gold_score_total += bucket_gold_score + gold_words_total += bucket_gold_words end_time = time.time() diff --git a/onmt/utils/misc.py b/onmt/utils/misc.py index 5e3fa390fb..2b404c4998 100644 --- a/onmt/utils/misc.py +++ b/onmt/utils/misc.py @@ -54,14 +54,8 @@ def sequence_mask(lengths, max_len=None): """ Creates a boolean mask from sequence lengths. """ - batch_size = lengths.numel() max_len = max_len or lengths.max() - return ( - torch.arange(0, max_len, device=lengths.device) - .type_as(lengths) - .repeat(batch_size, 1) - .lt(lengths.unsqueeze(1)) - ) + return torch.arange(0, max_len, device=lengths.device) >= lengths.unsqueeze(1) def tile(x, count, dim=0): diff --git a/tools/LM_scoring.py b/tools/LM_scoring.py index 10acd24eee..823571fa0e 100644 --- a/tools/LM_scoring.py +++ b/tools/LM_scoring.py @@ -102,7 +102,7 @@ def main(): cumul_length += batch["tgt"][:, 1:, 0].ne(padding_idx).sum().cpu() # Now we need to rearrange the batch of ppl # in the original order with indices - sent_ppl_orig = ppl.gather(0, batch["indices"].argsort(0)) + sent_ppl_orig = ppl.gather(0, batch["cid_line_number"].argsort(0)) for j in range(batch_size): ppl_file.write(str(sent_ppl_orig[j].item()) + "\n") logger.info(