diff --git a/examples/01_basics/07_unions.py b/examples/01_basics/07_unions.py index f99ffb04..a0999bba 100644 --- a/examples/01_basics/07_unions.py +++ b/examples/01_basics/07_unions.py @@ -23,7 +23,7 @@ class Color(enum.Enum): @dataclasses.dataclass(frozen=True) class Args: # Unions can be used to specify multiple allowable types. - union_over_types: int | str = 0 + union_over_types: int | str string_or_enum: Literal["red", "green"] | Color = "red" # Unions also work over more complex nested types. diff --git a/src/tyro/_arguments.py b/src/tyro/_arguments.py index 40bce674..99204863 100644 --- a/src/tyro/_arguments.py +++ b/src/tyro/_arguments.py @@ -108,7 +108,6 @@ class ArgumentDefinition: extern_prefix: str # User-facing prefix. subcommand_prefix: str # Prefix for nesting. field: _fields.FieldDefinition - type_from_typevar: Dict[TypeVar, TypeForm[Any]] def add_argument( self, parser: Union[argparse.ArgumentParser, argparse._ArgumentGroup] @@ -254,12 +253,7 @@ def _rule_handle_boolean_flags( arg: ArgumentDefinition, lowered: LoweredArgumentDefinition, ) -> None: - if ( - _resolver.apply_type_from_typevar( - arg.field.type_or_callable, arg.type_from_typevar - ) - is not bool - ): + if arg.field.type_or_callable is not bool: return if ( @@ -305,7 +299,6 @@ def _rule_recursive_instantiator_from_type( try: instantiator, metadata = _instantiators.instantiator_from_type( arg.field.type_or_callable, - arg.type_from_typevar, arg.field.markers, ) except _instantiators.UnsupportedTypeAnnotationError as e: diff --git a/src/tyro/_docstrings.py b/src/tyro/_docstrings.py index 3d461927..dd10168e 100644 --- a/src/tyro/_docstrings.py +++ b/src/tyro/_docstrings.py @@ -301,7 +301,6 @@ def get_callable_description(f: Callable) -> str: the fields of the class if a docstring is not specified; this helper will ignore these docstrings.""" - f, _unused = _resolver.resolve_generic_types(f) f = _resolver.unwrap_origin_strip_extras(f) if f in _callable_description_blocklist: return "" diff --git a/src/tyro/_fields.py b/src/tyro/_fields.py index a9e59b14..bb098b1f 100644 --- a/src/tyro/_fields.py +++ b/src/tyro/_fields.py @@ -44,6 +44,7 @@ is_typeddict, ) +from . import conf # Avoid circular import. from . import ( _docstrings, _instantiators, @@ -51,7 +52,6 @@ _singleton, _strings, _unsafe_cache, - conf, # Avoid circular import. ) from ._typing import TypeForm from .conf import _confstruct, _markers @@ -109,9 +109,6 @@ def make( *, markers: Tuple[_markers.Marker, ...] = (), ): - # Resolve generic aliases. - type_or_callable = _resolver.apply_type_from_typevar(type_or_callable, {}) - # Try to extract argconf overrides from type. _, argconfs = _resolver.unwrap_annotated_and_aliases( type_or_callable, _confstruct._ArgConfiguration @@ -147,6 +144,38 @@ def make( for context_markers in _field_context_markers: markers += context_markers + # Type resolution. + type_or_callable = _resolver.type_from_typevar_constraints(type_or_callable) + type_or_callable = _resolver.narrow_collection_types(type_or_callable, default) + type_or_callable = _resolver.narrow_union_type(type_or_callable, default) + type_or_callable = _resolver.apply_type_shims(type_or_callable) + + # Check that the default value matches the final resolved type. + # There's some similar Union-specific logic for this in narrow_union_type(). We + # may be able to consolidate this. + if ( + # Be relatively conservative: isinstance() can be checked on non-type + # types (like unions in Python >=3.10), but we'll only consider single types + # for now. + type(type_or_callable) is type + and not isinstance(default, type_or_callable) # type: ignore + # If a custom constructor is set, type_or_callable may not be + # matched to the annotated type. + and argconf.constructor_factory is None + and default not in DEFAULT_SENTINEL_SINGLETONS + # The numeric tower in Python is wacky. This logic is non-critical, so + # we'll just skip it (+the complexity) for numbers. + and not isinstance(default, numbers.Number) + ): + # If the default value doesn't match the resolved type, we expand the + # type. This is inspired by https://github.com/brentyi/tyro/issues/88. + warnings.warn( + f"The field {name} is annotated with type {type_or_callable}, " + f"but the default value {default} has type {type(default)}. " + f"We'll try to handle this gracefully, but it may cause unexpected behavior." + ) + type_or_callable = Union[type_or_callable, type(default)] # type: ignore + out = FieldDefinition( intern_name=name, extern_name=name if argconf.name is None else argconf.name, @@ -271,9 +300,7 @@ def field_list_from_callable( f: Union[Callable, TypeForm[Any]], default_instance: DefaultInstance, support_single_arg_types: bool, -) -> Tuple[ - Union[Callable, TypeForm[Any]], Dict[TypeVar, TypeForm], List[FieldDefinition] -]: +) -> Tuple[Union[Callable, TypeForm[Any]], List[FieldDefinition]]: """Generate a list of generic 'field' objects corresponding to the inputs of some annotated callable. @@ -283,7 +310,6 @@ def field_list_from_callable( A list of field definitions. """ # Resolve generic types. - f, type_from_typevar = _resolver.resolve_generic_types(f) f = _resolver.unwrap_newtype_and_narrow_subtypes(f, default_instance) # Try to generate field list. @@ -296,7 +322,6 @@ def field_list_from_callable( if support_single_arg_types: return ( f, - type_from_typevar, [ FieldDefinition( intern_name="value", @@ -317,46 +342,7 @@ def field_list_from_callable( else: raise _instantiators.UnsupportedTypeAnnotationError(field_list.message) - # Try to resolve types in our list of fields. - def resolve(field: FieldDefinition) -> FieldDefinition: - typ = field.type_or_callable - typ = _resolver.apply_type_from_typevar(typ, type_from_typevar) - typ = _resolver.type_from_typevar_constraints(typ) - typ = _resolver.narrow_collection_types(typ, field.default) - typ = _resolver.narrow_union_type(typ, field.default) - - # Check that the default value matches the final resolved type. - # There's some similar Union-specific logic for this in narrow_union_type(). We - # may be able to consolidate this. - if ( - # Be relatively conservative: isinstance() can be checked on non-type - # types (like unions in Python >=3.10), but we'll only consider single types - # for now. - type(typ) is type - and not isinstance(field.default, typ) # type: ignore - # If a custom constructor is set, field.type_or_callable may not be - # matched to the annotated type. - and not field.custom_constructor - and field.default not in DEFAULT_SENTINEL_SINGLETONS - # The numeric tower in Python is wacky. This logic is non-critical, so - # we'll just skip it (+the complexity) for numbers. - and not isinstance(field.default, numbers.Number) - ): - # If the default value doesn't match the resolved type, we expand the - # type. This is inspired by https://github.com/brentyi/tyro/issues/88. - warnings.warn( - f"The field {field.intern_name} is annotated with type {field.type_or_callable}, " - f"but the default value {field.default} has type {type(field.default)}. " - f"We'll try to handle this gracefully, but it may cause unexpected behavior." - ) - typ = Union[typ, type(field.default)] # type: ignore - - field = dataclasses.replace(field, type_or_callable=typ) - return field - - field_list = list(map(resolve, field_list)) - - return f, type_from_typevar, field_list + return f, field_list # Implementation details below. @@ -396,8 +382,6 @@ def _try_field_list_from_callable( default_instance = found_subcommand_configs[0].default # Unwrap generics. - f, type_from_typevar = _resolver.resolve_generic_types(f) - f = _resolver.apply_type_from_typevar(f, type_from_typevar) f = _resolver.unwrap_newtype_and_narrow_subtypes(f, default_instance) f = _resolver.narrow_collection_types(f, default_instance) f_origin = _resolver.unwrap_origin_strip_extras(cast(TypeForm, f)) @@ -445,7 +429,12 @@ def _try_field_list_from_callable( dict, ): return _field_list_from_dict(f, default_instance) - elif f_origin in (list, set, typing.Sequence, collections.abc.Sequence) or cls in ( + elif f_origin in ( + list, + set, + typing.Sequence, + collections.abc.Sequence, + ) or cls in ( list, set, typing.Sequence, diff --git a/src/tyro/_instantiators.py b/src/tyro/_instantiators.py index af1b40b7..5b0cb369 100644 --- a/src/tyro/_instantiators.py +++ b/src/tyro/_instantiators.py @@ -129,10 +129,6 @@ def is_type_string_converter(typ: Union[Callable, TypeForm[Any]]) -> bool: # Some checks we can do if the signature is available! for i, param in enumerate(signature.parameters.values()): annotation = type_annotations.get(param.name, param.annotation) - - # Hack: apply_type_from_typevar applies shims, like UnionType => Union - # conversion. - annotation = _resolver.apply_type_from_typevar(annotation, {}) if i == 0 and not ( (get_origin(annotation) is Union and str in get_args(annotation)) or annotation in (str, inspect.Parameter.empty) @@ -153,7 +149,6 @@ def is_type_string_converter(typ: Union[Callable, TypeForm[Any]]) -> bool: def instantiator_from_type( typ: Union[TypeForm[Any], Callable], - type_from_typevar: Dict[TypeVar, TypeForm[Any]], markers: Set[_markers.Marker], ) -> Tuple[Instantiator, InstantiatorMetadata]: """Recursive helper for parsing type annotations. @@ -201,9 +196,7 @@ def instantiator(strings: List[str]) -> None: # Address container types. If a matching container is found, this will recursively # call instantiator_from_type(). - container_out = _instantiator_from_container_type( - cast(TypeForm[Any], typ), type_from_typevar, markers - ) + container_out = _instantiator_from_container_type(cast(TypeForm[Any], typ), markers) if container_out is not None: return container_out @@ -335,7 +328,6 @@ def instantiator_base_case(strings: List[str]) -> Any: @overload def _instantiator_from_type_inner( typ: TypeForm, - type_from_typevar: Dict[TypeVar, TypeForm[Any]], allow_sequences: Literal["fixed_length"], markers: Set[_markers.Marker], ) -> Tuple[Instantiator, InstantiatorMetadata]: ... @@ -344,7 +336,6 @@ def _instantiator_from_type_inner( @overload def _instantiator_from_type_inner( typ: TypeForm, - type_from_typevar: Dict[TypeVar, TypeForm[Any]], allow_sequences: Literal[False], markers: Set[_markers.Marker], ) -> Tuple[_StandardInstantiator, InstantiatorMetadata]: ... @@ -353,7 +344,6 @@ def _instantiator_from_type_inner( @overload def _instantiator_from_type_inner( typ: TypeForm, - type_from_typevar: Dict[TypeVar, TypeForm[Any]], allow_sequences: Literal[True], markers: Set[_markers.Marker], ) -> Tuple[Instantiator, InstantiatorMetadata]: ... @@ -361,13 +351,12 @@ def _instantiator_from_type_inner( def _instantiator_from_type_inner( typ: TypeForm, - type_from_typevar: Dict[TypeVar, TypeForm[Any]], allow_sequences: Literal["fixed_length", True, False], markers: Set[_markers.Marker], ) -> Tuple[Instantiator, InstantiatorMetadata]: """Thin wrapper over instantiator_from_type, with some extra asserts for catching errors.""" - out = instantiator_from_type(typ, type_from_typevar, markers) + out = instantiator_from_type(typ, markers) if out[1].nargs == "*": # We currently only use allow_sequences=False for options in Literal types, # which are evaluated using `type()`. It should not be possible to hit this @@ -384,7 +373,6 @@ def _instantiator_from_type_inner( def _instantiator_from_container_type( typ: TypeForm[Any], - type_from_typevar: Dict[TypeVar, TypeForm[Any]], markers: Set[_markers.Marker], ) -> Optional[Tuple[Instantiator, InstantiatorMetadata]]: """Attempt to create an instantiator from a container type. Returns `None` if no @@ -418,13 +406,12 @@ def _instantiator_from_container_type( _instantiator_from_literal: (Literal, LiteralAlternate), }.items(): if type_origin in matched_origins: - return make(typ, type_from_typevar, markers) + return make(typ, markers) return None def _instantiator_from_tuple( typ: TypeForm, - type_from_typevar: Dict[TypeVar, TypeForm[Any]], markers: Set[_markers.Marker], ) -> Tuple[Instantiator, InstantiatorMetadata]: types = get_args(typ) @@ -435,7 +422,7 @@ def _instantiator_from_tuple( # Ellipsis: variable argument counts. When an ellipsis is used, tuples must # contain only one type. assert len(typeset_no_ellipsis) == 1 - return _instantiator_from_sequence(typ, type_from_typevar, markers) + return _instantiator_from_sequence(typ, markers) else: instantiators: List[_StandardInstantiator] = [] @@ -443,7 +430,7 @@ def _instantiator_from_tuple( nargs = 0 for t in types: a, b = _instantiator_from_type_inner( - t, type_from_typevar, allow_sequences="fixed_length", markers=markers + t, allow_sequences="fixed_length", markers=markers ) instantiators.append(a) # type: ignore metas.append(b) @@ -509,7 +496,6 @@ def _join_union_metavars(metavars: Iterable[str]) -> str: def _instantiator_from_union( typ: TypeForm, - type_from_typevar: Dict[TypeVar, TypeForm[Any]], markers: Set[_markers.Marker], ) -> Tuple[Instantiator, InstantiatorMetadata]: options = list(get_args(typ)) @@ -528,9 +514,7 @@ def _instantiator_from_union( nargs: Optional[Union[int, Literal["*"]]] = 1 first = True for t in options: - a, b = _instantiator_from_type_inner( - t, type_from_typevar, allow_sequences=True, markers=markers - ) + a, b = _instantiator_from_type_inner(t, allow_sequences=True, markers=markers) instantiators.append(a) metas.append(b) if b.choices is None: @@ -591,18 +575,16 @@ def union_instantiator(strings: List[str]) -> Any: def _instantiator_from_dict( typ: TypeForm, - type_from_typevar: Dict[TypeVar, TypeForm[Any]], markers: Set[_markers.Marker], ) -> Tuple[Instantiator, InstantiatorMetadata]: key_type, val_type = get_args(typ) key_instantiator, key_meta = _instantiator_from_type_inner( - key_type, type_from_typevar, allow_sequences="fixed_length", markers=markers + key_type, allow_sequences="fixed_length", markers=markers ) if _markers.UseAppendAction in markers: val_instantiator, val_meta = _instantiator_from_type_inner( val_type, - type_from_typevar, allow_sequences=True, markers=markers - {_markers.UseAppendAction}, ) @@ -625,7 +607,7 @@ def append_dict_instantiator(strings: List[List[str]]) -> Any: ) else: val_instantiator, val_meta = _instantiator_from_type_inner( - val_type, type_from_typevar, allow_sequences="fixed_length", markers=markers + val_type, allow_sequences="fixed_length", markers=markers ) pair_metavar = f"{key_meta.metavar} {val_meta.metavar}" key_nargs = cast(int, key_meta.nargs) # Casts needed for mypy but not pyright! @@ -673,7 +655,6 @@ def dict_instantiator(strings: List[str]) -> Any: def _instantiator_from_sequence( typ: TypeForm, - type_from_typevar: Dict[TypeVar, TypeForm[Any]], markers: Set[_markers.Marker], ) -> Tuple[Instantiator, InstantiatorMetadata]: """Instantiator for variable-length sequences: list, sets, Tuple[T, ...], etc.""" @@ -691,7 +672,6 @@ def _instantiator_from_sequence( if _markers.UseAppendAction in markers: make, inner_meta = _instantiator_from_type_inner( contained_type, - type_from_typevar, allow_sequences=True, markers=markers - {_markers.UseAppendAction}, ) @@ -709,7 +689,6 @@ def append_sequence_instantiator(strings: List[List[str]]) -> Any: else: make, inner_meta = _instantiator_from_type_inner( contained_type, - type_from_typevar, allow_sequences="fixed_length", markers=markers, ) @@ -743,7 +722,6 @@ def sequence_instantiator(strings: List[str]) -> Any: def _instantiator_from_literal( typ: TypeForm, - type_from_typevar: Dict[TypeVar, TypeForm[Any]], markers: Set[_markers.Marker], ) -> Tuple[_StandardInstantiator, InstantiatorMetadata]: choices = get_args(typ) diff --git a/src/tyro/_parsers.py b/src/tyro/_parsers.py index 47b46147..59364049 100644 --- a/src/tyro/_parsers.py +++ b/src/tyro/_parsers.py @@ -80,7 +80,7 @@ def from_callable_or_type( # Resolve the type of `f`, generate a field list. with _fields.FieldDefinition.marker_context(markers): - f, type_from_typevar, field_list = _fields.field_list_from_callable( + f, field_list = _fields.field_list_from_callable( f=f, default_instance=default_instance, support_single_arg_types=support_single_arg_types, @@ -112,7 +112,6 @@ def from_callable_or_type( for field in field_list: field_out = handle_field( field, - type_from_typevar=type_from_typevar, parent_classes=parent_classes, intern_prefix=intern_prefix, extern_prefix=extern_prefix, @@ -287,7 +286,6 @@ def format_group_name(prefix: str) -> str: def handle_field( field: _fields.FieldDefinition, - type_from_typevar: Dict[TypeVar, TypeForm[Any]], parent_classes: Set[Type[Any]], intern_prefix: str, extern_prefix: str, @@ -308,7 +306,6 @@ def handle_field( # (1) Handle Unions over callables; these result in subparsers. subparsers_attempt = SubparsersSpecification.from_field( field, - type_from_typevar=type_from_typevar, parent_classes=parent_classes, intern_prefix=_strings.make_field_name([intern_prefix, field.intern_name]), extern_prefix=_strings.make_field_name([extern_prefix, field.extern_name]), @@ -362,7 +359,6 @@ def handle_field( extern_prefix=extern_prefix, subcommand_prefix=subcommand_prefix, field=field, - type_from_typevar=type_from_typevar, ) @@ -381,7 +377,6 @@ class SubparsersSpecification: @staticmethod def from_field( field: _fields.FieldDefinition, - type_from_typevar: Dict[TypeVar, TypeForm[Any]], parent_classes: Set[Type[Any]], intern_prefix: str, extern_prefix: str, @@ -393,10 +388,7 @@ def from_field( # We don't use sets here to retain order of subcommands. options: List[Union[type, Callable]] - options = [ - _resolver.apply_type_from_typevar(typ, type_from_typevar) - for typ in get_args(typ) - ] + options = [typ for typ in get_args(typ)] options = [ ( # Cast seems unnecessary but needed in mypy... (1.4.1) diff --git a/src/tyro/_resolver.py b/src/tyro/_resolver.py index b1b76467..6bf07b58 100644 --- a/src/tyro/_resolver.py +++ b/src/tyro/_resolver.py @@ -1,6 +1,7 @@ """Utilities for resolving types and forward references.""" import collections.abc +import contextlib import copy import dataclasses import inspect @@ -61,63 +62,6 @@ def is_dataclass(cls: Union[TypeForm, Callable]) -> bool: return dataclasses.is_dataclass(unwrap_origin_strip_extras(cls)) # type: ignore -def resolve_generic_types( - cls: TypeOrCallable, -) -> Tuple[TypeOrCallable, Dict[TypeVar, TypeForm[Any]]]: - """If the input is a class: no-op. If it's a generic alias: returns the origin - class, and a mapping from typevars to concrete types.""" - - annotations: Tuple[Any, ...] = () - if get_origin(cls) is Annotated: - # ^We need this `if` statement for an obscure edge case: when `cls` is a - # function with `__tyro_markers__` set, we don't want/need to return - # Annotated[func, markers]. - cls, annotations = unwrap_annotated_and_aliases(cls, "all") - - # We'll ignore NewType when getting the origin + args for generics. - origin_cls = get_origin(unwrap_newtype_and_aliases(cls)[0]) - type_from_typevar: Dict[TypeVar, TypeForm[Any]] = {} - - # Support typing.Self. - # We'll do this by pretending that `Self` is a TypeVar... - if hasattr(cls, "__self__"): - self_type = getattr(cls, "__self__") - if inspect.isclass(self_type): - type_from_typevar[cast(TypeVar, Self)] = self_type # type: ignore - else: - type_from_typevar[cast(TypeVar, Self)] = self_type.__class__ # type: ignore - - if ( - # Apply some heuristics for generic types. Should revisit this. - origin_cls is not None - and hasattr(origin_cls, "__parameters__") - and hasattr(origin_cls.__parameters__, "__len__") - ): - typevars = origin_cls.__parameters__ - typevar_values = get_args(unwrap_newtype_and_aliases(cls)[0]) - assert len(typevars) == len(typevar_values) - cls = origin_cls - type_from_typevar.update(dict(zip(typevars, typevar_values))) - - if hasattr(cls, "__orig_bases__"): - bases = getattr(cls, "__orig_bases__") - for base in bases: - origin_base = unwrap_origin_strip_extras(base) - if origin_base is base or not hasattr(origin_base, "__parameters__"): - continue - typevars = origin_base.__parameters__ - typevar_values = get_args(base) - type_from_typevar.update(dict(zip(typevars, typevar_values))) - - if len(annotations) == 0: - return cls, type_from_typevar - else: - return ( - Annotated.__class_getitem__((cls, *annotations)), # type: ignore - type_from_typevar, - ) - - @_unsafe_cache.unsafe_cache(maxsize=1024) def resolved_fields(cls: TypeForm) -> List[dataclasses.Field]: """Similar to dataclasses.fields(), but includes dataclasses.InitVar types and @@ -389,63 +333,122 @@ def unwrap_annotated_and_aliases( return args[0], targets # type: ignore -def apply_type_from_typevar( - typ: TypeOrCallable, type_from_typevar: Dict[TypeVar, TypeForm[Any]] -) -> TypeOrCallable: - GenericAlias = getattr(types, "GenericAlias", None) - if ( - GenericAlias is not None - and isinstance(typ, GenericAlias) - and len(getattr(typ, "__type_params__", ())) > 0 - ): - type_from_typevar = type_from_typevar.copy() - for k, v in zip(typ.__type_params__, typ.__args__): # type: ignore - type_from_typevar[k] = v # type: ignore - typ = typ.__value__ # type: ignore +class TypeParameterResolver: + param_assignments: List[Dict[TypeVar, TypeForm[Any]]] = [] - if typ in type_from_typevar: - return type_from_typevar[typ] # type: ignore + @classmethod + @contextlib.contextmanager + def resolve_context(cls, typ: TypeOrCallable): + """Context manager for resolving type parameters.""" - origin = get_origin(typ) - args = get_args(typ) - if len(args) > 0: - if origin is Annotated: - args = args[:1] - if origin is collections.abc.Callable: - assert isinstance(args[0], list) - args = tuple(args[0]) + args[1:] - - # Convert Python 3.9 and 3.10 types to their typing library equivalents, which - # support `.copy_with()`. This is not really the right place for this logic... - if sys.version_info[:2] >= (3, 9): - shim_table = { - # PEP 585. Requires Python 3.9. - tuple: Tuple, - list: List, - dict: Dict, - set: Set, - frozenset: FrozenSet, - type: Type, - } - if hasattr(types, "UnionType"): # type: ignore - # PEP 604. Requires Python 3.10. - shim_table[types.UnionType] = Union # type: ignore - - for new, old in shim_table.items(): - if origin is new: # type: ignore - typ = old.__getitem__(args) # type: ignore - - new_args = tuple(apply_type_from_typevar(x, type_from_typevar) for x in args) - - # Standard generic aliases have a `copy_with()`! - if hasattr(typ, "copy_with"): - return typ.copy_with(new_args) # type: ignore - else: - # `collections` types, like collections.abc.Sequence. - assert hasattr(origin, "__class_getitem__") - return origin.__class_getitem__(new_args) # type: ignore + annotations: Tuple[Any, ...] = () + if get_origin(typ) is Annotated: + # ^We need this if statement for an obscure edge case: when `typ` is a + # function with `__tyro_markers__` set, we don't want/need to return + # Annotated[func, markers]. + typ, annotations = unwrap_annotated_and_aliases(typ, "all") + + # We'll ignore NewType when getting the origin + args for generics. + origin_cls = get_origin(unwrap_newtype_and_aliases(typ)[0]) + type_from_typevar: Dict[TypeVar, TypeForm[Any]] = {} + + # Support typing.Self. + # We'll do this by pretending that `Self` is a TypeVar... + if hasattr(typ, "__self__"): + self_type = getattr(typ, "__self__") + if inspect.isclass(self_type): + type_from_typevar[cast(TypeVar, Self)] = self_type # type: ignore + else: + type_from_typevar[cast(TypeVar, Self)] = self_type.__class__ # type: ignore - return typ # type: ignore + if ( + # Apply some heuristics for generic types. Should revisit this. + origin_cls is not None + and hasattr(origin_cls, "__parameters__") + and hasattr(origin_cls.__parameters__, "__len__") + ): + typevars = origin_cls.__parameters__ + typevar_values = get_args(unwrap_newtype_and_aliases(typ)[0]) + assert len(typevars) == len(typevar_values) + typ = origin_cls + type_from_typevar.update(dict(zip(typevars, typevar_values))) + + if hasattr(typ, "__orig_bases__"): + bases = getattr(typ, "__orig_bases__") + for base in bases: + origin_base = unwrap_origin_strip_extras(base) + if origin_base is base or not hasattr(origin_base, "__parameters__"): + continue + typevars = origin_base.__parameters__ + typevar_values = get_args(base) + type_from_typevar.update(dict(zip(typevars, typevar_values))) + + GenericAlias = getattr(types, "GenericAlias", None) + if ( + GenericAlias is not None + and isinstance(typ, GenericAlias) + and len(getattr(typ, "__type_params__", ())) > 0 + ): + for k, v in zip(typ.__type_params__, typ.__args__): # type: ignore + type_from_typevar[k] = v # type: ignore + typ = typ.__value__ # type: ignore + + cls.param_assignments.append(type_from_typevar) + + # Apply the TypeVar assignments. + typ = TypeParameterResolver.apply_param_assignments(typ) + assert type(typ) is not list + yield ( + typ + if len(annotations) == 0 + else Annotated.__class_getitem__( # type: ignore + ( + typ, + *annotations, + ) + ) + ) + + cls.param_assignments.pop() + + @staticmethod + def apply_param_assignments(typ: TypeOrCallable) -> TypeOrCallable: + for type_from_typevar in reversed(TypeParameterResolver.param_assignments): + if typ in type_from_typevar: + return type_from_typevar[typ] # type: ignore + + origin = get_origin(typ) + args = get_args(typ) + if len(args) > 0: + if origin is Annotated: + args = args[:1] + if origin is collections.abc.Callable: + assert isinstance(args[0], list) + args = tuple(args[0]) + args[1:] + + new_args = [] + for x in args: + for type_from_typevar in reversed( + TypeParameterResolver.param_assignments + ): + if x in type_from_typevar: + x = type_from_typevar[x] + break + new_args.append(x) + + new_args = tuple( + TypeParameterResolver.apply_param_assignments(x) for x in args + ) + + # Standard generic aliases have a `copy_with()`! + if hasattr(typ, "copy_with"): + return typ.copy_with(new_args) # type: ignore + else: + # `collections` types, like collections.abc.Sequence. + assert hasattr(origin, "__class_getitem__") + return origin.__class_getitem__(new_args) # type: ignore + + return typ # type: ignore @_unsafe_cache.unsafe_cache(maxsize=1024) @@ -483,6 +486,94 @@ def narrow_union_type(typ: TypeOrCallable, default_instance: Any) -> TypeOrCalla NoneType = type(None) +def resolve_generic_types( + cls: TypeOrCallable, +) -> Tuple[TypeOrCallable, Dict[TypeVar, TypeForm[Any]]]: + """If the input is a class: no-op. If it's a generic alias: returns the origin + class, and a mapping from typevars to concrete types.""" + + annotations: Tuple[Any, ...] = () + if get_origin(cls) is Annotated: + # ^We need this `if` statement for an obscure edge case: when `cls` is a + # function with `__tyro_markers__` set, we don't want/need to return + # Annotated[func, markers]. + cls, annotations = unwrap_annotated_and_aliases(cls, "all") + + # We'll ignore NewType when getting the origin + args for generics. + origin_cls = get_origin(unwrap_newtype_and_aliases(cls)[0]) + type_from_typevar: Dict[TypeVar, TypeForm[Any]] = {} + + # Support typing.Self. + # We'll do this by pretending that `Self` is a TypeVar... + if hasattr(cls, "__self__"): + self_type = getattr(cls, "__self__") + if inspect.isclass(self_type): + type_from_typevar[cast(TypeVar, Self)] = self_type # type: ignore + else: + type_from_typevar[cast(TypeVar, Self)] = self_type.__class__ # type: ignore + + if ( + # Apply some heuristics for generic types. Should revisit this. + origin_cls is not None + and hasattr(origin_cls, "__parameters__") + and hasattr(origin_cls.__parameters__, "__len__") + ): + typevars = origin_cls.__parameters__ + typevar_values = get_args(unwrap_newtype_and_aliases(cls)[0]) + assert len(typevars) == len(typevar_values) + cls = origin_cls + type_from_typevar.update(dict(zip(typevars, typevar_values))) + + if hasattr(cls, "__orig_bases__"): + bases = getattr(cls, "__orig_bases__") + for base in bases: + origin_base = unwrap_origin_strip_extras(base) + if origin_base is base or not hasattr(origin_base, "__parameters__"): + continue + typevars = origin_base.__parameters__ + typevar_values = get_args(base) + type_from_typevar.update(dict(zip(typevars, typevar_values))) + + if len(annotations) == 0: + return cls, type_from_typevar + else: + return ( + Annotated.__class_getitem__((cls, *annotations)), # type: ignore + type_from_typevar, + ) + + +def apply_type_shims(typ: TypeOrCallable) -> TypeOrCallable: + """Apply shims to types to support older Python versions.""" + origin = get_origin(typ) + args = get_args(typ) + + if origin is None or len(args) == 0: + return typ + + # Convert Python 3.9 and 3.10 types to their typing library equivalents, which + # support `.copy_with()`. This is not really the right place for this logic... + if sys.version_info[:2] >= (3, 9): + shim_table = { + # PEP 585. Requires Python 3.9. + tuple: Tuple, + list: List, + dict: Dict, + set: Set, + frozenset: FrozenSet, + type: Type, + } + if hasattr(types, "UnionType"): # type: ignore + # PEP 604. Requires Python 3.10. + shim_table[types.UnionType] = Union # type: ignore + + for new, old in shim_table.items(): + if origin is new: # type: ignore + typ = old.__getitem__(args) # type: ignore + continue + return typ + + def get_type_hints_with_backported_syntax( obj: Callable[..., Any], include_extras: bool = False ) -> Dict[str, Any]: diff --git a/src/tyro/_subcommand_matching.py b/src/tyro/_subcommand_matching.py index ff21e481..02daccc5 100644 --- a/src/tyro/_subcommand_matching.py +++ b/src/tyro/_subcommand_matching.py @@ -72,7 +72,7 @@ def make( ) -> _TypeTree: """From an object instance, return a data structure representing the types in the object.""" try: - typ, _type_from_typevar, field_list = _fields.field_list_from_callable( + typ, field_list = _fields.field_list_from_callable( typ, default_instance=default_instance, support_single_arg_types=False ) except _instantiators.UnsupportedTypeAnnotationError: diff --git a/src/tyro/extras/_serialization.py b/src/tyro/extras/_serialization.py index ae77abb4..7d70cf55 100644 --- a/src/tyro/extras/_serialization.py +++ b/src/tyro/extras/_serialization.py @@ -36,7 +36,6 @@ def _get_contained_special_types_from_type( ) cls = _resolver.unwrap_annotated_and_aliases(cls) - cls, type_from_typevar = _resolver.resolve_generic_types(cls) contained_special_types = {cls} @@ -56,13 +55,6 @@ def handle_type(typ: Type[Any]) -> Set[Type[Any]]: # Handle Union, Annotated, List, etc. No-op when there are no args. return functools.reduce(set.union, map(handle_type, get_args(typ)), set()) - # Handle generics. - for typ in type_from_typevar.values(): - contained_special_types |= handle_type(typ) - - if cls in parent_contained_dataclasses: - return contained_special_types - # Handle fields. for field in _resolver.resolved_fields(cls): # type: ignore assert not isinstance(field.type, str)