Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 19, 2024
1 parent 3d661fc commit 2a74252
Show file tree
Hide file tree
Showing 9 changed files with 360 additions and 327 deletions.
4 changes: 1 addition & 3 deletions src/tyro/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Expand All @@ -27,8 +26,7 @@
import shtab

from . import _argparse as argparse
from . import _fields, _instantiators, _resolver, _strings
from ._typing import TypeForm
from . import _fields, _instantiators, _strings
from .conf import _markers

if TYPE_CHECKING:
Expand Down
195 changes: 101 additions & 94 deletions src/tyro/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import shtab
from typing_extensions import Literal

from tyro._resolver import TypeParamResolver

from . import _argparse as argparse
from . import (
_argparse_formatter,
Expand Down Expand Up @@ -316,104 +318,109 @@ def _cli_impl(
stacklevel=2,
)

# Internally, we distinguish between two concepts:
# - "default", which is used for individual arguments.
# - "default_instance", which is used for _fields_ (which may be broken down into
# one or many arguments, depending on various factors).
#
# This could be revisited.
default_instance_internal: Union[_fields.NonpropagatingMissingType, OutT] = (
_fields.MISSING_NONPROP if default is None else default
)

# We wrap our type with a dummy dataclass if it can't be treated as a nested type.
# For example: passing in f=int will result in a dataclass with a single field
# typed as int.
if not _fields.is_nested_type(cast(type, f), default_instance_internal):
dummy_field = cast(
dataclasses.Field,
dataclasses.field(),
resolve_context = TypeParamResolver.get_assignment_context(f)
with resolve_context:
f = resolve_context.origin_type

# Internally, we distinguish between two concepts:
# - "default", which is used for individual arguments.
# - "default_instance", which is used for _fields_ (which may be broken down into
# one or many arguments, depending on various factors).
#
# This could be revisited.
default_instance_internal: Union[_fields.NonpropagatingMissingType, OutT] = (
_fields.MISSING_NONPROP if default is None else default
)
f = dataclasses.make_dataclass(
cls_name="dummy",
fields=[(_strings.dummy_field_name, cast(type, f), dummy_field)],
frozen=True,
)
default_instance_internal = f(default_instance_internal) # type: ignore
dummy_wrapped = True
else:
dummy_wrapped = False

# Read and fix arguments. If the user passes in --field_name instead of
# --field-name, correct for them.
args = list(sys.argv[1:]) if args is None else list(args)

# Fix arguments. This will modify all option-style arguments replacing
# underscores with hyphens, or vice versa if use_underscores=True.
# If two options are ambiguous, e.g., --a_b and --a-b, raise a runtime error.
modified_args: Dict[str, str] = {}
for index, arg in enumerate(args):
if not arg.startswith("--"):
continue

if "=" in arg:
arg, _, val = arg.partition("=")
fixed = "--" + _strings.replace_delimeter_in_part(arg[2:]) + "=" + val

# We wrap our type with a dummy dataclass if it can't be treated as a nested type.
# For example: passing in f=int will result in a dataclass with a single field
# typed as int.
if not _fields.is_nested_type(cast(type, f), default_instance_internal):
dummy_field = cast(
dataclasses.Field,
dataclasses.field(),
)
f = dataclasses.make_dataclass(
cls_name="dummy",
fields=[(_strings.dummy_field_name, cast(type, f), dummy_field)],
frozen=True,
)
default_instance_internal = f(default_instance_internal) # type: ignore
dummy_wrapped = True
else:
fixed = "--" + _strings.replace_delimeter_in_part(arg[2:])
if (
return_unknown_args
and fixed in modified_args
and modified_args[fixed] != arg
):
raise RuntimeError(
"Ambiguous arguments: " + modified_args[fixed] + " and " + arg
dummy_wrapped = False

# Read and fix arguments. If the user passes in --field_name instead of
# --field-name, correct for them.
args = list(sys.argv[1:]) if args is None else list(args)

# Fix arguments. This will modify all option-style arguments replacing
# underscores with hyphens, or vice versa if use_underscores=True.
# If two options are ambiguous, e.g., --a_b and --a-b, raise a runtime error.
modified_args: Dict[str, str] = {}
for index, arg in enumerate(args):
if not arg.startswith("--"):
continue

if "=" in arg:
arg, _, val = arg.partition("=")
fixed = "--" + _strings.replace_delimeter_in_part(arg[2:]) + "=" + val
else:
fixed = "--" + _strings.replace_delimeter_in_part(arg[2:])
if (
return_unknown_args
and fixed in modified_args
and modified_args[fixed] != arg
):
raise RuntimeError(
"Ambiguous arguments: " + modified_args[fixed] + " and " + arg
)
modified_args[fixed] = arg
args[index] = fixed

# If we pass in the --tyro-print-completion or --tyro-write-completion flags: turn
# formatting tags, and get the shell we want to generate a completion script for
# (bash/zsh/tcsh).
#
# shtab also offers an add_argument_to() functions that fulfills a similar goal, but
# manual parsing of argv is convenient for turning off formatting.
#
# Note: --tyro-print-completion is deprecated! --tyro-write-completion is less prone
# to errors from accidental logging, print statements, etc.
print_completion = False
write_completion = False
if len(args) >= 2:
# We replace underscores with hyphens to accomodate for `use_undercores`.
print_completion = args[0].replace("_", "-") == "--tyro-print-completion"
write_completion = (
len(args) >= 3
and args[0].replace("_", "-") == "--tyro-write-completion"
)
modified_args[fixed] = arg
args[index] = fixed

# If we pass in the --tyro-print-completion or --tyro-write-completion flags: turn
# formatting tags, and get the shell we want to generate a completion script for
# (bash/zsh/tcsh).
#
# shtab also offers an add_argument_to() functions that fulfills a similar goal, but
# manual parsing of argv is convenient for turning off formatting.
#
# Note: --tyro-print-completion is deprecated! --tyro-write-completion is less prone
# to errors from accidental logging, print statements, etc.
print_completion = False
write_completion = False
if len(args) >= 2:
# We replace underscores with hyphens to accomodate for `use_undercores`.
print_completion = args[0].replace("_", "-") == "--tyro-print-completion"
write_completion = (
len(args) >= 3 and args[0].replace("_", "-") == "--tyro-write-completion"
)

# Note: setting USE_RICH must happen before the parser specification is generated.
# TODO: revisit this. Ideally we should be able to eliminate the global state
# changes.
completion_shell = None
completion_target_path = None
if print_completion or write_completion:
completion_shell = args[1]
if write_completion:
completion_target_path = pathlib.Path(args[2])
if print_completion or write_completion or return_parser:
_arguments.USE_RICH = False
else:
_arguments.USE_RICH = True

# Map a callable to the relevant CLI arguments + subparsers.
parser_spec = _parsers.ParserSpecification.from_callable_or_type(
f,
description=description,
parent_classes=set(), # Used for recursive calls.
default_instance=default_instance_internal, # Overrides for default values.
intern_prefix="", # Used for recursive calls.
extern_prefix="", # Used for recursive calls.
subcommand_prefix="", # Used for recursive calls.
)
# Note: setting USE_RICH must happen before the parser specification is generated.
# TODO: revisit this. Ideally we should be able to eliminate the global state
# changes.
completion_shell = None
completion_target_path = None
if print_completion or write_completion:
completion_shell = args[1]
if write_completion:
completion_target_path = pathlib.Path(args[2])
if print_completion or write_completion or return_parser:
_arguments.USE_RICH = False
else:
_arguments.USE_RICH = True

# Map a callable to the relevant CLI arguments + subparsers.
parser_spec = _parsers.ParserSpecification.from_callable_or_type(
f,
description=description,
parent_classes=set(), # Used for recursive calls.
default_instance=default_instance_internal, # Overrides for default values.
intern_prefix="", # Used for recursive calls.
extern_prefix="", # Used for recursive calls.
subcommand_prefix="", # Used for recursive calls.
)

# Generate parser!
with _argparse_formatter.ansi_context():
Expand Down
1 change: 1 addition & 0 deletions src/tyro/_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def get_callable_description(f: Callable) -> str:
the fields of the class if a docstring is not specified; this helper will ignore
these docstrings."""

f, _ = _resolver.resolve_generic_types(f)
f = _resolver.unwrap_origin_strip_extras(f)
if f in _callable_description_blocklist:
return ""
Expand Down
Loading

0 comments on commit 2a74252

Please sign in to comment.