Skip to content

Commit

Permalink
clean up opts
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Jan 3, 2024
1 parent 8bbd18a commit 46c1397
Show file tree
Hide file tree
Showing 13 changed files with 69 additions and 87 deletions.
3 changes: 1 addition & 2 deletions docs/source/examples/replicate_vicuna/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def prune_history(user_messages_sizes, bot_messages_sizes, max_history_size):

def _get_parser():
parser = ArgumentParser(description="chatbot.py")
opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
opts.model_opts(parser)
return parser

Expand Down
3 changes: 1 addition & 2 deletions docs/source/examples/replicate_vicuna/simple_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@

def _get_parser():
parser = ArgumentParser(description="simple_inference_engine_py.py")
opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
opts.model_opts(parser)
return parser

Expand Down
4 changes: 1 addition & 3 deletions eval_llm/MMLU-FR/run_mmlu_opennmt_fr.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ def evaluate(opt):

def _get_parser():
parser = ArgumentParser(description="run_mmlu_opennmt_fr.py")

opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
return parser


Expand Down
4 changes: 1 addition & 3 deletions eval_llm/MMLU/run_mmlu_opennmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,7 @@ def evaluate(opt):

def _get_parser():
parser = ArgumentParser(description="run_mmlu_opennmt.py")

opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
return parser


Expand Down
4 changes: 2 additions & 2 deletions onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from onmt.utils.logging import init_logger, logger
from onmt.utils.misc import set_random_seed, check_path
from onmt.utils.parse import ArgumentParser
from onmt.opts import dynamic_prepare_opts
from onmt.opts import data_prepare_opts
from onmt.inputters.text_corpus import build_corpora_iters, get_corpora
from onmt.inputters.text_utils import process, append_features_to_text
from onmt.transforms import make_transforms, get_transforms_cls
Expand Down Expand Up @@ -273,7 +273,7 @@ def save_counter(counter, save_path):

def _get_parser():
parser = ArgumentParser(description="build_vocab.py")
dynamic_prepare_opts(parser, build_vocab_only=True)
data_prepare_opts(parser, build_vocab_only=True)
return parser


Expand Down
4 changes: 2 additions & 2 deletions onmt/bin/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _get_parser():
parser.add_argument("--url_root", type=str, default="/translator")
parser.add_argument("--debug", "-d", action="store_true")
parser.add_argument(
"--config", "-c", type=str, default="./available_models/conf.json"
"--model_config", "-m", type=str, default="./available_models/conf.json"
)
return parser

Expand All @@ -155,7 +155,7 @@ def main():
parser = _get_parser()
args = parser.parse_args()
start(
args.config,
args.model_config,
url_root=args.url_root,
host=args.ip,
port=args.port,
Expand Down
5 changes: 2 additions & 3 deletions onmt/bin/translate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from onmt.inference_engine import InferenceEnginePY
from onmt.opts import config_opts, translate_opts
from onmt.opts import translate_opts
from onmt.utils.parse import ArgumentParser
from onmt.utils.misc import use_gpu, set_random_seed
from torch.profiler import profile, record_function, ProfilerActivity
Expand All @@ -23,8 +23,7 @@ def translate(opt):

def _get_parser():
parser = ArgumentParser(description="translate.py")
config_opts(parser)
translate_opts(parser, dynamic=True)
translate_opts(parser)
return parser


Expand Down
115 changes: 54 additions & 61 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def _add_reproducibility_opts(parser):
)


def _add_dynamic_corpus_opts(parser, build_vocab_only=False):
"""Options related to training corpus, type: a list of dictionary."""
def _add_dataset_opts(parser, build_vocab_only=False):
"""Options related to training datasets, type: a list of dictionary."""
group = parser.add_argument_group("Data")
group.add(
"-data",
Expand Down Expand Up @@ -278,7 +278,7 @@ def _add_features_opts(parser):
)


def _add_dynamic_vocab_opts(parser, build_vocab_only=False):
def _add_vocab_opts(parser, build_vocab_only=False):
"""Options related to vocabulary and features.
Add all options relate to vocabulary or features to parser.
Expand Down Expand Up @@ -412,7 +412,7 @@ def _add_dynamic_vocab_opts(parser, build_vocab_only=False):
)


def _add_dynamic_transform_opts(parser):
def _add_transform_opts(parser):
"""Options related to transforms.
Options that specified in the definitions of each transform class
Expand All @@ -422,17 +422,17 @@ def _add_dynamic_transform_opts(parser):
transform_cls.add_options(parser)


def dynamic_prepare_opts(parser, build_vocab_only=False):
def data_prepare_opts(parser, build_vocab_only=False):
"""Options related to data prepare in dynamic mode.
Add all dynamic data prepare related options to parser.
If `build_vocab_only` set to True, then only contains options that
will be used in `onmt/bin/build_vocab.py`.
"""
config_opts(parser)
_add_dynamic_corpus_opts(parser, build_vocab_only=build_vocab_only)
_add_dynamic_vocab_opts(parser, build_vocab_only=build_vocab_only)
_add_dynamic_transform_opts(parser)
_add_dataset_opts(parser, build_vocab_only=build_vocab_only)
_add_vocab_opts(parser, build_vocab_only=build_vocab_only)
_add_transform_opts(parser)

if build_vocab_only:
_add_reproducibility_opts(parser)
Expand Down Expand Up @@ -1125,6 +1125,39 @@ def _add_train_general_opts(parser):
help="Type of the source input. " "Options are [text].",
)

group.add(
"-bucket_size",
"--bucket_size",
type=int,
default=262144,
help="""A bucket is a buffer of bucket_size examples to pick
from the various Corpora. The dynamic iterator batches
batch_size batchs from the bucket and shuffle them.""",
)
group.add(
"-bucket_size_init",
"--bucket_size_init",
type=int,
default=-1,
help="""The bucket is initalized with this awith this
amount of examples (optional)""",
)
group.add(
"-bucket_size_increment",
"--bucket_size_increment",
type=int,
default=0,
help="""The bucket size is incremented with this
amount of examples (optional)""",
)
group.add(
"-prefetch_factor",
"--prefetch_factor",
type=int,
default=200,
help="""number of mini-batches loaded in advance to avoid the
GPU waiting during the refilling of the bucket.""",
)
group.add(
"--save_model",
"-save_model",
Expand Down Expand Up @@ -1541,43 +1574,6 @@ def _add_train_general_opts(parser):
_add_logging_opts(parser, is_train=True)


def _add_train_dynamic_data(parser):
group = parser.add_argument_group("Dynamic data")
group.add(
"-bucket_size",
"--bucket_size",
type=int,
default=262144,
help="""A bucket is a buffer of bucket_size examples to pick
from the various Corpora. The dynamic iterator batches
batch_size batchs from the bucket and shuffle them.""",
)
group.add(
"-bucket_size_init",
"--bucket_size_init",
type=int,
default=-1,
help="""The bucket is initalized with this awith this
amount of examples (optional)""",
)
group.add(
"-bucket_size_increment",
"--bucket_size_increment",
type=int,
default=0,
help="""The bucket size is incremented with this
amount of examples (optional)""",
)
group.add(
"-prefetch_factor",
"--prefetch_factor",
type=int,
default=200,
help="""number of mini-batches loaded in advance to avoid the
GPU waiting during the refilling of the bucket.""",
)


def _add_quant_opts(parser):
group = parser.add_argument_group("Quant options")
group.add(
Expand Down Expand Up @@ -1624,13 +1620,10 @@ def _add_quant_opts(parser):

def train_opts(parser):
"""All options used in train."""
# options relate to data preprare
dynamic_prepare_opts(parser, build_vocab_only=False)
data_prepare_opts(parser, build_vocab_only=False)
distributed_opts(parser)
# options relate to train
model_opts(parser)
_add_train_general_opts(parser)
_add_train_dynamic_data(parser)
_add_quant_opts(parser)


Expand Down Expand Up @@ -1796,8 +1789,9 @@ def _add_decoding_opts(parser):
)


def translate_opts(parser, dynamic=False):
def translate_opts(parser):
"""Translation / inference options"""
config_opts(parser)
group = parser.add_argument_group("Model")
group.add(
"--model",
Expand Down Expand Up @@ -1929,18 +1923,17 @@ def translate_opts(parser, dynamic=False):
)
group.add("--gpu", "-gpu", type=int, default=-1, help="Device to run on")

if dynamic:
group.add(
"-transforms",
"--transforms",
default=[],
nargs="+",
choices=AVAILABLE_TRANSFORMS.keys(),
help="Default transform pipeline to apply to data.",
)
group.add(
"-transforms",
"--transforms",
default=[],
nargs="+",
choices=AVAILABLE_TRANSFORMS.keys(),
help="Default transform pipeline to apply to data.",
)

# Adding options related to Transforms
_add_dynamic_transform_opts(parser)
# Adding options related to Transforms
_add_transform_opts(parser)

_add_quant_opts(parser)

Expand Down
4 changes: 2 additions & 2 deletions onmt/tests/test_data_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os

from onmt.utils.parse import ArgumentParser
from onmt.opts import dynamic_prepare_opts
from onmt.opts import data_prepare_opts
from onmt.train_single import prepare_transforms_vocabs
from onmt.constants import CorpusName

Expand All @@ -17,7 +17,7 @@

def get_default_opts():
parser = ArgumentParser(description="data sample prepare")
dynamic_prepare_opts(parser)
data_prepare_opts(parser)

default_opts = [
"-config",
Expand Down
3 changes: 1 addition & 2 deletions onmt/tests/test_inference_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

def _get_parser():
parser = ArgumentParser(description="simple_inference_engine_py.py")
opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
return parser


Expand Down
3 changes: 1 addition & 2 deletions onmt/utils/scoring_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from onmt.utils.parse import ArgumentParser
from onmt.translate import GNMTGlobalScorer, Translator
from onmt.opts import config_opts, translate_opts
from onmt.opts import translate_opts
from onmt.constants import CorpusTask
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe
Expand Down Expand Up @@ -51,7 +51,6 @@ def translate(self, model, gpu_rank, step):

# Set "default" translation options on empty cfgfile
parser = ArgumentParser()
config_opts(parser)
translate_opts(parser)
base_args = ["-model", "dummy"] + ["-src", "dummy"]
opt = parser.parse_args(base_args)
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
"onmt_server=onmt.bin.server:main",
"onmt_train=onmt.bin.train:main",
"onmt_translate=onmt.bin.translate:main",
"onmt_translate_dynamic=onmt.bin.translate_dynamic:main",
"onmt_release_model=onmt.bin.release_model:main",
"onmt_average_models=onmt.bin.average_models:main",
"onmt_build_vocab=onmt.bin.build_vocab:main",
Expand Down
3 changes: 1 addition & 2 deletions tools/LM_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@

def _get_parser():
parser = ArgumentParser(description="LM_scoring.py")
opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
return parser


Expand Down

0 comments on commit 46c1397

Please sign in to comment.