Skip to content

Commit

Permalink
Support total=False for TypedDict
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 18, 2022
1 parent 745c13f commit ae521cb
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 69 deletions.
77 changes: 43 additions & 34 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ Returns:
</summary>
<blockquote>

In the simplest case, `dcargs.cli()` can be used to run a function with
arguments populated from the CLI.
In the simplest case, `dcargs.cli()` can be used to run a function with arguments
populated from the CLI.

**Code ([link](examples/01_functions.py)):**

Expand Down Expand Up @@ -238,8 +238,8 @@ hello 10</samp>
</summary>
<blockquote>

Common pattern: use `dcargs.cli()` to instantiate a dataclass. The outputted
instance can be used as a typed alternative for an argparse namespace.
Common pattern: use `dcargs.cli()` to instantiate a dataclass. The outputted instance
can be used as a typed alternative for an argparse namespace.

**Code ([link](examples/02_dataclasses.py)):**

Expand Down Expand Up @@ -299,8 +299,8 @@ Args(field1=&#x27;hello&#x27;, field2=5)</samp>
</summary>
<blockquote>

We can generate argument parsers from more advanced type annotations, like enums
and tuple types.
We can generate argument parsers from more advanced type annotations, like enums and
tuple types.

**Code ([link](examples/03_enums_and_containers.py)):**

Expand Down Expand Up @@ -389,8 +389,8 @@ TrainConfig(dataset_sources=(PosixPath(&#x27;data&#x27;),), image_dimensions=(32
</summary>
<blockquote>

Booleans can either be expected to be explicitly passed in, or, if given a
default value, automatically converted to flags.
Booleans can either be expected to be explicitly passed in, or, if given a default
value, automatically converted to flags.

**Code ([link](examples/04_flags.py)):**

Expand Down Expand Up @@ -629,9 +629,9 @@ usage: 05_hierarchical_configs.py [-h] --out-dir PATH
</summary>
<blockquote>

We can integrate `dcargs.cli()` into common configuration patterns: here, we
select one of multiple possible base configurations, and then use the CLI to
either override (existing) or fill in (missing) values.
We can integrate `dcargs.cli()` into common configuration patterns: here, we select
one of multiple possible base configurations, and then use the CLI to either override
(existing) or fill in (missing) values.

**Code ([link](examples/06_base_configs.py)):**

Expand Down Expand Up @@ -775,7 +775,7 @@ arguments:
(required)
--activation {&lt;class &#x27;torch.nn.modules.activation.ReLU&#x27;&gt;}
Activation to use. Not specifiable via the
commandline. (not parsable)
commandline. (fixed)

optimizer arguments:
Optimizer parameters.
Expand Down Expand Up @@ -814,7 +814,7 @@ arguments:
(required)
--activation {&lt;class &#x27;torch.nn.modules.activation.GELU&#x27;&gt;}
Activation to use. Not specifiable via the
commandline. (not parsable)
commandline. (fixed)

optimizer arguments:
Optimizer parameters.
Expand All @@ -839,9 +839,8 @@ ExperimentConfig(dataset=&#x27;imagenet-50&#x27;, optimizer=AdamOptimizer(learni
</summary>
<blockquote>

`typing.Literal[]` can be used to restrict inputs to a fixed set of literal
choices; `typing.Union[]` can be used to restrict inputs to a fixed set of
types.
`typing.Literal[]` can be used to restrict inputs to a fixed set of literal choices;
`typing.Union[]` can be used to restrict inputs to a fixed set of types.

**Code ([link](examples/07_literals_and_unions.py)):**

Expand Down Expand Up @@ -932,8 +931,7 @@ arguments:
</summary>
<blockquote>

Positional-only arguments in functions are converted to positional CLI
arguments.
Positional-only arguments in functions are converted to positional CLI arguments.

**Code ([link](examples/08_positional_args.py)):**

Expand Down Expand Up @@ -1053,8 +1051,7 @@ background_rgb=(1.0, 0.0, 0.0)</samp>
</summary>
<blockquote>

Unions over nested types (classes or dataclasses) are populated using
subparsers.
Unions over nested types (classes or dataclasses) are populated using subparsers.

**Code ([link](examples/09_subparsers.py)):**

Expand Down Expand Up @@ -1275,11 +1272,8 @@ AdamOptimizer(learning_rate=0.0003, betas=(0.9, 0.999))</samp>
</summary>
<blockquote>

Dictionary inputs can be specified using either a standard `Dict[K, V]`
annotation, or a `TypedDict` type.

Note that setting `total=False` for `TypedDict` is currently not (but reasonably
could be) supported.
Dictionary inputs can be specified using either a standard `Dict[K, V]` annotation,
or a `TypedDict` type.

**Code ([link](examples/11_dictionaries.py)):**

Expand All @@ -1289,21 +1283,22 @@ from typing import Dict, Tuple, TypedDict
import dcargs


class DictionarySchema(TypedDict):
class DictionarySchema(
TypedDict,
# Setting `total=False` specifies that not all keys need to exist.
total=False,
):
learning_rate: float
betas: Tuple[float, float]


def main(
typed_dict: DictionarySchema,
standard_dict: Dict[str, float] = {
"learning_rate": 3e-4,
"beta1": 0.9,
"beta2": 0.999,
},
typed_dict: DictionarySchema = {
"learning_rate": 3e-4,
"betas": (0.9, 0.999),
},
) -> None:
assert isinstance(standard_dict, dict)
assert isinstance(typed_dict, dict)
Expand All @@ -1321,9 +1316,9 @@ if __name__ == "__main__":

<pre>
<samp>$ <kbd>python ./11_dictionaries.py --help</kbd>
usage: 11_dictionaries.py [-h] [--standard-dict STR FLOAT [STR FLOAT ...]]
[--typed-dict.learning-rate FLOAT]
usage: 11_dictionaries.py [-h] [--typed-dict.learning-rate FLOAT]
[--typed-dict.betas FLOAT FLOAT]
[--standard-dict STR FLOAT [STR FLOAT ...]]

arguments:
-h, --help show this help message and exit
Expand All @@ -1334,9 +1329,23 @@ arguments:
typed_dict arguments:

--typed-dict.learning-rate FLOAT
(default: 0.0003)
Setting `total=False` specifies that not all keys need
to exist. (unset by default)
--typed-dict.betas FLOAT FLOAT
(default: 0.9 0.999)</samp>
Setting `total=False` specifies that not all keys need
to exist. (unset by default)</samp>
</pre>

<pre>
<samp>$ <kbd>python ./11_dictionaries.py --typed-dict.learning-rate 3e-4</kbd>
Standard dict: {&#x27;learning_rate&#x27;: 0.0003, &#x27;beta1&#x27;: 0.9, &#x27;beta2&#x27;: 0.999}
Typed dict: {&#x27;learning_rate&#x27;: 0.0003}</samp>
</pre>

<pre>
<samp>$ <kbd>python ./11_dictionaries.py --typed-dict.betas 0.9 0.999</kbd>
Standard dict: {&#x27;learning_rate&#x27;: 0.0003, &#x27;beta1&#x27;: 0.9, &#x27;beta2&#x27;: 0.999}
Typed dict: {&#x27;betas&#x27;: (0.9, 0.999)}</samp>
</pre>

</blockquote>
Expand Down
2 changes: 2 additions & 0 deletions dcargs/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def _rule_generate_helptext(
default_text = f"(sets: {arg.field.name}=True)"
elif lowered.action == "store_false":
default_text = f"(sets: {arg.field.name}=False)"
elif arg.field.default is _fields.EXCLUDE_FROM_CALL:
default_text = "(unset by default)"
elif lowered.nargs is not None and hasattr(default, "__iter__"):
# For tuple types, we might have default as (0, 1, 2, 3).
# For list types, we might have default as [0, 1, 2, 3].
Expand Down
9 changes: 5 additions & 4 deletions dcargs/_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ def get_value_from_arg(prefixed_field_name: str) -> Any:
)
consumed_keywords |= consumed_keywords_child

if field.positional:
args.append(value)
else:
kwargs[field.name] = value
if value is not _fields.EXCLUDE_FROM_CALL:
if field.positional:
args.append(value)
else:
kwargs[field.name] = value

return f(*args, **kwargs), consumed_keywords # type: ignore
29 changes: 23 additions & 6 deletions dcargs/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import docstring_parser
from typing_extensions import get_type_hints, is_typeddict

from . import _docstrings, _resolver
from . import _docstrings, _instantiators, _parsers, _resolver


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -44,11 +44,16 @@ class NonpropagatingMissingType(_Singleton):
pass


class ExcludeFromKwargsType(_Singleton):
pass


# We have two types of missing sentinels: a propagating missing value, which when set as
# a default will set all child values of nested structures as missing as well, and a
# nonpropagating missing sentinel, which does not override child defaults.
MISSING_PROP = PropagatingMissingType()
MISSING_NONPROP = NonpropagatingMissingType()
EXCLUDE_FROM_CALL = ExcludeFromKwargsType()

# Note that our "public" missing API will always be the propagating missing sentinel.
MISSING_PUBLIC: Any = MISSING_PROP
Expand Down Expand Up @@ -98,16 +103,28 @@ def field_list_from_callable(
if cls is not None and is_typeddict(cls):
# Handle typed dictionaries.
field_list = []
no_default_instance = default_instance in MISSING_SINGLETONS
assert no_default_instance or isinstance(default_instance, dict)
valid_default_instance = (
default_instance not in MISSING_SINGLETONS
and default_instance is not EXCLUDE_FROM_CALL
)
assert not valid_default_instance or isinstance(default_instance, dict)
for name, typ in get_type_hints(cls).items():
if valid_default_instance:
default = default_instance.get(name, MISSING_PROP) # type: ignore
elif getattr(cls, "__total__") is False:
default = EXCLUDE_FROM_CALL
if _parsers.is_possibly_nested_type(typ):
raise _instantiators.UnsupportedTypeAnnotationError(
"`total=False` not supported for nested structures."
)
else:
default = MISSING_PROP

field_list.append(
FieldDefinition(
name=name,
typ=typ,
default=MISSING_PROP
if no_default_instance
else default_instance.get(name, MISSING_PROP), # type: ignore
default=default,
helptext=_docstrings.get_field_docstring(cls, name),
positional=False,
)
Expand Down
8 changes: 4 additions & 4 deletions dcargs/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
)


def _is_possibly_nested_type(typ: Any) -> bool:
def is_possibly_nested_type(typ: Any) -> bool:
"""Heuristics for determining whether a type can be treated as a 'nested type',
where a single field has multiple corresponding arguments (eg for nested
dataclasses or classes).
Expand Down Expand Up @@ -189,10 +189,10 @@ def from_callable(
continue
else:
field = dataclasses.replace(field, typ=type(field.default))
assert _is_possibly_nested_type(field.typ)
assert is_possibly_nested_type(field.typ)

# (2) Handle nested callables.
if _is_possibly_nested_type(field.typ):
if is_possibly_nested_type(field.typ):
nested_parser = ParserSpecification.from_callable(
field.typ,
description=None,
Expand Down Expand Up @@ -361,7 +361,7 @@ def from_field(
# We don't use sets here to retain order of subcommands.
options = [type_from_typevar.get(typ, typ) for typ in get_args(field.typ)]
options_no_none = [o for o in options if o != type(None)] # noqa
if not all(map(_is_possibly_nested_type, options_no_none)):
if not all(map(is_possibly_nested_type, options_no_none)):
return None

parser_from_name: Dict[str, ParserSpecification] = {}
Expand Down
16 changes: 8 additions & 8 deletions examples/11_dictionaries.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
"""Dictionary inputs can be specified using either a standard `Dict[K, V]` annotation,
or a `TypedDict` type.
Note that setting `total=False` for `TypedDict` is currently not (but reasonably could be)
supported.
Usage:
`python ./11_dictionaries.py --help`
`python ./11_dictionaries.py --typed-dict.learning-rate 3e-4`
`python ./11_dictionaries.py --typed-dict.betas 0.9 0.999`
"""

from typing import Dict, Tuple, TypedDict

import dcargs


class DictionarySchema(TypedDict):
class DictionarySchema(
TypedDict,
# Setting `total=False` specifies that not all keys need to exist.
total=False,
):
learning_rate: float
betas: Tuple[float, float]


def main(
typed_dict: DictionarySchema,
standard_dict: Dict[str, float] = {
"learning_rate": 3e-4,
"beta1": 0.9,
"beta2": 0.999,
},
typed_dict: DictionarySchema = {
"learning_rate": 3e-4,
"betas": (0.9, 0.999),
},
) -> None:
assert isinstance(standard_dict, dict)
assert isinstance(typed_dict, dict)
Expand Down
Loading

0 comments on commit ae521cb

Please sign in to comment.