diff --git a/src/tyro/_fields.py b/src/tyro/_fields.py index dd503b53..a9e59b14 100644 --- a/src/tyro/_fields.py +++ b/src/tyro/_fields.py @@ -109,6 +109,9 @@ def make( *, markers: Tuple[_markers.Marker, ...] = (), ): + # Resolve generic aliases. + type_or_callable = _resolver.apply_type_from_typevar(type_or_callable, {}) + # Try to extract argconf overrides from type. _, argconfs = _resolver.unwrap_annotated_and_aliases( type_or_callable, _confstruct._ArgConfiguration diff --git a/src/tyro/_resolver.py b/src/tyro/_resolver.py index 25f23135..b1b76467 100644 --- a/src/tyro/_resolver.py +++ b/src/tyro/_resolver.py @@ -392,6 +392,17 @@ def unwrap_annotated_and_aliases( def apply_type_from_typevar( typ: TypeOrCallable, type_from_typevar: Dict[TypeVar, TypeForm[Any]] ) -> TypeOrCallable: + GenericAlias = getattr(types, "GenericAlias", None) + if ( + GenericAlias is not None + and isinstance(typ, GenericAlias) + and len(getattr(typ, "__type_params__", ())) > 0 + ): + type_from_typevar = type_from_typevar.copy() + for k, v in zip(typ.__type_params__, typ.__args__): # type: ignore + type_from_typevar[k] = v # type: ignore + typ = typ.__value__ # type: ignore + if typ in type_from_typevar: return type_from_typevar[typ] # type: ignore @@ -405,7 +416,7 @@ def apply_type_from_typevar( args = tuple(args[0]) + args[1:] # Convert Python 3.9 and 3.10 types to their typing library equivalents, which - # support `.copy_with()`. + # support `.copy_with()`. This is not really the right place for this logic... if sys.version_info[:2] >= (3, 9): shim_table = { # PEP 585. Requires Python 3.9. @@ -434,7 +445,7 @@ def apply_type_from_typevar( assert hasattr(origin, "__class_getitem__") return origin.__class_getitem__(new_args) # type: ignore - return typ + return typ # type: ignore @_unsafe_cache.unsafe_cache(maxsize=1024) diff --git a/tests/test_new_style_annotations_min_py312.py b/tests/test_new_style_annotations_min_py312.py index 9189e585..120d5d09 100644 --- a/tests/test_new_style_annotations_min_py312.py +++ b/tests/test_new_style_annotations_min_py312.py @@ -64,3 +64,16 @@ class Container: a: AnnotatedBasic assert tyro.cli(Container, args="--basic 1".split(" ")) == Container(1) + + +type TT[T] = Annotated[T, tyro.conf.arg(name="", constructor=lambda: True)] + + +def test_pep695_generic_alias() -> None: + """Adapted from: https://github.com/brentyi/tyro/issues/177""" + + @dataclass(frozen=True) + class Config: + arg: TT[bool] + + assert tyro.cli(Config, args=[]) == Config(arg=True)