diff --git a/pyproject.toml b/pyproject.toml index a35f7a24..dc8989ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tyro" -version = "0.3.26" +version = "0.3.27" description = "Strongly typed, zero-effort CLI interfaces" authors = ["brentyi "] include = ["./tyro/**/*"] diff --git a/tyro/_parsers.py b/tyro/_parsers.py index ad5aa97b..62cb0a18 100644 --- a/tyro/_parsers.py +++ b/tyro/_parsers.py @@ -83,7 +83,7 @@ def from_callable_or_type( field = dataclasses.replace( field, # Resolve generic types. - typ=_resolver.narrow_type( + typ=_resolver.narrow_container_types( _resolver.type_from_typevar_constraints( # type: ignore _resolver.apply_type_from_typevar( field.typ, diff --git a/tyro/_resolver.py b/tyro/_resolver.py index ec46f106..4c3c2ea1 100644 --- a/tyro/_resolver.py +++ b/tyro/_resolver.py @@ -2,7 +2,6 @@ import collections.abc import copy import dataclasses -import pathlib import sys from typing import ( Any, @@ -118,14 +117,10 @@ def narrow_type(typ: TypeT, default_instance: Any) -> TypeT: """Type narrowing: if we annotate as Animal but specify a default instance of Cat, we should parse as Cat. - Note that Union types are intentionally excluded here.""" - - # Don't apply narrowing for pathlib.PosixPath, pathlib.WindowsPath, etc. - # This is mostly an aesthetic decision, and is needed because pathlib.Path() hacks - # __new__ to dynamically choose between path types. - if typ is pathlib.Path: - return typ - + This should generally only be applied to fields used as nested structures, not + individual arguments/fields. (if a field is annotated as Union[int, str], and a + string default is passed in, we don't want to narrow the type to always be + strings!)""" try: potential_subclass = type(default_instance) @@ -135,6 +130,11 @@ def narrow_type(typ: TypeT, default_instance: Any) -> TypeT: return typ superclass = unwrap_annotated(typ)[0] + + # For Python 3.10: don't narrow union types. + if get_origin(superclass) is Union: + return typ + if superclass is Any or issubclass(potential_subclass, superclass): # type: ignore if get_origin(typ) is Annotated: return Annotated.__class_getitem__( # type: ignore @@ -144,13 +144,17 @@ def narrow_type(typ: TypeT, default_instance: Any) -> TypeT: except TypeError: pass + return typ + + +def narrow_container_types(typ: TypeT, default_instance: Any) -> TypeT: + """Type narrowing for containers. Infers types of container contents.""" if typ is list and isinstance(default_instance, list): typ = List.__getitem__(Union.__getitem__(tuple(map(type, default_instance)))) # type: ignore elif typ is set and isinstance(default_instance, set): typ = Set.__getitem__(Union.__getitem__(tuple(map(type, default_instance)))) # type: ignore elif typ is tuple and isinstance(default_instance, tuple): typ = Tuple.__getitem__(tuple(map(type, default_instance))) # type: ignore - return typ