Skip to content

Commit

Permalink
Merge pull request #5 from brentyi/feat/subcommand_config
Browse files Browse the repository at this point in the history
v0.3.0
  • Loading branch information
brentyi authored Sep 7, 2022
2 parents 767e4df + 5833870 commit 943f93e
Show file tree
Hide file tree
Showing 51 changed files with 2,022 additions and 851 deletions.
6 changes: 3 additions & 3 deletions dcargs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from . import extras
from ._cli import cli, generate_parser
from . import conf, extras
from ._cli import cli
from ._fields import MISSING_PUBLIC as MISSING
from ._instantiators import UnsupportedTypeAnnotationError

__all__ = [
"conf",
"extras",
"cli",
"generate_parser",
"MISSING",
"UnsupportedTypeAnnotationError",
]
Expand Down
37 changes: 36 additions & 1 deletion dcargs/_argparse_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import shutil
from typing import Any, ContextManager, Generator

import termcolor

from . import _strings


Expand All @@ -15,6 +17,22 @@ def monkeypatch_len(obj: Any) -> int:
return len(obj)


def dummy_termcolor_context() -> ContextManager[None]:
"""Context for turning termcolor off."""

def dummy_colored(*args, **kwargs) -> str:
return args[0]

@contextlib.contextmanager
def inner() -> Generator[None, None, None]:
orig_colored = termcolor.colored
termcolor.colored = dummy_colored
yield
termcolor.colored = orig_colored

return inner()


def ansi_context() -> ContextManager[None]:
"""Context for working with ANSI codes + argparse:
- Applies a temporary monkey patch for making argparse ignore ANSI codes when
Expand All @@ -25,6 +43,7 @@ def ansi_context() -> ContextManager[None]:
@contextlib.contextmanager
def inner() -> Generator[None, None, None]:
if not hasattr(argparse, "len"):
# Sketchy, but seems to work.
argparse.len = monkeypatch_len # type: ignore
try:
# Use Colorama to support coloring in Windows shells.
Expand Down Expand Up @@ -136,7 +155,8 @@ def _format_action(self, action):
)
)
# </new>
parts.append("%*s%s\n" % (indent_first, "", help_lines[0]))

parts.append("%*s%s\n" % (indent_first, "", help_lines[0])) # type: ignore
for line in help_lines[1:]:
parts.append("%*s%s\n" % (help_position, "", line))

Expand All @@ -150,3 +170,18 @@ def _format_action(self, action):

# return a single string
return self._join_parts(parts)

def _split_lines(self, text, width):
text = self._whitespace_matcher.sub(" ", text).strip()
# The textwrap module is used only for formatting help.
# Delay its import for speeding up the common usage of argparse.
import textwrap as textwrap

# Sketchy, but seems to work.
textwrap.len = monkeypatch_len # type: ignore
out = textwrap.wrap(text, width)
del textwrap.len # type: ignore
return out

def _fill_text(self, text, width, indent):
return "".join(indent + line for line in text.splitlines(keepends=True))
52 changes: 37 additions & 15 deletions dcargs/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

import termcolor

from . import _fields, _instantiators, _strings
from . import _fields, _instantiators, _resolver, _strings
from .conf import _markers

try:
# Python >=3.8.
Expand Down Expand Up @@ -95,9 +96,9 @@ class LoweredArgumentDefinition:

def is_fixed(self) -> bool:
"""If the instantiator is set to `None`, even after all argument
transformations, it means that we weren't able to determine a valid instantiator
for an argument. We then mark the argument as 'fixed', with a value always equal
to the field default."""
transformations, it means that we don't have a valid instantiator for an
argument. We then mark the argument as 'fixed', with a value always equal to the
field default."""
return self.instantiator is None

# From here on out, all fields correspond 1:1 to inputs to argparse's
Expand Down Expand Up @@ -132,26 +133,35 @@ def _rule_handle_boolean_flags(
arg: ArgumentDefinition,
lowered: LoweredArgumentDefinition,
) -> LoweredArgumentDefinition:
if arg.type_from_typevar.get(arg.field.typ, arg.field.typ) is not bool: # type: ignore
if _resolver.apply_type_from_typevar(arg.field.typ, arg.type_from_typevar) is not bool: # type: ignore
return lowered

if lowered.default is False and not arg.field.positional:
if (
arg.field.default in _fields.MISSING_SINGLETONS
or arg.field.is_positional()
or _markers.FLAG_CONVERSION_OFF in arg.field.markers
):
# Treat bools as a normal parameter.
return lowered
elif arg.field.default is False:
# Default `False` => --flag passed in flips to `True`.
return dataclasses.replace(
lowered,
action="store_true",
instantiator=lambda x: x, # argparse will directly give us a bool!
)
elif lowered.default is True and not arg.field.positional:
elif arg.field.default is True:
# Default `True` => --no-flag passed in flips to `False`.
return dataclasses.replace(
lowered,
action="store_false",
instantiator=lambda x: x, # argparse will directly give us a bool!
)
else:
# Treat bools as a normal parameter.
return lowered

assert False, (
"Expected a boolean as a default for {arg.field.name}, but got"
" {lowered.default}."
)


def _rule_recursive_instantiator_from_type(
Expand All @@ -166,6 +176,14 @@ def _rule_recursive_instantiator_from_type(
Conversions from strings to our desired types happen in the instantiator; this is a
bit more flexible, and lets us handle more complex types like enums and multi-type
tuples."""
if _markers.FIXED in arg.field.markers:
return dataclasses.replace(
lowered,
instantiator=None,
metavar=termcolor.colored("{fixed}", color="red"),
required=False,
default=_fields.MISSING_PROP,
)
if lowered.instantiator is not None:
return lowered
try:
Expand All @@ -175,7 +193,11 @@ def _rule_recursive_instantiator_from_type(
)
except _instantiators.UnsupportedTypeAnnotationError as e:
if arg.field.default in _fields.MISSING_SINGLETONS:
raise e
raise _instantiators.UnsupportedTypeAnnotationError(
"Unsupported type annotation for the field"
f" {_strings.make_field_name([arg.prefix, arg.field.name])}. To"
" suppress this error, assign the field a default value."
) from e
else:
# For fields with a default, we'll get by even if there's no instantiator
# available.
Expand Down Expand Up @@ -238,9 +260,9 @@ def _rule_generate_helptext(
# https://stackoverflow.com/questions/21168120/python-argparse-errors-with-in-help-string
docstring_help = docstring_help.replace("%", "%%")
help_parts.append(docstring_help)
elif arg.field.positional and arg.field.name != _strings.dummy_field_name:
elif arg.field.is_positional() and arg.field.name != _strings.dummy_field_name:
# Place the type in the helptext. Note that we skip this for dummy fields, which
# will sitll have the type in the metavar.
# will still have the type in the metavar.
help_parts.append(str(lowered.metavar))

default = lowered.default
Expand Down Expand Up @@ -287,7 +309,7 @@ def _rule_set_name_or_flag(
arg: ArgumentDefinition,
lowered: LoweredArgumentDefinition,
) -> LoweredArgumentDefinition:
if arg.field.positional:
if arg.field.is_positional():
name_or_flag = _strings.make_field_name([arg.prefix, arg.field.name])
elif lowered.action == "store_false":
name_or_flag = "--" + _strings.make_field_name(
Expand All @@ -309,7 +331,7 @@ def _rule_positional_special_handling(
arg: ArgumentDefinition,
lowered: LoweredArgumentDefinition,
) -> LoweredArgumentDefinition:
if not arg.field.positional:
if not arg.field.is_positional():
return lowered

metavar = _strings.make_field_name([arg.prefix, arg.field.name]).upper()
Expand Down
22 changes: 7 additions & 15 deletions dcargs/_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Any, Callable, Dict, List, Sequence, Set, Tuple, TypeVar, Union

from typing_extensions import get_args, get_origin
from typing_extensions import get_args

from . import _arguments, _fields, _parsers, _resolver, _strings

Expand All @@ -24,7 +24,6 @@ def call_from_args(
default_instance: Union[T, _fields.NonpropagatingMissingType],
value_from_prefixed_field_name: Dict[str, Any],
field_name_prefix: str,
avoid_subparsers: bool,
) -> Tuple[T, Set[str]]:
"""Call `f` with arguments specified by a dictionary of values from argparse.
Expand Down Expand Up @@ -56,11 +55,7 @@ def get_value_from_arg(prefixed_field_name: str) -> Any:
prefixed_field_name = _strings.make_field_name([field_name_prefix, field.name])

# Resolve field type.
field_type = (
type_from_typevar[field.typ] # type: ignore
if field.typ in type_from_typevar
else field.typ
)
field_type = _resolver.apply_type_from_typevar(field.typ, type_from_typevar) # type: ignore

if prefixed_field_name in arg_from_prefixed_field_name:
assert prefixed_field_name not in consumed_keywords
Expand Down Expand Up @@ -100,16 +95,14 @@ def get_value_from_arg(prefixed_field_name: str) -> Any:
in parser_definition.helptext_from_nested_class_field_name
):
# Nested callable.
if get_origin(field_type) is Union:
assert avoid_subparsers
if _resolver.unwrap_origin_strip_extras(field_type) is Union:
field_type = type(field.default)
value, consumed_keywords_child = call_from_args(
field_type,
parser_definition,
field.default,
value_from_prefixed_field_name,
field_name_prefix=prefixed_field_name,
avoid_subparsers=avoid_subparsers,
)
consumed_keywords |= consumed_keywords_child
else:
Expand Down Expand Up @@ -148,8 +141,8 @@ def get_value_from_arg(prefixed_field_name: str) -> Any:
value = None
else:
options = map(
lambda x: x if x not in type_from_typevar else type_from_typevar[x],
get_args(field_type),
lambda x: _resolver.apply_type_from_typevar(x, type_from_typevar),
get_args(_resolver.unwrap_annotated(field_type)[0]),
)
chosen_f = None
for option in options:
Expand All @@ -166,19 +159,18 @@ def get_value_from_arg(prefixed_field_name: str) -> Any:
field.default if type(field.default) is chosen_f else None,
value_from_prefixed_field_name,
field_name_prefix=prefixed_field_name,
avoid_subparsers=avoid_subparsers,
)
consumed_keywords |= consumed_keywords_child

if value is not _fields.EXCLUDE_FROM_CALL:
if field.positional:
if field.is_positional():
args.append(value)
else:
kwargs[
field.name if field.name_override is None else field.name_override
] = value

unwrapped_f = _resolver.unwrap_origin(f)
unwrapped_f = _resolver.unwrap_origin_strip_extras(f)
unwrapped_f = list if unwrapped_f is Sequence else unwrapped_f # type: ignore
unwrapped_f = _resolver.narrow_type(unwrapped_f, default_instance)
if unwrapped_f in (tuple, list, set):
Expand Down
Loading

0 comments on commit 943f93e

Please sign in to comment.