Skip to content

Commit

Permalink
fix valid transform at scoring and tokenize with onmt_tokenize when d…
Browse files Browse the repository at this point in the history
…ocify
  • Loading branch information
vince62s committed Mar 12, 2024
1 parent 64afa4c commit 2354c30
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 17 deletions.
8 changes: 4 additions & 4 deletions onmt/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,12 @@ def test_pyonmttok_bpe(self):
"struc■",
"tion■",
":",
"\n■",
"⦅newline⦆■",
"in■",
"struc■",
"tion■",
"\n■",
"\n■",
"⦅newline⦆■",
"⦅newline⦆■",
"#■",
"#■",
"#",
Expand All @@ -342,7 +342,7 @@ def test_pyonmttok_bpe(self):
"on■",
"se",
":",
"\n",
"⦅newline⦆",
"<blank>",
"respon■",
"se",
Expand Down
2 changes: 1 addition & 1 deletion onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def train(
logger.info(
"Start training loop and validate every %d steps...", valid_steps
)
logger.info("Scoring with: {}".format(self.scoring_preparator.transform))
logger.info("Scoring with: {}".format(self.scoring_preparator.transforms))

total_stats = onmt.utils.Statistics()
report_stats = onmt.utils.Statistics()
Expand Down
11 changes: 10 additions & 1 deletion onmt/transforms/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _tokenize(self, tokens, side="src", is_train=False):
"""Tokenize a list of words."""
# This method embeds a custom logic to correctly handle certain placeholders
# in case the tokenizer doesn't preserve them.
sentence = " ".join(tokens).replace(DefaultTokens.SEP, "\n")
sentence = " ".join(tokens)
# Locate the end-of-sentence placeholders.
sent_list = sentence.split(DefaultTokens.EOS)
# Tokenize each sentence separately.
Expand Down Expand Up @@ -257,6 +257,7 @@ def tokenize_string(self, string, side="src", is_train=False):
"""Apply subword sampling or deterministic subwording"""
sp_model = self.load_models[side]
nbest_size = self.tgt_subword_nbest if side == "tgt" else self.src_subword_nbest
string = string.replace(DefaultTokens.SEP, "\n")
if is_train is False or nbest_size in [0, 1]:
# derterministic subwording
tokens = sp_model.encode(string, out_type=str)
Expand Down Expand Up @@ -441,6 +442,9 @@ def _parse_opts(self):
self.src_other_kwargs = self.opts.src_onmttok_kwargs
self.tgt_other_kwargs = self.opts.tgt_onmttok_kwargs
self.gpt2_pretok = self.opts.gpt2_pretok
self.preserve_placeholders = self.opts.tgt_onmttok_kwargs.get(
"preserve_placeholders", False
)

@classmethod
def get_specials(cls, opts):
Expand Down Expand Up @@ -558,6 +562,11 @@ def tokenize_string(self, sentence, side="src", is_train=False):
segmented.extend(["Ċ", "Ċ"])
else:
segmented.append(s)
elif (
self.src_subword_type == "sentencepiece" and not self.preserve_placeholders
):
sentence = sentence.replace(DefaultTokens.SEP, "\n")
segmented = tokenizer(sentence)
else:
segmented = tokenizer(sentence)
return segmented
Expand Down
15 changes: 4 additions & 11 deletions onmt/utils/scoring_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,10 @@ def __init__(self, vocabs, opt):
if self.opt.dump_preds is not None:
if not os.path.exists(self.opt.dump_preds):
os.makedirs(self.opt.dump_preds)
self.transforms = opt.transforms
transforms_cls = get_transforms_cls(self.transforms)
transforms = make_transforms(self.opt, transforms_cls, self.vocabs)
self.transform = TransformPipe.build_from(transforms.values())

def warm_up(self, transforms):
self.transforms = transforms
transforms_cls = get_transforms_cls(self.transforms)
transforms = make_transforms(self.opt, transforms_cls, self.vocabs)
self.transform = TransformPipe.build_from(transforms.values())
self.transforms_cls = get_transforms_cls(transforms)

def translate(self, model, gpu_rank, step):
"""Compute and save the sentences predicted by the
Expand Down Expand Up @@ -84,7 +78,7 @@ def translate(self, model, gpu_rank, step):

# Reinstantiate the validation iterator

transforms_cls = get_transforms_cls(model_opt._all_transform)
#transforms_cls = get_transforms_cls(model_opt._all_transform)
model_opt.num_workers = 0
model_opt.tgt = None

Expand All @@ -100,7 +94,7 @@ def translate(self, model, gpu_rank, step):

valid_iter = build_dynamic_dataset_iter(
model_opt,
transforms_cls,
self.transforms_cls,
translator.vocabs,
task=CorpusTask.VALID,
tgt="", # This force to clear the target side (needed when using tgt_file_prefix)
Expand All @@ -125,12 +119,11 @@ def translate(self, model, gpu_rank, step):

# Flatten predictions
preds = [x.lstrip() for sublist in preds for x in sublist]

# Save results
if len(preds) > 0 and self.opt.scoring_debug:
path = os.path.join(self.opt.dump_preds, f"preds.valid_step_{step}.txt")
with open(path, "a") as file:
for i in range(len(preds)):
for i in range(len(raw_srcs)):
file.write("SOURCE: {}\n".format(raw_srcs[i]))
file.write("REF: {}\n".format(raw_refs[i]))
file.write("PRED: {}\n\n".format(preds[i]))
Expand Down

0 comments on commit 2354c30

Please sign in to comment.