diff --git a/src/tyro/_resolver.py b/src/tyro/_resolver.py index 3fa2ea06..2473447f 100644 --- a/src/tyro/_resolver.py +++ b/src/tyro/_resolver.py @@ -18,6 +18,7 @@ Sequence, Set, Tuple, + Type, TypeVar, Union, cast, @@ -473,7 +474,8 @@ def narrow_union_type(typ: TypeOrCallable, default_instance: Any) -> TypeOrCalla try: if default_instance not in MISSING_SINGLETONS and not any( - isinstance(default_instance, o) for o in options_unwrapped + isinstance_with_fuzzy_numeric_tower(default_instance, o) is not False + for o in options_unwrapped ): warnings.warn( f"{type(default_instance)} does not match any type in Union:" @@ -486,6 +488,40 @@ def narrow_union_type(typ: TypeOrCallable, default_instance: Any) -> TypeOrCalla return typ +def isinstance_with_fuzzy_numeric_tower( + obj: Any, classinfo: Type +) -> Union[bool, Literal["~"]]: + """ + Enhanced version of isinstance() that returns: + - True: if object is exactly of the specified type + - "~": if object follows numeric tower rules but isn't exact type + - False: if object is not of the specified type or numeric tower rules don't apply + + Examples: + >>> enhanced_isinstance(3, int) # Returns True + >>> enhanced_isinstance(3, float) # Returns "~" + >>> enhanced_isinstance(True, int) # Returns "~" + >>> enhanced_isinstance(3, bool) # Returns False + >>> enhanced_isinstance(True, bool) # Returns True + """ + # Handle exact match first + if isinstance(obj, classinfo): + return True + + # Handle numeric tower cases + if isinstance(obj, bool): + if classinfo in (int, float, complex): + return "~" + elif isinstance(obj, int) and not isinstance(obj, bool): # explicit bool check + if classinfo in (float, complex): + return "~" + elif isinstance(obj, float): + if classinfo is complex: + return "~" + + return False + + NoneType = type(None) diff --git a/src/tyro/constructors/_primitive_spec.py b/src/tyro/constructors/_primitive_spec.py index 729c2999..8c113d6e 100644 --- a/src/tyro/constructors/_primitive_spec.py +++ b/src/tyro/constructors/_primitive_spec.py @@ -99,10 +99,15 @@ class PrimitiveConstructorSpec(Generic[T]): instance_from_str: Callable[[list[str]], T] """Given a list of string arguments, construct an instance of the type. The length of the list will match the value of nargs.""" - is_instance: Callable[[Any], bool] + is_instance: Callable[[Any], bool | Literal["~"]] """Given an object instance, does it match this primitive type? This is used for specific help messages when both a union type is present and a - default is provided.""" + default is provided. + + Can return "~" to signify that an instance is a "fuzzy" match, and should + only be used if there are no other matches. This is used for numeric tower + support. + """ str_from_instance: Callable[[T], list[str]] """Convert an instance to a list of string arguments that would construct the instance. This is used for help messages when a default is provided.""" @@ -124,11 +129,12 @@ def any_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: return None raise UnsupportedTypeAnnotationError("`Any` is not a parsable type.") - # HACK: this is for code that uses `tyro.conf.arg(constructor=json.loads)`. - # We're going to deprecate this syntax (the constructor= argument in - # tyro.conf.arg), but there is code that lives in the wild that relies - # on the behavior so we'll do our best not to break it. - vanilla_types = (int, str, float, bytes, json.loads) + # HACK (json.loads): this is for code that uses + # `tyro.conf.arg(constructor=json.loads)`. We're going to deprecate this + # syntax (the constructor= argument in tyro.conf.arg), but there is code + # that lives in the wild that relies on the behavior so we'll do our best + # not to break it. + vanilla_types = (int, str, float, complex, bytes, bytearray, json.loads) @registry.primitive_rule def basics_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: @@ -142,10 +148,11 @@ def basics_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None if type_info.type is bytes else type_info.type(args[0]) ), - # Numeric tower in Python is weird... - is_instance=lambda x: isinstance(x, (int, float)) - if type_info.type is float - else isinstance(x, type_info.type), + # issubclass(type(x), y) here is preferable over isinstance(x, y) + # due to quirks in the numeric tower. + is_instance=lambda x: _resolver.isinstance_with_fuzzy_numeric_tower( + x, type_info.type + ), str_from_instance=lambda instance: [str(instance)], ) @@ -582,7 +589,7 @@ def union_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: # General unions, eg Union[int, bool]. We'll try to convert these from left to # right. - option_specs = [] + option_specs: list[PrimitiveConstructorSpec] = [] choices: tuple[str, ...] | None = () nargs: int | Literal["*"] = 1 first = True @@ -646,9 +653,18 @@ def union_instantiator(strings: List[str]) -> Any: ) def str_from_instance(instance: Any) -> List[str]: + fuzzy_match = None for option_spec in option_specs: - if option_spec.is_instance(instance): + is_instance = option_spec.is_instance(instance) + if is_instance is True: return option_spec.str_from_instance(instance) + elif is_instance == "~": + fuzzy_match = option_spec + + # If we get here, we have a fuzzy match. + if fuzzy_match is not None: + return fuzzy_match.str_from_instance(instance) + assert False, f"could not match default value {instance} with any types in union {options}" return PrimitiveConstructorSpec( diff --git a/tests/test_dcargs.py b/tests/test_dcargs.py index 43214088..db11d1a1 100644 --- a/tests/test_dcargs.py +++ b/tests/test_dcargs.py @@ -942,3 +942,24 @@ def main(dt: datetime.time) -> datetime.time: # Invalid hour value. with pytest.raises(SystemExit): tyro.cli(main, args=["--dt", "25:00:00"]) + + +def test_numeric_tower() -> None: + @dataclasses.dataclass(frozen=True) + class NumericTower: + a: complex | str = 3.0 + b: bytearray | str = dataclasses.field( + default_factory=lambda: bytearray(b"123") + ) + c: complex | str = True + d: int | complex = False + e: float | str = 3 + + assert tyro.cli(NumericTower, args=[]) == NumericTower(3.0) + assert tyro.cli(NumericTower, args="--a 1+3j".split(" ")) == NumericTower(1 + 3j) + assert tyro.cli(NumericTower, args="--c False".split(" ")) == NumericTower( + c="False" + ) + assert tyro.cli(NumericTower, args="--e 3.2".split(" ")) == NumericTower(e=3.2) + with pytest.raises(SystemExit): + tyro.cli(NumericTower, args="--d False".split(" ")) diff --git a/tests/test_py311_generated/test_collections_generated.py b/tests/test_py311_generated/test_collections_generated.py index 9aca7e8a..5810a021 100644 --- a/tests/test_py311_generated/test_collections_generated.py +++ b/tests/test_py311_generated/test_collections_generated.py @@ -494,7 +494,7 @@ def main(x: Dict = {"int": 5, "str": "5"}): def test_dict_optional() -> None: # In this case, the `None` is ignored. - def main(x: Optional[Dict[str, int]] = {"three": 3, "five": 5}): + def main(x: Optional[Dict[str, float]] = {"three": 3, "five": 5}): return x assert tyro.cli(main, args=[]) == {"three": 3, "five": 5} diff --git a/tests/test_py311_generated/test_dcargs_generated.py b/tests/test_py311_generated/test_dcargs_generated.py index 8952be3e..7f45093b 100644 --- a/tests/test_py311_generated/test_dcargs_generated.py +++ b/tests/test_py311_generated/test_dcargs_generated.py @@ -944,3 +944,24 @@ def main(dt: datetime.time) -> datetime.time: # Invalid hour value. with pytest.raises(SystemExit): tyro.cli(main, args=["--dt", "25:00:00"]) + + +def test_numeric_tower() -> None: + @dataclasses.dataclass(frozen=True) + class NumericTower: + a: complex | str = 3.0 + b: bytearray | str = dataclasses.field( + default_factory=lambda: bytearray(b"123") + ) + c: complex | str = True + d: int | complex = False + e: float | str = 3 + + assert tyro.cli(NumericTower, args=[]) == NumericTower(3.0) + assert tyro.cli(NumericTower, args="--a 1+3j".split(" ")) == NumericTower(1 + 3j) + assert tyro.cli(NumericTower, args="--c False".split(" ")) == NumericTower( + c="False" + ) + assert tyro.cli(NumericTower, args="--e 3.2".split(" ")) == NumericTower(e=3.2) + with pytest.raises(SystemExit): + tyro.cli(NumericTower, args="--d False".split(" "))