diff --git a/README.md b/README.md index 7d53c0f2..1220456b 100644 --- a/README.md +++ b/README.md @@ -177,8 +177,8 @@ Returns:
-In the simplest case, `dcargs.cli()` can be used to run a function with -arguments populated from the CLI. +In the simplest case, `dcargs.cli()` can be used to run a function with arguments +populated from the CLI. **Code ([link](examples/01_functions.py)):** @@ -238,8 +238,8 @@ hello 10
-Common pattern: use `dcargs.cli()` to instantiate a dataclass. The outputted -instance can be used as a typed alternative for an argparse namespace. +Common pattern: use `dcargs.cli()` to instantiate a dataclass. The outputted instance +can be used as a typed alternative for an argparse namespace. **Code ([link](examples/02_dataclasses.py)):** @@ -299,8 +299,8 @@ Args(field1='hello', field2=5)
-We can generate argument parsers from more advanced type annotations, like enums -and tuple types. +We can generate argument parsers from more advanced type annotations, like enums and +tuple types. **Code ([link](examples/03_enums_and_containers.py)):** @@ -389,8 +389,8 @@ TrainConfig(dataset_sources=(PosixPath('data'),), image_dimensions=(32
-Booleans can either be expected to be explicitly passed in, or, if given a -default value, automatically converted to flags. +Booleans can either be expected to be explicitly passed in, or, if given a default +value, automatically converted to flags. **Code ([link](examples/04_flags.py)):** @@ -629,9 +629,9 @@ usage: 05_hierarchical_configs.py [-h] --out-dir PATH
-We can integrate `dcargs.cli()` into common configuration patterns: here, we -select one of multiple possible base configurations, and then use the CLI to -either override (existing) or fill in (missing) values. +We can integrate `dcargs.cli()` into common configuration patterns: here, we select +one of multiple possible base configurations, and then use the CLI to either override +(existing) or fill in (missing) values. **Code ([link](examples/06_base_configs.py)):** @@ -775,7 +775,7 @@ arguments: (required) --activation {<class 'torch.nn.modules.activation.ReLU'>} Activation to use. Not specifiable via the - commandline. (not parsable) + commandline. (fixed) optimizer arguments: Optimizer parameters. @@ -814,7 +814,7 @@ arguments: (required) --activation {<class 'torch.nn.modules.activation.GELU'>} Activation to use. Not specifiable via the - commandline. (not parsable) + commandline. (fixed) optimizer arguments: Optimizer parameters. @@ -839,9 +839,8 @@ ExperimentConfig(dataset='imagenet-50', optimizer=AdamOptimizer(learni
-`typing.Literal[]` can be used to restrict inputs to a fixed set of literal -choices; `typing.Union[]` can be used to restrict inputs to a fixed set of -types. +`typing.Literal[]` can be used to restrict inputs to a fixed set of literal choices; +`typing.Union[]` can be used to restrict inputs to a fixed set of types. **Code ([link](examples/07_literals_and_unions.py)):** @@ -932,8 +931,7 @@ arguments:
-Positional-only arguments in functions are converted to positional CLI -arguments. +Positional-only arguments in functions are converted to positional CLI arguments. **Code ([link](examples/08_positional_args.py)):** @@ -1053,8 +1051,7 @@ background_rgb=(1.0, 0.0, 0.0)
-Unions over nested types (classes or dataclasses) are populated using -subparsers. +Unions over nested types (classes or dataclasses) are populated using subparsers. **Code ([link](examples/09_subparsers.py)):** @@ -1275,11 +1272,8 @@ AdamOptimizer(learning_rate=0.0003, betas=(0.9, 0.999))
-Dictionary inputs can be specified using either a standard `Dict[K, V]` -annotation, or a `TypedDict` type. - -Note that setting `total=False` for `TypedDict` is currently not (but reasonably -could be) supported. +Dictionary inputs can be specified using either a standard `Dict[K, V]` annotation, +or a `TypedDict` type. **Code ([link](examples/11_dictionaries.py)):** @@ -1289,21 +1283,22 @@ from typing import Dict, Tuple, TypedDict import dcargs -class DictionarySchema(TypedDict): +class DictionarySchema( + TypedDict, + # Setting `total=False` specifies that not all keys need to exist. + total=False, +): learning_rate: float betas: Tuple[float, float] def main( + typed_dict: DictionarySchema, standard_dict: Dict[str, float] = { "learning_rate": 3e-4, "beta1": 0.9, "beta2": 0.999, }, - typed_dict: DictionarySchema = { - "learning_rate": 3e-4, - "betas": (0.9, 0.999), - }, ) -> None: assert isinstance(standard_dict, dict) assert isinstance(typed_dict, dict) @@ -1321,9 +1316,9 @@ if __name__ == "__main__":
 $ python ./11_dictionaries.py --help
-usage: 11_dictionaries.py [-h] [--standard-dict STR FLOAT [STR FLOAT ...]]
-                          [--typed-dict.learning-rate FLOAT]
+usage: 11_dictionaries.py [-h] [--typed-dict.learning-rate FLOAT]
                           [--typed-dict.betas FLOAT FLOAT]
+                          [--standard-dict STR FLOAT [STR FLOAT ...]]
 
 arguments:
   -h, --help            show this help message and exit
@@ -1334,9 +1329,23 @@ arguments:
 typed_dict arguments:
 
   --typed-dict.learning-rate FLOAT
-                        (default: 0.0003)
+                        Setting `total=False` specifies that not all keys need
+                        to exist. (unset by default)
   --typed-dict.betas FLOAT FLOAT
-                        (default: 0.9 0.999)
+                        Setting `total=False` specifies that not all keys need
+                        to exist. (unset by default)
+
+ +
+$ python ./11_dictionaries.py --typed-dict.learning-rate 3e-4
+Standard dict: {'learning_rate': 0.0003, 'beta1': 0.9, 'beta2': 0.999}
+Typed dict: {'learning_rate': 0.0003}
+
+ +
+$ python ./11_dictionaries.py --typed-dict.betas 0.9 0.999
+Standard dict: {'learning_rate': 0.0003, 'beta1': 0.9, 'beta2': 0.999}
+Typed dict: {'betas': (0.9, 0.999)}
 
diff --git a/dcargs/_arguments.py b/dcargs/_arguments.py index efc216d9..36663042 100644 --- a/dcargs/_arguments.py +++ b/dcargs/_arguments.py @@ -263,6 +263,8 @@ def _rule_generate_helptext( default_text = f"(sets: {arg.field.name}=True)" elif lowered.action == "store_false": default_text = f"(sets: {arg.field.name}=False)" + elif arg.field.default is _fields.EXCLUDE_FROM_CALL: + default_text = "(unset by default)" elif lowered.nargs is not None and hasattr(default, "__iter__"): # For tuple types, we might have default as (0, 1, 2, 3). # For list types, we might have default as [0, 1, 2, 3]. diff --git a/dcargs/_calling.py b/dcargs/_calling.py index 580bd51e..3d7aaa96 100644 --- a/dcargs/_calling.py +++ b/dcargs/_calling.py @@ -155,9 +155,10 @@ def get_value_from_arg(prefixed_field_name: str) -> Any: ) consumed_keywords |= consumed_keywords_child - if field.positional: - args.append(value) - else: - kwargs[field.name] = value + if value is not _fields.EXCLUDE_FROM_CALL: + if field.positional: + args.append(value) + else: + kwargs[field.name] = value return f(*args, **kwargs), consumed_keywords # type: ignore diff --git a/dcargs/_fields.py b/dcargs/_fields.py index 1c905587..4cdc890f 100644 --- a/dcargs/_fields.py +++ b/dcargs/_fields.py @@ -9,7 +9,7 @@ import docstring_parser from typing_extensions import get_type_hints, is_typeddict -from . import _docstrings, _resolver +from . import _docstrings, _instantiators, _parsers, _resolver @dataclasses.dataclass(frozen=True) @@ -44,11 +44,16 @@ class NonpropagatingMissingType(_Singleton): pass +class ExcludeFromKwargsType(_Singleton): + pass + + # We have two types of missing sentinels: a propagating missing value, which when set as # a default will set all child values of nested structures as missing as well, and a # nonpropagating missing sentinel, which does not override child defaults. MISSING_PROP = PropagatingMissingType() MISSING_NONPROP = NonpropagatingMissingType() +EXCLUDE_FROM_CALL = ExcludeFromKwargsType() # Note that our "public" missing API will always be the propagating missing sentinel. MISSING_PUBLIC: Any = MISSING_PROP @@ -98,16 +103,28 @@ def field_list_from_callable( if cls is not None and is_typeddict(cls): # Handle typed dictionaries. field_list = [] - no_default_instance = default_instance in MISSING_SINGLETONS - assert no_default_instance or isinstance(default_instance, dict) + valid_default_instance = ( + default_instance not in MISSING_SINGLETONS + and default_instance is not EXCLUDE_FROM_CALL + ) + assert not valid_default_instance or isinstance(default_instance, dict) for name, typ in get_type_hints(cls).items(): + if valid_default_instance: + default = default_instance.get(name, MISSING_PROP) # type: ignore + elif getattr(cls, "__total__") is False: + default = EXCLUDE_FROM_CALL + if _parsers.is_possibly_nested_type(typ): + raise _instantiators.UnsupportedTypeAnnotationError( + "`total=False` not supported for nested structures." + ) + else: + default = MISSING_PROP + field_list.append( FieldDefinition( name=name, typ=typ, - default=MISSING_PROP - if no_default_instance - else default_instance.get(name, MISSING_PROP), # type: ignore + default=default, helptext=_docstrings.get_field_docstring(cls, name), positional=False, ) diff --git a/dcargs/_parsers.py b/dcargs/_parsers.py index 87dbdef5..aff0fce1 100644 --- a/dcargs/_parsers.py +++ b/dcargs/_parsers.py @@ -50,7 +50,7 @@ ) -def _is_possibly_nested_type(typ: Any) -> bool: +def is_possibly_nested_type(typ: Any) -> bool: """Heuristics for determining whether a type can be treated as a 'nested type', where a single field has multiple corresponding arguments (eg for nested dataclasses or classes). @@ -189,10 +189,10 @@ def from_callable( continue else: field = dataclasses.replace(field, typ=type(field.default)) - assert _is_possibly_nested_type(field.typ) + assert is_possibly_nested_type(field.typ) # (2) Handle nested callables. - if _is_possibly_nested_type(field.typ): + if is_possibly_nested_type(field.typ): nested_parser = ParserSpecification.from_callable( field.typ, description=None, @@ -361,7 +361,7 @@ def from_field( # We don't use sets here to retain order of subcommands. options = [type_from_typevar.get(typ, typ) for typ in get_args(field.typ)] options_no_none = [o for o in options if o != type(None)] # noqa - if not all(map(_is_possibly_nested_type, options_no_none)): + if not all(map(is_possibly_nested_type, options_no_none)): return None parser_from_name: Dict[str, ParserSpecification] = {} diff --git a/examples/11_dictionaries.py b/examples/11_dictionaries.py index 4f51ca7e..ef788b71 100644 --- a/examples/11_dictionaries.py +++ b/examples/11_dictionaries.py @@ -1,11 +1,10 @@ """Dictionary inputs can be specified using either a standard `Dict[K, V]` annotation, or a `TypedDict` type. -Note that setting `total=False` for `TypedDict` is currently not (but reasonably could be) -supported. - Usage: `python ./11_dictionaries.py --help` +`python ./11_dictionaries.py --typed-dict.learning-rate 3e-4` +`python ./11_dictionaries.py --typed-dict.betas 0.9 0.999` """ from typing import Dict, Tuple, TypedDict @@ -13,21 +12,22 @@ import dcargs -class DictionarySchema(TypedDict): +class DictionarySchema( + TypedDict, + # Setting `total=False` specifies that not all keys need to exist. + total=False, +): learning_rate: float betas: Tuple[float, float] def main( + typed_dict: DictionarySchema, standard_dict: Dict[str, float] = { "learning_rate": 3e-4, "beta1": 0.9, "beta2": 0.999, }, - typed_dict: DictionarySchema = { - "learning_rate": 3e-4, - "betas": (0.9, 0.999), - }, ) -> None: assert isinstance(standard_dict, dict) assert isinstance(typed_dict, dict) diff --git a/tests/test_dict_namedtuple.py b/tests/test_dict_namedtuple.py index 3b957bf9..cfe791b0 100644 --- a/tests/test_dict_namedtuple.py +++ b/tests/test_dict_namedtuple.py @@ -63,22 +63,55 @@ def test_basic_typeddict(): class ManyTypesTypedDict(TypedDict): i: int s: str - f: float - p: pathlib.Path assert dcargs.cli( ManyTypesTypedDict, - args=[ - "--i", - "5", - "--s", - "5", - "--f", - "5", - "--p", - "~", - ], - ) == dict(i=5, s="5", f=5.0, p=pathlib.Path("~")) + args="--i 5 --s 5".split(" "), + ) == dict(i=5, s="5") + + with pytest.raises(SystemExit): + dcargs.cli(ManyTypesTypedDict, args="--i 5".split(" ")) + + with pytest.raises(SystemExit): + dcargs.cli(ManyTypesTypedDict, args="--s 5".split(" ")) + + +def test_total_false_typeddict(): + class ManyTypesTypedDict(TypedDict, total=False): + i: int + s: str + + assert dcargs.cli( + ManyTypesTypedDict, + args="--i 5 --s 5".split(" "), + ) == dict(i=5, s="5") + + assert dcargs.cli(ManyTypesTypedDict, args="--i 5".split(" ")) == dict(i=5) + assert dcargs.cli(ManyTypesTypedDict, args="--s 5".split(" ")) == dict(s="5") + + +def test_total_false_nested_typeddict(): + class ChildTypedDict(TypedDict, total=False): + i: int + s: str + + class ParentTypedDict(TypedDict, total=False): + child: ChildTypedDict + + with pytest.raises(dcargs.UnsupportedTypeAnnotationError): + dcargs.cli( + ParentTypedDict, + args="--child.i 5 --child.s 5".split(" "), + ) + + with pytest.raises(dcargs.UnsupportedTypeAnnotationError): + assert ( + dcargs.cli( + ParentTypedDict, + args=[""], + ) + == {} + ) def test_nested_typeddict():