Skip to content

Commit

Permalink
Start refactor (broken)
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 19, 2024
1 parent 89ff57a commit 3d661fc
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 222 deletions.
2 changes: 1 addition & 1 deletion examples/01_basics/07_unions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Color(enum.Enum):
@dataclasses.dataclass(frozen=True)
class Args:
# Unions can be used to specify multiple allowable types.
union_over_types: int | str = 0
union_over_types: int | str
string_or_enum: Literal["red", "green"] | Color = "red"

# Unions also work over more complex nested types.
Expand Down
9 changes: 1 addition & 8 deletions src/tyro/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ class ArgumentDefinition:
extern_prefix: str # User-facing prefix.
subcommand_prefix: str # Prefix for nesting.
field: _fields.FieldDefinition
type_from_typevar: Dict[TypeVar, TypeForm[Any]]

def add_argument(
self, parser: Union[argparse.ArgumentParser, argparse._ArgumentGroup]
Expand Down Expand Up @@ -254,12 +253,7 @@ def _rule_handle_boolean_flags(
arg: ArgumentDefinition,
lowered: LoweredArgumentDefinition,
) -> None:
if (
_resolver.apply_type_from_typevar(
arg.field.type_or_callable, arg.type_from_typevar
)
is not bool
):
if arg.field.type_or_callable is not bool:
return

if (
Expand Down Expand Up @@ -305,7 +299,6 @@ def _rule_recursive_instantiator_from_type(
try:
instantiator, metadata = _instantiators.instantiator_from_type(
arg.field.type_or_callable,
arg.type_from_typevar,
arg.field.markers,
)
except _instantiators.UnsupportedTypeAnnotationError as e:
Expand Down
1 change: 0 additions & 1 deletion src/tyro/_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ def get_callable_description(f: Callable) -> str:
the fields of the class if a docstring is not specified; this helper will ignore
these docstrings."""

f, _unused = _resolver.resolve_generic_types(f)
f = _resolver.unwrap_origin_strip_extras(f)
if f in _callable_description_blocklist:
return ""
Expand Down
93 changes: 41 additions & 52 deletions src/tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@
is_typeddict,
)

from . import conf # Avoid circular import.
from . import (
_docstrings,
_instantiators,
_resolver,
_singleton,
_strings,
_unsafe_cache,
conf, # Avoid circular import.
)
from ._typing import TypeForm
from .conf import _confstruct, _markers
Expand Down Expand Up @@ -109,9 +109,6 @@ def make(
*,
markers: Tuple[_markers.Marker, ...] = (),
):
# Resolve generic aliases.
type_or_callable = _resolver.apply_type_from_typevar(type_or_callable, {})

# Try to extract argconf overrides from type.
_, argconfs = _resolver.unwrap_annotated_and_aliases(
type_or_callable, _confstruct._ArgConfiguration
Expand Down Expand Up @@ -147,6 +144,38 @@ def make(
for context_markers in _field_context_markers:
markers += context_markers

# Type resolution.
type_or_callable = _resolver.type_from_typevar_constraints(type_or_callable)
type_or_callable = _resolver.narrow_collection_types(type_or_callable, default)
type_or_callable = _resolver.narrow_union_type(type_or_callable, default)
type_or_callable = _resolver.apply_type_shims(type_or_callable)

# Check that the default value matches the final resolved type.
# There's some similar Union-specific logic for this in narrow_union_type(). We
# may be able to consolidate this.
if (
# Be relatively conservative: isinstance() can be checked on non-type
# types (like unions in Python >=3.10), but we'll only consider single types
# for now.
type(type_or_callable) is type
and not isinstance(default, type_or_callable) # type: ignore
# If a custom constructor is set, type_or_callable may not be
# matched to the annotated type.
and argconf.constructor_factory is None
and default not in DEFAULT_SENTINEL_SINGLETONS
# The numeric tower in Python is wacky. This logic is non-critical, so
# we'll just skip it (+the complexity) for numbers.
and not isinstance(default, numbers.Number)
):
# If the default value doesn't match the resolved type, we expand the
# type. This is inspired by https://github.com/brentyi/tyro/issues/88.
warnings.warn(
f"The field {name} is annotated with type {type_or_callable}, "
f"but the default value {default} has type {type(default)}. "
f"We'll try to handle this gracefully, but it may cause unexpected behavior."
)
type_or_callable = Union[type_or_callable, type(default)] # type: ignore

out = FieldDefinition(
intern_name=name,
extern_name=name if argconf.name is None else argconf.name,
Expand Down Expand Up @@ -271,9 +300,7 @@ def field_list_from_callable(
f: Union[Callable, TypeForm[Any]],
default_instance: DefaultInstance,
support_single_arg_types: bool,
) -> Tuple[
Union[Callable, TypeForm[Any]], Dict[TypeVar, TypeForm], List[FieldDefinition]
]:
) -> Tuple[Union[Callable, TypeForm[Any]], List[FieldDefinition]]:
"""Generate a list of generic 'field' objects corresponding to the inputs of some
annotated callable.
Expand All @@ -283,7 +310,6 @@ def field_list_from_callable(
A list of field definitions.
"""
# Resolve generic types.
f, type_from_typevar = _resolver.resolve_generic_types(f)
f = _resolver.unwrap_newtype_and_narrow_subtypes(f, default_instance)

# Try to generate field list.
Expand All @@ -296,7 +322,6 @@ def field_list_from_callable(
if support_single_arg_types:
return (
f,
type_from_typevar,
[
FieldDefinition(
intern_name="value",
Expand All @@ -317,46 +342,7 @@ def field_list_from_callable(
else:
raise _instantiators.UnsupportedTypeAnnotationError(field_list.message)

# Try to resolve types in our list of fields.
def resolve(field: FieldDefinition) -> FieldDefinition:
typ = field.type_or_callable
typ = _resolver.apply_type_from_typevar(typ, type_from_typevar)
typ = _resolver.type_from_typevar_constraints(typ)
typ = _resolver.narrow_collection_types(typ, field.default)
typ = _resolver.narrow_union_type(typ, field.default)

# Check that the default value matches the final resolved type.
# There's some similar Union-specific logic for this in narrow_union_type(). We
# may be able to consolidate this.
if (
# Be relatively conservative: isinstance() can be checked on non-type
# types (like unions in Python >=3.10), but we'll only consider single types
# for now.
type(typ) is type
and not isinstance(field.default, typ) # type: ignore
# If a custom constructor is set, field.type_or_callable may not be
# matched to the annotated type.
and not field.custom_constructor
and field.default not in DEFAULT_SENTINEL_SINGLETONS
# The numeric tower in Python is wacky. This logic is non-critical, so
# we'll just skip it (+the complexity) for numbers.
and not isinstance(field.default, numbers.Number)
):
# If the default value doesn't match the resolved type, we expand the
# type. This is inspired by https://github.com/brentyi/tyro/issues/88.
warnings.warn(
f"The field {field.intern_name} is annotated with type {field.type_or_callable}, "
f"but the default value {field.default} has type {type(field.default)}. "
f"We'll try to handle this gracefully, but it may cause unexpected behavior."
)
typ = Union[typ, type(field.default)] # type: ignore

field = dataclasses.replace(field, type_or_callable=typ)
return field

field_list = list(map(resolve, field_list))

return f, type_from_typevar, field_list
return f, field_list


# Implementation details below.
Expand Down Expand Up @@ -396,8 +382,6 @@ def _try_field_list_from_callable(
default_instance = found_subcommand_configs[0].default

# Unwrap generics.
f, type_from_typevar = _resolver.resolve_generic_types(f)
f = _resolver.apply_type_from_typevar(f, type_from_typevar)
f = _resolver.unwrap_newtype_and_narrow_subtypes(f, default_instance)
f = _resolver.narrow_collection_types(f, default_instance)
f_origin = _resolver.unwrap_origin_strip_extras(cast(TypeForm, f))
Expand Down Expand Up @@ -445,7 +429,12 @@ def _try_field_list_from_callable(
dict,
):
return _field_list_from_dict(f, default_instance)
elif f_origin in (list, set, typing.Sequence, collections.abc.Sequence) or cls in (
elif f_origin in (
list,
set,
typing.Sequence,
collections.abc.Sequence,
) or cls in (
list,
set,
typing.Sequence,
Expand Down
38 changes: 8 additions & 30 deletions src/tyro/_instantiators.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,6 @@ def is_type_string_converter(typ: Union[Callable, TypeForm[Any]]) -> bool:
# Some checks we can do if the signature is available!
for i, param in enumerate(signature.parameters.values()):
annotation = type_annotations.get(param.name, param.annotation)

# Hack: apply_type_from_typevar applies shims, like UnionType => Union
# conversion.
annotation = _resolver.apply_type_from_typevar(annotation, {})
if i == 0 and not (
(get_origin(annotation) is Union and str in get_args(annotation))
or annotation in (str, inspect.Parameter.empty)
Expand All @@ -153,7 +149,6 @@ def is_type_string_converter(typ: Union[Callable, TypeForm[Any]]) -> bool:

def instantiator_from_type(
typ: Union[TypeForm[Any], Callable],
type_from_typevar: Dict[TypeVar, TypeForm[Any]],
markers: Set[_markers.Marker],
) -> Tuple[Instantiator, InstantiatorMetadata]:
"""Recursive helper for parsing type annotations.
Expand Down Expand Up @@ -201,9 +196,7 @@ def instantiator(strings: List[str]) -> None:

# Address container types. If a matching container is found, this will recursively
# call instantiator_from_type().
container_out = _instantiator_from_container_type(
cast(TypeForm[Any], typ), type_from_typevar, markers
)
container_out = _instantiator_from_container_type(cast(TypeForm[Any], typ), markers)
if container_out is not None:
return container_out

Expand Down Expand Up @@ -335,7 +328,6 @@ def instantiator_base_case(strings: List[str]) -> Any:
@overload
def _instantiator_from_type_inner(
typ: TypeForm,
type_from_typevar: Dict[TypeVar, TypeForm[Any]],
allow_sequences: Literal["fixed_length"],
markers: Set[_markers.Marker],
) -> Tuple[Instantiator, InstantiatorMetadata]: ...
Expand All @@ -344,7 +336,6 @@ def _instantiator_from_type_inner(
@overload
def _instantiator_from_type_inner(
typ: TypeForm,
type_from_typevar: Dict[TypeVar, TypeForm[Any]],
allow_sequences: Literal[False],
markers: Set[_markers.Marker],
) -> Tuple[_StandardInstantiator, InstantiatorMetadata]: ...
Expand All @@ -353,21 +344,19 @@ def _instantiator_from_type_inner(
@overload
def _instantiator_from_type_inner(
typ: TypeForm,
type_from_typevar: Dict[TypeVar, TypeForm[Any]],
allow_sequences: Literal[True],
markers: Set[_markers.Marker],
) -> Tuple[Instantiator, InstantiatorMetadata]: ...


def _instantiator_from_type_inner(
typ: TypeForm,
type_from_typevar: Dict[TypeVar, TypeForm[Any]],
allow_sequences: Literal["fixed_length", True, False],
markers: Set[_markers.Marker],
) -> Tuple[Instantiator, InstantiatorMetadata]:
"""Thin wrapper over instantiator_from_type, with some extra asserts for catching
errors."""
out = instantiator_from_type(typ, type_from_typevar, markers)
out = instantiator_from_type(typ, markers)
if out[1].nargs == "*":
# We currently only use allow_sequences=False for options in Literal types,
# which are evaluated using `type()`. It should not be possible to hit this
Expand All @@ -384,7 +373,6 @@ def _instantiator_from_type_inner(

def _instantiator_from_container_type(
typ: TypeForm[Any],
type_from_typevar: Dict[TypeVar, TypeForm[Any]],
markers: Set[_markers.Marker],
) -> Optional[Tuple[Instantiator, InstantiatorMetadata]]:
"""Attempt to create an instantiator from a container type. Returns `None` if no
Expand Down Expand Up @@ -418,13 +406,12 @@ def _instantiator_from_container_type(
_instantiator_from_literal: (Literal, LiteralAlternate),
}.items():
if type_origin in matched_origins:
return make(typ, type_from_typevar, markers)
return make(typ, markers)
return None


def _instantiator_from_tuple(
typ: TypeForm,
type_from_typevar: Dict[TypeVar, TypeForm[Any]],
markers: Set[_markers.Marker],
) -> Tuple[Instantiator, InstantiatorMetadata]:
types = get_args(typ)
Expand All @@ -435,15 +422,15 @@ def _instantiator_from_tuple(
# Ellipsis: variable argument counts. When an ellipsis is used, tuples must
# contain only one type.
assert len(typeset_no_ellipsis) == 1
return _instantiator_from_sequence(typ, type_from_typevar, markers)
return _instantiator_from_sequence(typ, markers)

else:
instantiators: List[_StandardInstantiator] = []
metas: List[InstantiatorMetadata] = []
nargs = 0
for t in types:
a, b = _instantiator_from_type_inner(
t, type_from_typevar, allow_sequences="fixed_length", markers=markers
t, allow_sequences="fixed_length", markers=markers
)
instantiators.append(a) # type: ignore
metas.append(b)
Expand Down Expand Up @@ -509,7 +496,6 @@ def _join_union_metavars(metavars: Iterable[str]) -> str:

def _instantiator_from_union(
typ: TypeForm,
type_from_typevar: Dict[TypeVar, TypeForm[Any]],
markers: Set[_markers.Marker],
) -> Tuple[Instantiator, InstantiatorMetadata]:
options = list(get_args(typ))
Expand All @@ -528,9 +514,7 @@ def _instantiator_from_union(
nargs: Optional[Union[int, Literal["*"]]] = 1
first = True
for t in options:
a, b = _instantiator_from_type_inner(
t, type_from_typevar, allow_sequences=True, markers=markers
)
a, b = _instantiator_from_type_inner(t, allow_sequences=True, markers=markers)
instantiators.append(a)
metas.append(b)
if b.choices is None:
Expand Down Expand Up @@ -591,18 +575,16 @@ def union_instantiator(strings: List[str]) -> Any:

def _instantiator_from_dict(
typ: TypeForm,
type_from_typevar: Dict[TypeVar, TypeForm[Any]],
markers: Set[_markers.Marker],
) -> Tuple[Instantiator, InstantiatorMetadata]:
key_type, val_type = get_args(typ)
key_instantiator, key_meta = _instantiator_from_type_inner(
key_type, type_from_typevar, allow_sequences="fixed_length", markers=markers
key_type, allow_sequences="fixed_length", markers=markers
)

if _markers.UseAppendAction in markers:
val_instantiator, val_meta = _instantiator_from_type_inner(
val_type,
type_from_typevar,
allow_sequences=True,
markers=markers - {_markers.UseAppendAction},
)
Expand All @@ -625,7 +607,7 @@ def append_dict_instantiator(strings: List[List[str]]) -> Any:
)
else:
val_instantiator, val_meta = _instantiator_from_type_inner(
val_type, type_from_typevar, allow_sequences="fixed_length", markers=markers
val_type, allow_sequences="fixed_length", markers=markers
)
pair_metavar = f"{key_meta.metavar} {val_meta.metavar}"
key_nargs = cast(int, key_meta.nargs) # Casts needed for mypy but not pyright!
Expand Down Expand Up @@ -673,7 +655,6 @@ def dict_instantiator(strings: List[str]) -> Any:

def _instantiator_from_sequence(
typ: TypeForm,
type_from_typevar: Dict[TypeVar, TypeForm[Any]],
markers: Set[_markers.Marker],
) -> Tuple[Instantiator, InstantiatorMetadata]:
"""Instantiator for variable-length sequences: list, sets, Tuple[T, ...], etc."""
Expand All @@ -691,7 +672,6 @@ def _instantiator_from_sequence(
if _markers.UseAppendAction in markers:
make, inner_meta = _instantiator_from_type_inner(
contained_type,
type_from_typevar,
allow_sequences=True,
markers=markers - {_markers.UseAppendAction},
)
Expand All @@ -709,7 +689,6 @@ def append_sequence_instantiator(strings: List[List[str]]) -> Any:
else:
make, inner_meta = _instantiator_from_type_inner(
contained_type,
type_from_typevar,
allow_sequences="fixed_length",
markers=markers,
)
Expand Down Expand Up @@ -743,7 +722,6 @@ def sequence_instantiator(strings: List[str]) -> Any:

def _instantiator_from_literal(
typ: TypeForm,
type_from_typevar: Dict[TypeVar, TypeForm[Any]],
markers: Set[_markers.Marker],
) -> Tuple[_StandardInstantiator, InstantiatorMetadata]:
choices = get_args(typ)
Expand Down
Loading

0 comments on commit 3d661fc

Please sign in to comment.