diff --git a/onmt/transforms/transform.py b/onmt/transforms/transform.py index e111b5869d..4784b354a2 100644 --- a/onmt/transforms/transform.py +++ b/onmt/transforms/transform.py @@ -264,13 +264,16 @@ def _repr_args(self): def make_transforms(opts, transforms_cls, vocabs): """Build transforms in `transforms_cls` with vocab of `fields`.""" transforms = {} - for name, transform_cls in transforms_cls.items(): - if transform_cls.require_vocab() and vocabs is None: - logger.warning(f"{transform_cls.__name__} require vocab to apply, skip it.") - continue - transform_obj = transform_cls(opts) - transform_obj.warm_up(vocabs) - transforms[name] = transform_obj + if transforms_cls: + for name, transform_cls in transforms_cls.items(): + if transform_cls.require_vocab() and vocabs is None: + logger.warning( + f"{transform_cls.__name__} require vocab to apply, skip it." + ) + continue + transform_obj = transform_cls(opts) + transform_obj.warm_up(vocabs) + transforms[name] = transform_obj return transforms diff --git a/onmt/utils/scoring_utils.py b/onmt/utils/scoring_utils.py index 7f407bd90f..7dc5c7163a 100644 --- a/onmt/utils/scoring_utils.py +++ b/onmt/utils/scoring_utils.py @@ -5,7 +5,7 @@ from onmt.opts import translate_opts from onmt.constants import CorpusTask from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter -from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe +from onmt.transforms import get_transforms_cls class ScoringPreparator: @@ -19,6 +19,8 @@ def __init__(self, vocabs, opt): if self.opt.dump_preds is not None: if not os.path.exists(self.opt.dump_preds): os.makedirs(self.opt.dump_preds) + self.transforms = None + self.transforms_cls = None def warm_up(self, transforms): self.transforms = transforms @@ -78,7 +80,7 @@ def translate(self, model, gpu_rank, step): # Reinstantiate the validation iterator - #transforms_cls = get_transforms_cls(model_opt._all_transform) + # transforms_cls = get_transforms_cls(model_opt._all_transform) model_opt.num_workers = 0 model_opt.tgt = None