Skip to content

Commit

Permalink
Final tweaks for 0.3.0
Browse files Browse the repository at this point in the history
- Make --field_name and --field-name interchangeable
- Add prefix_names option to extras.subcommand_type_from_defaults()
  • Loading branch information
brentyi committed Sep 7, 2022
1 parent b65d1f1 commit a24b209
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 16 deletions.
16 changes: 15 additions & 1 deletion dcargs/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,26 @@ def cli(
prefix="", # Used for recursive calls.
)

# Read and fix arguments. If the user passes in --field_name instead of
# --field-name, correct for them.
args = sys.argv[1:] if args is None else args

def fix_arg(arg: str) -> str:
if not arg.startswith("--"):
return arg
if "=" in arg:
arg, _, val = arg.partition("=")
return arg.replace("_", "-") + "=" + val
else:
return arg.replace("_", "-")

args = list(map(fix_arg, args))

# If we pass in the --dcargs-print-completion flag: turn termcolor off, and get the
# shell we want to generate a completion script for (bash/zsh/tcsh).
#
# Note that shtab also offers an add_argument_to() functions that fulfills a similar
# goal, but manual parsing of argv is convenient for turning off colors.
args = sys.argv[1:] if args is None else args
print_completion = len(args) >= 2 and args[0] == "--dcargs-print-completion"

formatting_context = _argparse_formatter.ansi_context()
Expand Down
2 changes: 1 addition & 1 deletion dcargs/_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from . import _resolver

dummy_field_name = "__dcargs_dummy_field_name__"
dummy_field_name = "__dcargs_dummy_field__"


def _strip_dummy_field_names(parts: Iterable[str]) -> Iterable[str]:
Expand Down
4 changes: 2 additions & 2 deletions dcargs/extras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""The :mod:`dcargs.extras` submodule contains helpers that complement :func:`dcargs.cli()`, but
aren't considered part of the core interface."""

from ._base_configs import subcommand_union_from_mapping
from ._base_configs import subcommand_type_from_defaults
from ._serialization import from_yaml, to_yaml

__all__ = ["subcommand_union_from_mapping", "to_yaml", "from_yaml"]
__all__ = ["subcommand_type_from_defaults", "to_yaml", "from_yaml"]
28 changes: 21 additions & 7 deletions dcargs/extras/_base_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
T = TypeVar("T")


def subcommand_union_from_mapping(
default_from_name: Mapping[str, T], descriptions: Mapping[str, str] = {}
def subcommand_type_from_defaults(
defaults: Mapping[str, T],
descriptions: Mapping[str, str] = {},
*,
prefix_names: bool = True,
) -> Type[T]:
"""Returns a Union type for defining subcommands that choose between nested types.
"""Construct a Union type for defining subcommands that choose between defaults.
For example, when `default` is set to:
This can most commonly be used to create a "base configuration" pattern:
https://brentyi.github.io/dcargs/examples/10_base_configs/
For example, when `defaults` is set to:
```python
{
Expand All @@ -36,7 +42,7 @@ def subcommand_union_from_mapping(
]
```
This can be used directly in dcargs.cli:
The resulting type can be used directly in dcargs.cli:
```python
config = dcargs.cli(subcommand_union_from_mapping(default_from_name))
Expand Down Expand Up @@ -70,8 +76,16 @@ def train(
return Union.__getitem__( # type: ignore
tuple(
Annotated.__class_getitem__( # type: ignore
(type(v), subcommand(k, default=v, description=descriptions.get(k)))
(
type(v),
subcommand(
k,
default=v,
description=descriptions.get(k),
prefix_name=prefix_names,
),
)
)
for k, v in default_from_name.items()
for k, v in defaults.items()
)
)
2 changes: 1 addition & 1 deletion examples/05_hierarchical_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class OptimizerConfig:
@dataclasses.dataclass
class ExperimentConfig:
# Various configurable options for our optimizer.
optimizer: OptimizerConfig
optimizer_config: OptimizerConfig

# Batch size.
batch_size: int = 32
Expand Down
4 changes: 2 additions & 2 deletions examples/10_base_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ class ExperimentConfig:

if __name__ == "__main__":
config = dcargs.cli(
dcargs.extras.subcommand_union_from_mapping(base_configs, descriptions),
dcargs.extras.subcommand_type_from_defaults(base_configs, descriptions),
)
# Note that this is equivalent to:
# ^Note that this is equivalent to:
#
# config = dcargs.cli(
# Union[
Expand Down
25 changes: 25 additions & 0 deletions tests/test_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,31 @@ class Nested:
dcargs.cli(Nested, args=["--x", "1"])


def test_nested_accidental_underscores():
@dataclasses.dataclass
class B:
arg_name: str

@dataclasses.dataclass
class Nested:
x: int
child_struct: B

assert (
dcargs.cli(Nested, args=["--x", "1", "--child-struct.arg-name", "three_five"])
== dcargs.cli(
Nested, args=["--x", "1", "--child_struct.arg_name", "three_five"]
)
== dcargs.cli(
Nested, args=["--x", "1", "--child_struct.arg-name", "three_five"]
)
== dcargs.cli(Nested, args=["--x", "1", "--child_struct.arg_name=three_five"])
== Nested(x=1, child_struct=B(arg_name="three_five"))
)
with pytest.raises(SystemExit):
dcargs.cli(Nested, args=["--x", "1"])


def test_nested_default():
@dataclasses.dataclass
class B:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_union_from_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_union_from_mapping():
"two": A(2),
"three": A(3),
}
ConfigUnion = dcargs.extras.subcommand_union_from_mapping(base_configs)
ConfigUnion = dcargs.extras.subcommand_type_from_defaults(base_configs)

assert dcargs.cli(ConfigUnion, args="one".split(" ")) == A(1)
assert dcargs.cli(ConfigUnion, args="two".split(" ")) == A(2)
Expand All @@ -32,7 +32,7 @@ def test_union_from_mapping_in_function():

# Hack for mypy. Not needed for pyright.
ConfigUnion = A
ConfigUnion = dcargs.extras.subcommand_union_from_mapping(base_configs) # type: ignore
ConfigUnion = dcargs.extras.subcommand_type_from_defaults(base_configs) # type: ignore

def main(config: ConfigUnion, flag: bool = False) -> Optional[A]:
if flag:
Expand Down

0 comments on commit a24b209

Please sign in to comment.