Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add use_underscores option #76

Merged
merged 4 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tests/helptext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
import tyro._strings


def get_helptext(f: Callable, args: List[str] = ["--help"]) -> str:
def get_helptext(
f: Callable, args: List[str] = ["--help"], use_underscores: bool = False
) -> str:
target = io.StringIO()
with pytest.raises(SystemExit), contextlib.redirect_stdout(target):
tyro.cli(f, args=args)
tyro.cli(f, args=args, use_underscores=use_underscores)

# Check tyro.extras.get_parser().
parser = tyro.extras.get_parser(f)
parser = tyro.extras.get_parser(f, use_underscores=use_underscores)
assert isinstance(parser, argparse.ArgumentParser)

# Returned parser should have formatting information stripped. External tools rarely
Expand All @@ -44,7 +46,7 @@ def get_helptext(f: Callable, args: List[str] = ["--help"]) -> str:
target2 = io.StringIO()
with pytest.raises(SystemExit), contextlib.redirect_stdout(target2):
tyro._arguments.USE_RICH = False
tyro.cli(f, args=args)
tyro.cli(f, args=args, use_underscores=use_underscores)
tyro._arguments.USE_RICH = True

if target2.getvalue() != tyro._strings.strip_ansi_sequences(target.getvalue()):
Expand Down
31 changes: 31 additions & 0 deletions tests/test_dcargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,3 +688,34 @@ class A:

assert tyro.cli(A, args="--x".split(" ")).x == ()
assert tyro.cli(A, args="--y".split(" ")).y == []


def test_unknown_args_with_consistent_duplicates_use_underscores() -> None:
@dataclasses.dataclass
class A:
a_b: List[int] = dataclasses.field(default_factory=list)
c_d: List[int] = dataclasses.field(default_factory=list)

# Tests logic for consistent duplicate arguments when performing argument fixing.
# i.e., we can fix arguments if the separator is consistent (all _'s or all -'s).
a, unknown_args = tyro.cli(
A,
args=[
"--a-b",
"5",
"--a-b",
"7",
"--c_d",
"5",
"--c_d",
"7",
"--e-f",
"--e-f",
"--g_h",
"--g_h",
],
return_unknown_args=True,
use_underscores=True,
)
assert a == A(a_b=[7], c_d=[7])
assert unknown_args == ["--e-f", "--e-f", "--g_h", "--g_h"]
84 changes: 84 additions & 0 deletions tests/test_dict_namedtuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,36 @@ class HelptextNamedTuple(NamedTuple):


def test_nested_dict() -> None:
loaded_config = {
"batch_size": 32,
"optimizer": {
"learning_rate": 1e-4,
"epsilon": 1e-8,
"scheduler": {"schedule_type": "constant"},
},
}
backup_config = copy.deepcopy(loaded_config)
overrided_config = tyro.cli(
dict,
default=loaded_config,
args=[
"--batch-size",
"16",
"--optimizer.scheduler.schedule_type",
"exponential",
],
)

# Overridden config should be different from loaded config.
assert overrided_config != loaded_config
assert overrided_config["batch_size"] == 16
assert overrided_config["optimizer"]["scheduler"]["schedule_type"] == "exponential"

# Original loaded config should not be mutated.
assert loaded_config == backup_config


def test_nested_dict_use_underscores() -> None:
loaded_config = {
"batch_size": 32,
"optimizer": {
Expand All @@ -329,6 +359,7 @@ def test_nested_dict() -> None:
"--optimizer.scheduler.schedule-type",
"exponential",
],
use_underscores=True,
)

# Overridden config should be different from loaded config.
Expand Down Expand Up @@ -372,6 +403,59 @@ def test_nested_dict_hyphen() -> None:
assert loaded_config == backup_config


def test_nested_dict_hyphen_use_underscores() -> None:
# We do a lot of underscore <=> conversion in the code; this is just to make sure it
# doesn't break anything!
loaded_config = {
"batch-size": 32,
"optimizer": {
"learning-rate": 1e-4,
"epsilon": 1e-8,
"scheduler": {"schedule-type": "constant"},
},
}
backup_config = copy.deepcopy(loaded_config)
overrided_config = tyro.cli(
dict,
default=loaded_config,
args=[
"--batch-size",
"16",
"--optimizer.scheduler.schedule-type",
"exponential",
],
use_underscores=True,
)

# Overridden config should be different from loaded config.
assert overrided_config != loaded_config
assert overrided_config["batch-size"] == 16
assert overrided_config["optimizer"]["scheduler"]["schedule-type"] == "exponential"

# Original loaded config should not be mutated.
assert loaded_config == backup_config

overrided_config = tyro.cli(
dict,
default=loaded_config,
args=[
"--batch_size",
"16",
"--optimizer.scheduler.schedule_type",
"exponential",
],
use_underscores=True,
)

# Overridden config should be different from loaded config.
assert overrided_config != loaded_config
assert overrided_config["batch-size"] == 16
assert overrided_config["optimizer"]["scheduler"]["schedule-type"] == "exponential"

# Original loaded config should not be mutated.
assert loaded_config == backup_config


def test_nested_dict_annotations() -> None:
loaded_config = {
"optimizer": {
Expand Down
98 changes: 98 additions & 0 deletions tests/test_helptext.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,101 @@ def main(child: Child) -> None:

helptext = get_helptext(main)
assert "--child.x | --child.no-x" in helptext


def test_multiple_subparsers_helptext_hyphens() -> None:
@dataclasses.dataclass
class SubcommandOne:
"""2% milk.""" # % symbol is prone to bugs in argparse.

arg_x: int = 0
arg_flag: bool = False

@dataclasses.dataclass
class SubcommandTwo:
arg_y: int = 1

@dataclasses.dataclass
class SubcommandThree:
arg_z: int = 2

@dataclasses.dataclass
class MultipleSubparsers:
# Field a description.
a: Union[SubcommandOne, SubcommandTwo, SubcommandThree]
# Field b description.
b: Union[SubcommandOne, SubcommandTwo, SubcommandThree]
# Field c description.
c: Union[SubcommandOne, SubcommandTwo, SubcommandThree] = dataclasses.field(
default_factory=SubcommandThree
)

helptext = get_helptext(MultipleSubparsers)

assert "2% milk." in helptext
assert "Field a description." in helptext
assert "Field b description." not in helptext
assert "Field c description." not in helptext

helptext = get_helptext(
MultipleSubparsers, args=["a:subcommand-one", "b:subcommand-one", "--help"]
)

assert "2% milk." in helptext
assert "Field a description." not in helptext
assert "Field b description." not in helptext
assert "Field c description." in helptext
assert "(default: c:subcommand-three)" in helptext
assert "--b.arg-x" in helptext
assert "--b.no-arg-flag" in helptext
assert "--b.arg-flag" in helptext


def test_multiple_subparsers_helptext_underscores() -> None:
@dataclasses.dataclass
class SubcommandOne:
"""2% milk.""" # % symbol is prone to bugs in argparse.

arg_x: int = 0
arg_flag: bool = False

@dataclasses.dataclass
class SubcommandTwo:
arg_y: int = 1

@dataclasses.dataclass
class SubcommandThree:
arg_z: int = 2

@dataclasses.dataclass
class MultipleSubparsers:
# Field a description.
a: Union[SubcommandOne, SubcommandTwo, SubcommandThree]
# Field b description.
b: Union[SubcommandOne, SubcommandTwo, SubcommandThree]
# Field c description.
c: Union[SubcommandOne, SubcommandTwo, SubcommandThree] = dataclasses.field(
default_factory=SubcommandThree
)

helptext = get_helptext(MultipleSubparsers, use_underscores=True)

assert "2% milk." in helptext
assert "Field a description." in helptext
assert "Field b description." not in helptext
assert "Field c description." not in helptext

helptext = get_helptext(
MultipleSubparsers,
args=["a:subcommand_one", "b:subcommand_one", "--help"],
use_underscores=True,
)

assert "2% milk." in helptext
assert "Field a description." not in helptext
assert "Field b description." not in helptext
assert "Field c description." in helptext
assert "(default: c:subcommand_three)" in helptext
assert "--b.arg_x" in helptext
assert "--b.no_arg_flag" in helptext
assert "--b.arg_flag" in helptext
8 changes: 5 additions & 3 deletions tyro/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@ def __init__(

if option_string.startswith("--"):
if "." not in option_string:
option_string = "--no-" + option_string[2:]
option_string = (
"--no" + _strings.get_delimeter() + option_string[2:]
)
else:
# Loose heuristic for where to add the no- prefix.
# Loose heuristic for where to add the no-/no_ prefix.
left, _, right = option_string.rpartition(".")
option_string = left + ".no-" + right
option_string = left + ".no" + _strings.get_delimeter() + right
self._no_strings.add(option_string)

_option_strings.append(option_string)
Expand Down
Loading
Loading