diff --git a/src/tyro/constructors/_primitive_spec.py b/src/tyro/constructors/_primitive_spec.py index 80575e37..729c2999 100644 --- a/src/tyro/constructors/_primitive_spec.py +++ b/src/tyro/constructors/_primitive_spec.py @@ -142,7 +142,10 @@ def basics_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None if type_info.type is bytes else type_info.type(args[0]) ), - is_instance=lambda x: isinstance(x, type_info.type), + # Numeric tower in Python is weird... + is_instance=lambda x: isinstance(x, (int, float)) + if type_info.type is float + else isinstance(x, type_info.type), str_from_instance=lambda instance: [str(instance)], ) @@ -502,13 +505,11 @@ def instance_from_str(args: list[str]) -> dict: return out def str_from_instance(instance: dict) -> list[str]: + # TODO: this may be strange right now for the append action. out: list[str] = [] - assert ( - len(instance) == 0 - ), "When parsed as a primitive, we currrently assume all defaults are length=0. Dictionaries with non-zero-length defaults are interpreted as struct types." - # for key, value in instance.items(): - # out.extend(key_spec.str_from_instance(key)) - # out.extend(val_spec.str_from_instance(value)) + for key, value in instance.items(): + out.extend(key_spec.str_from_instance(key)) + out.extend(val_spec.str_from_instance(value)) return out if _markers.UseAppendAction in type_info.markers: diff --git a/src/tyro/constructors/_struct_spec.py b/src/tyro/constructors/_struct_spec.py index 6a9aeecd..06f0776f 100644 --- a/src/tyro/constructors/_struct_spec.py +++ b/src/tyro/constructors/_struct_spec.py @@ -274,6 +274,7 @@ def dict_rule(info: StructTypeInfo) -> StructConstructorSpec | None: if is_typeddict(info.type) or ( info.type not in ( + Dict, dict, collections.abc.Mapping, ) diff --git a/tests/test_collections.py b/tests/test_collections.py index 2f84ffef..300d1481 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -482,6 +482,26 @@ def main(x: Dict[str, Any] = {"int": 5, "str": "5"}): } +def test_dict_no_annotation_2() -> None: + def main(x: Dict = {"int": 5, "str": "5"}): + return x + + assert tyro.cli(main, args=[]) == {"int": 5, "str": "5"} + assert tyro.cli(main, args="--x.int 3 --x.str 7".split(" ")) == { + "int": 3, + "str": "7", + } + + +def test_dict_optional() -> None: + # In this case, the `None` is ignored. + def main(x: Optional[Dict[str, float]] = {"three": 3, "five": 5}): + return x + + assert tyro.cli(main, args=[]) == {"three": 3, "five": 5} + assert tyro.cli(main, args="--x 3 3 5 5".split(" ")) == {"3": 3, "5": 5} + + def test_double_dict_no_annotation() -> None: def main( x: Dict[str, Any] = { diff --git a/tests/test_py311_generated/test_collections_generated.py b/tests/test_py311_generated/test_collections_generated.py index 5b6e3d81..9aca7e8a 100644 --- a/tests/test_py311_generated/test_collections_generated.py +++ b/tests/test_py311_generated/test_collections_generated.py @@ -481,6 +481,26 @@ def main(x: Dict[str, Any] = {"int": 5, "str": "5"}): } +def test_dict_no_annotation_2() -> None: + def main(x: Dict = {"int": 5, "str": "5"}): + return x + + assert tyro.cli(main, args=[]) == {"int": 5, "str": "5"} + assert tyro.cli(main, args="--x.int 3 --x.str 7".split(" ")) == { + "int": 3, + "str": "7", + } + + +def test_dict_optional() -> None: + # In this case, the `None` is ignored. + def main(x: Optional[Dict[str, int]] = {"three": 3, "five": 5}): + return x + + assert tyro.cli(main, args=[]) == {"three": 3, "five": 5} + assert tyro.cli(main, args="--x 3 3 5 5".split(" ")) == {"3": 3, "5": 5} + + def test_double_dict_no_annotation() -> None: def main( x: Dict[str, Any] = {