From 2354c3059d08c2e093d77f585ff938b08c7c6908 Mon Sep 17 00:00:00 2001 From: vince62s Date: Tue, 12 Mar 2024 16:48:53 +0100 Subject: [PATCH] fix valid transform at scoring and tokenize with onmt_tokenize when docify --- onmt/tests/test_transform.py | 8 ++++---- onmt/trainer.py | 2 +- onmt/transforms/tokenize.py | 11 ++++++++++- onmt/utils/scoring_utils.py | 15 ++++----------- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py index 4eef58d15f..78d1554c77 100644 --- a/onmt/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -327,12 +327,12 @@ def test_pyonmttok_bpe(self): "struc■", "tion■", ":", - "\n■", + "⦅newline⦆■", "in■", "struc■", "tion■", - "\n■", - "\n■", + "⦅newline⦆■", + "⦅newline⦆■", "#■", "#■", "#", @@ -342,7 +342,7 @@ def test_pyonmttok_bpe(self): "on■", "se", ":", - "\n", + "⦅newline⦆", "", "respon■", "se", diff --git a/onmt/trainer.py b/onmt/trainer.py index d580b19fd1..1b30d0e729 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -297,7 +297,7 @@ def train( logger.info( "Start training loop and validate every %d steps...", valid_steps ) - logger.info("Scoring with: {}".format(self.scoring_preparator.transform)) + logger.info("Scoring with: {}".format(self.scoring_preparator.transforms)) total_stats = onmt.utils.Statistics() report_stats = onmt.utils.Statistics() diff --git a/onmt/transforms/tokenize.py b/onmt/transforms/tokenize.py index 17424a94bb..e0162774b4 100644 --- a/onmt/transforms/tokenize.py +++ b/onmt/transforms/tokenize.py @@ -157,7 +157,7 @@ def _tokenize(self, tokens, side="src", is_train=False): """Tokenize a list of words.""" # This method embeds a custom logic to correctly handle certain placeholders # in case the tokenizer doesn't preserve them. - sentence = " ".join(tokens).replace(DefaultTokens.SEP, "\n") + sentence = " ".join(tokens) # Locate the end-of-sentence placeholders. sent_list = sentence.split(DefaultTokens.EOS) # Tokenize each sentence separately. @@ -257,6 +257,7 @@ def tokenize_string(self, string, side="src", is_train=False): """Apply subword sampling or deterministic subwording""" sp_model = self.load_models[side] nbest_size = self.tgt_subword_nbest if side == "tgt" else self.src_subword_nbest + string = string.replace(DefaultTokens.SEP, "\n") if is_train is False or nbest_size in [0, 1]: # derterministic subwording tokens = sp_model.encode(string, out_type=str) @@ -441,6 +442,9 @@ def _parse_opts(self): self.src_other_kwargs = self.opts.src_onmttok_kwargs self.tgt_other_kwargs = self.opts.tgt_onmttok_kwargs self.gpt2_pretok = self.opts.gpt2_pretok + self.preserve_placeholders = self.opts.tgt_onmttok_kwargs.get( + "preserve_placeholders", False + ) @classmethod def get_specials(cls, opts): @@ -558,6 +562,11 @@ def tokenize_string(self, sentence, side="src", is_train=False): segmented.extend(["Ċ", "Ċ"]) else: segmented.append(s) + elif ( + self.src_subword_type == "sentencepiece" and not self.preserve_placeholders + ): + sentence = sentence.replace(DefaultTokens.SEP, "\n") + segmented = tokenizer(sentence) else: segmented = tokenizer(sentence) return segmented diff --git a/onmt/utils/scoring_utils.py b/onmt/utils/scoring_utils.py index d0b5f7f8e5..7f407bd90f 100644 --- a/onmt/utils/scoring_utils.py +++ b/onmt/utils/scoring_utils.py @@ -19,16 +19,10 @@ def __init__(self, vocabs, opt): if self.opt.dump_preds is not None: if not os.path.exists(self.opt.dump_preds): os.makedirs(self.opt.dump_preds) - self.transforms = opt.transforms - transforms_cls = get_transforms_cls(self.transforms) - transforms = make_transforms(self.opt, transforms_cls, self.vocabs) - self.transform = TransformPipe.build_from(transforms.values()) def warm_up(self, transforms): self.transforms = transforms - transforms_cls = get_transforms_cls(self.transforms) - transforms = make_transforms(self.opt, transforms_cls, self.vocabs) - self.transform = TransformPipe.build_from(transforms.values()) + self.transforms_cls = get_transforms_cls(transforms) def translate(self, model, gpu_rank, step): """Compute and save the sentences predicted by the @@ -84,7 +78,7 @@ def translate(self, model, gpu_rank, step): # Reinstantiate the validation iterator - transforms_cls = get_transforms_cls(model_opt._all_transform) + #transforms_cls = get_transforms_cls(model_opt._all_transform) model_opt.num_workers = 0 model_opt.tgt = None @@ -100,7 +94,7 @@ def translate(self, model, gpu_rank, step): valid_iter = build_dynamic_dataset_iter( model_opt, - transforms_cls, + self.transforms_cls, translator.vocabs, task=CorpusTask.VALID, tgt="", # This force to clear the target side (needed when using tgt_file_prefix) @@ -125,12 +119,11 @@ def translate(self, model, gpu_rank, step): # Flatten predictions preds = [x.lstrip() for sublist in preds for x in sublist] - # Save results if len(preds) > 0 and self.opt.scoring_debug: path = os.path.join(self.opt.dump_preds, f"preds.valid_step_{step}.txt") with open(path, "a") as file: - for i in range(len(preds)): + for i in range(len(raw_srcs)): file.write("SOURCE: {}\n".format(raw_srcs[i])) file.write("REF: {}\n".format(raw_refs[i])) file.write("PRED: {}\n\n".format(preds[i]))