Skip to content

Commit

Permalink
Support directly passing builtins into dcargs.cli
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 18, 2022
1 parent 362afe5 commit 47875b0
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 49 deletions.
2 changes: 1 addition & 1 deletion dcargs/_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,4 @@ def get_value_from_arg(prefixed_field_name: str) -> Any:
else:
kwargs[field.name] = value

return f(*args, **kwargs), consumed_keywords # type: ignore
return _resolver.unwrap_origin(f)(*args, **kwargs), consumed_keywords # type: ignore
133 changes: 89 additions & 44 deletions dcargs/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import dataclasses
import inspect
import warnings
from typing import Any, Callable, List, Optional, Type, TypeVar, Union
from typing import Any, Callable, List, Optional, Type, TypeVar, Union, cast

import docstring_parser
from typing_extensions import get_type_hints, is_typeddict
Expand Down Expand Up @@ -89,7 +89,7 @@ def field_list_from_callable(
`f` can be from a dataclass type, regular class type, or function."""

# Unwrap generics.
f, _unused_type_from_typevar = _resolver.resolve_generic_types(f)
f, type_from_typevar = _resolver.resolve_generic_types(f)

# If `f` is a type:
# 1. Set cls to the type.
Expand All @@ -98,7 +98,6 @@ def field_list_from_callable(
if isinstance(f, type):
cls = f
f = cls.__init__ # type: ignore
ignore_self = True

if cls is not None and is_typeddict(cls):
# Handle typed dictionaries.
Expand Down Expand Up @@ -181,51 +180,97 @@ def field_list_from_callable(
default_instance in MISSING_SINGLETONS
), "`default_instance` is only supported for dataclass and TypedDict types."

# Get type annotations, docstrings.
hints = get_type_hints(f)
docstring = inspect.getdoc(f)
docstring_from_arg_name = {}
if docstring is not None:
for param_doc in docstring_parser.parse(docstring).params:
docstring_from_arg_name[param_doc.arg_name] = param_doc.description
del docstring

# Generate field list from function signature.
field_list = []
ignore_self = cls is not None
params = inspect.signature(f).parameters.values()
for param in params:
# For `__init__`, skip self parameter.
if ignore_self:
ignore_self = False
continue

# Get default value.
default = param.default

# Get helptext from docstring.
helptext = docstring_from_arg_name.get(param.name)
if helptext is None and cls is not None:
helptext = _docstrings.get_field_docstring(cls, param.name)

if param.name not in hints:
raise TypeError(
f"Expected fully type-annotated callable, but {f} with arguments"
f" {tuple(map(lambda p: p.name, params))} has no annotation for"
f" '{param.name}'."
)
params = list(inspect.signature(f).parameters.values())
if cls is not None:
# Ignore self parameter.
params = params[1:]

try:
return _field_list_from_params(f, cls, params)
except TypeError as e:
# Try to support passing things like int, str, Dict[K,V], torch.device
# directly into dcargs.cli(). These aren't "type-annotated callables" but
# this a nice-to-have.
param_count = 0
has_kw_only = False
has_var_positional = False
for param in params:
if (
param.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
and param.default is inspect.Parameter.empty
):
param_count += 1
elif param.kind is inspect.Parameter.KEYWORD_ONLY:
has_kw_only = True
elif param.kind is inspect.Parameter.VAR_POSITIONAL:
has_var_positional = True

if not has_kw_only and (
param_count == 1 or (param_count == 0 and has_var_positional)
):
# Things look ok!
if cls is not None:
f = cls
return [
FieldDefinition(
name=_resolver.unwrap_origin(f).__name__,
typ=cast(Type, f),
default=MISSING_NONPROP,
helptext=None,
positional=True,
)
]
else:
raise e

field_list.append(
FieldDefinition(
name=param.name,
# Note that param.annotation does not resolve forward references.
typ=hints[param.name],
default=default,
helptext=helptext,
positional=param.kind is inspect.Parameter.POSITIONAL_ONLY,
)

def _field_list_from_params(
f: Callable, cls: Optional[Type], params: List[inspect.Parameter]
) -> List[FieldDefinition]:
# Get type annotations, docstrings.
docstring = inspect.getdoc(f)
docstring_from_arg_name = {}
if docstring is not None:
for param_doc in docstring_parser.parse(docstring).params:
docstring_from_arg_name[param_doc.arg_name] = param_doc.description
del docstring
hints = get_type_hints(f)

field_list = []
for param in params:
# Get default value.
default = param.default

# Get helptext from docstring.
helptext = docstring_from_arg_name.get(param.name)
if helptext is None and cls is not None:
helptext = _docstrings.get_field_docstring(cls, param.name)

if param.name not in hints:
raise TypeError(
f"Expected fully type-annotated callable, but {f} with arguments"
f" {tuple(map(lambda p: p.name, params))} has no annotation for"
f" '{param.name}'."
)
return field_list

field_list.append(
FieldDefinition(
name=param.name,
# Note that param.annotation does not resolve forward references.
typ=hints[param.name],
default=default,
helptext=helptext,
positional=param.kind is inspect.Parameter.POSITIONAL_ONLY,
)
)

return field_list


def _ensure_dataclass_instance_used_as_default_is_frozen(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="dcargs",
version="0.1.7",
version="0.1.8",
description="Strongly typed, zero-effort CLIs",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
21 changes: 20 additions & 1 deletion tests/test_dcargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import dataclasses
import enum
import pathlib
from typing import Any, AnyStr, Callable, ClassVar, List, Optional, TypeVar, Union
from typing import Any, AnyStr, Callable, ClassVar, Dict, List, Optional, TypeVar, Union

import pytest
import torch
Expand Down Expand Up @@ -453,3 +453,22 @@ def main(device: torch.device) -> torch.device:
return device

assert dcargs.cli(main, args=["--device", "cpu"]) == torch.device("cpu")


def test_torch_device_2():
assert dcargs.cli(torch.device, args=["cpu"]) == torch.device("cpu")


def test_just_int():
assert dcargs.cli(int, args=["123"]) == 123


def test_just_dict():
assert dcargs.cli(Dict[str, str], args="key value key2 value2".split(" ")) == {
"key": "value",
"key2": "value2",
}


def test_just_list():
assert dcargs.cli(List[int], args="1 2 3 4".split(" ")) == [1, 2, 3, 4]
12 changes: 10 additions & 2 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,16 @@ class ChildClass(UnrelatedParentClass, ActualParentClass[int]):
dcargs.cli(ChildClass, args=["--x", "1", "--y", "2", "--z", "3"])


def test_missing_annotation():
def main(a) -> None:
def test_missing_annotation_1():
def main(a, b) -> None:
pass

with pytest.raises(TypeError):
dcargs.cli(main, args=["--help"])


def test_missing_annotation_2():
def main(*, a) -> None:
pass

with pytest.raises(TypeError):
Expand Down

0 comments on commit 47875b0

Please sign in to comment.