Skip to content

Commit

Permalink
Fix edge cases from new primitive API
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Nov 4, 2024
1 parent 6552337 commit ab6914f
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/tyro/constructors/_primitive_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ def basics_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None
if type_info.type is bytes
else type_info.type(args[0])
),
is_instance=lambda x: isinstance(x, type_info.type),
# Numeric tower in Python is weird...
is_instance=lambda x: isinstance(x, (int, float))
if type_info.type is float
else isinstance(x, type_info.type),
str_from_instance=lambda instance: [str(instance)],
)

Expand Down Expand Up @@ -502,13 +505,11 @@ def instance_from_str(args: list[str]) -> dict:
return out

def str_from_instance(instance: dict) -> list[str]:
# TODO: this may be strange right now for the append action.
out: list[str] = []
assert (
len(instance) == 0
), "When parsed as a primitive, we currrently assume all defaults are length=0. Dictionaries with non-zero-length defaults are interpreted as struct types."
# for key, value in instance.items():
# out.extend(key_spec.str_from_instance(key))
# out.extend(val_spec.str_from_instance(value))
for key, value in instance.items():
out.extend(key_spec.str_from_instance(key))
out.extend(val_spec.str_from_instance(value))
return out

if _markers.UseAppendAction in type_info.markers:
Expand Down
1 change: 1 addition & 0 deletions src/tyro/constructors/_struct_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def dict_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
if is_typeddict(info.type) or (
info.type
not in (
Dict,
dict,
collections.abc.Mapping,
)
Expand Down
20 changes: 20 additions & 0 deletions tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,26 @@ def main(x: Dict[str, Any] = {"int": 5, "str": "5"}):
}


def test_dict_no_annotation_2() -> None:
def main(x: Dict = {"int": 5, "str": "5"}):
return x

assert tyro.cli(main, args=[]) == {"int": 5, "str": "5"}
assert tyro.cli(main, args="--x.int 3 --x.str 7".split(" ")) == {
"int": 3,
"str": "7",
}


def test_dict_optional() -> None:
# In this case, the `None` is ignored.
def main(x: Optional[Dict[str, float]] = {"three": 3, "five": 5}):
return x

assert tyro.cli(main, args=[]) == {"three": 3, "five": 5}
assert tyro.cli(main, args="--x 3 3 5 5".split(" ")) == {"3": 3, "5": 5}


def test_double_dict_no_annotation() -> None:
def main(
x: Dict[str, Any] = {
Expand Down
20 changes: 20 additions & 0 deletions tests/test_py311_generated/test_collections_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,26 @@ def main(x: Dict[str, Any] = {"int": 5, "str": "5"}):
}


def test_dict_no_annotation_2() -> None:
def main(x: Dict = {"int": 5, "str": "5"}):
return x

assert tyro.cli(main, args=[]) == {"int": 5, "str": "5"}
assert tyro.cli(main, args="--x.int 3 --x.str 7".split(" ")) == {
"int": 3,
"str": "7",
}


def test_dict_optional() -> None:
# In this case, the `None` is ignored.
def main(x: Optional[Dict[str, int]] = {"three": 3, "five": 5}):
return x

assert tyro.cli(main, args=[]) == {"three": 3, "five": 5}
assert tyro.cli(main, args="--x 3 3 5 5".split(" ")) == {"3": 3, "5": 5}


def test_double_dict_no_annotation() -> None:
def main(
x: Dict[str, Any] = {
Expand Down

0 comments on commit ab6914f

Please sign in to comment.