Skip to content

Commit

Permalink
Fix subcommand override edge case (#117)
Browse files Browse the repository at this point in the history
* Fix subcommand override edge case

* Fix typo

* ruff

* Fix tests

* More tweaks, leave refactor TODO note
  • Loading branch information
brentyi authored Jan 19, 2024
1 parent bc55c77 commit 171b636
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "tyro"
authors = [
{name = "brentyi", email = "[email protected]"},
]
version = "0.6.5"
version = "0.6.6"
description = "Strongly typed, zero-effort CLI interfaces"
readme = "README.md"
license = { text="MIT" }
Expand Down
62 changes: 45 additions & 17 deletions src/tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class FieldDefinition:
"""Type or callable for this field. This should have all Annotated[] annotations
stripped."""
default: Any
# We need to record whether defaults are from default instances to
# determine if they should override the default in
# tyro.conf.subcommand(default=...).
is_default_from_default_instance: bool
helptext: Optional[str]
markers: FrozenSet[Any]
custom_constructor: bool
Expand All @@ -85,6 +89,7 @@ def make(
name: str,
type_or_callable: Union[TypeForm[Any], Callable],
default: Any,
is_default_from_default_instance: bool,
helptext: Optional[str],
call_argname_override: Optional[Any] = None,
*,
Expand Down Expand Up @@ -125,6 +130,7 @@ def make(
if argconf.constructor_factory is None
else argconf.constructor_factory(),
default=default,
is_default_from_default_instance=is_default_from_default_instance,
helptext=helptext,
markers=frozenset(inferred_markers).union(markers),
custom_constructor=argconf.constructor_factory is not None,
Expand Down Expand Up @@ -285,6 +291,7 @@ def field_list_from_callable(
extern_name="value", # Doesn't matter.
type_or_callable=f,
default=default_instance,
is_default_from_default_instance=True,
helptext="",
custom_constructor=False,
markers=frozenset(
Expand Down Expand Up @@ -370,6 +377,9 @@ def _try_field_list_from_callable(
f: Union[Callable, TypeForm[Any]],
default_instance: DefaultInstance,
) -> Union[List[FieldDefinition], UnsupportedNestedTypeMessage]:
# TODO: this is needed when field_list_from_callable() is called in _calling.py.
# It's basically duplicated from (completely separate!) logic in
# _parsers.py, which is risky and might caused edge cases.
f, found_subcommand_configs = _resolver.unwrap_annotated(
f, conf._confstruct._SubcommandConfiguration
)
Expand Down Expand Up @@ -463,8 +473,10 @@ def _field_list_from_typeddict(
assert not valid_default_instance or isinstance(default_instance, dict)
for name, typ in _resolver.get_type_hints(cls, include_extras=True).items():
typ_origin = get_origin(typ)
is_default_from_default_instance = False
if valid_default_instance and name in cast(dict, default_instance):
default = cast(dict, default_instance)[name]
is_default_from_default_instance = True
elif typ_origin is Required and total is False:
# Support total=False.
default = MISSING_PROP
Expand Down Expand Up @@ -500,6 +512,7 @@ def _field_list_from_typeddict(
name=name,
type_or_callable=typ,
default=default,
is_default_from_default_instance=is_default_from_default_instance,
helptext=_docstrings.get_field_docstring(cls, name),
)
)
Expand Down Expand Up @@ -531,6 +544,7 @@ def _field_list_from_namedtuple(
name=name,
type_or_callable=typ,
default=default,
is_default_from_default_instance=True,
helptext=_docstrings.get_field_docstring(cls, name),
)
)
Expand Down Expand Up @@ -562,7 +576,9 @@ def _field_list_from_dataclass(
if is_flax_module and dc_field.name in ("name", "parent"):
continue

default = _get_dataclass_field_default(dc_field, default_instance)
default, is_default_from_default_instance = _get_dataclass_field_default(
dc_field, default_instance
)

# Try to get helptext from field metadata. This is also intended to be
# compatible with HuggingFace-style config objects.
Expand All @@ -579,6 +595,7 @@ def _field_list_from_dataclass(
name=dc_field.name,
type_or_callable=dc_field.type,
default=default,
is_default_from_default_instance=is_default_from_default_instance,
helptext=helptext,
)
)
Expand Down Expand Up @@ -630,7 +647,7 @@ def _field_list_from_pydantic(
if helptext is None:
helptext = _docstrings.get_field_docstring(cls, pd1_field.name)

default = _get_pydantic_v1_field_default(
default, is_default_from_default_instance = _get_pydantic_v1_field_default(
pd1_field.name, pd1_field, default_instance
)

Expand All @@ -639,6 +656,7 @@ def _field_list_from_pydantic(
name=pd1_field.name,
type_or_callable=pd1_field.outer_type_,
default=default,
is_default_from_default_instance=is_default_from_default_instance,
helptext=helptext,
)
)
Expand All @@ -649,7 +667,9 @@ def _field_list_from_pydantic(
if helptext is None:
helptext = _docstrings.get_field_docstring(cls, name)

default = _get_pydantic_v2_field_default(name, pd2_field, default_instance)
default, is_default_from_default_instance = _get_pydantic_v2_field_default(
name, pd2_field, default_instance
)

field_list.append(
FieldDefinition.make(
Expand All @@ -660,6 +680,7 @@ def _field_list_from_pydantic(
if len(pd2_field.metadata) > 0
else pd2_field.annotation,
default=default,
is_default_from_default_instance=is_default_from_default_instance,
helptext=helptext,
)
)
Expand Down Expand Up @@ -687,9 +708,11 @@ def _field_list_from_attrs(
# Default handling.
name = attr_field.name
default = attr_field.default
is_default_from_default_instance = False
if default_instance not in MISSING_SINGLETONS:
if hasattr(default_instance, name):
default = getattr(default_instance, name)
is_default_from_default_instance = True
else:
warnings.warn(
f"Could not find field {name} in default instance"
Expand All @@ -708,6 +731,7 @@ def _field_list_from_attrs(
name=name,
type_or_callable=attr_field.type,
default=default,
is_default_from_default_instance=is_default_from_default_instance,
helptext=_docstrings.get_field_docstring(cls, name),
)
)
Expand Down Expand Up @@ -754,6 +778,7 @@ def _field_list_from_tuple(
name=str(i),
type_or_callable=child,
default=default_i,
is_default_from_default_instance=True,
helptext="",
# This should really set the positional marker, but the CLI is more
# intuitive for mixed nested/non-nested types in tuples when we stick
Expand Down Expand Up @@ -839,6 +864,7 @@ def _try_field_list_from_sequence_inner(
name=str(i),
type_or_callable=contained_type,
default=default_i,
is_default_from_default_instance=True,
helptext="",
)
)
Expand All @@ -860,6 +886,7 @@ def _field_list_from_dict(
name=str(k) if not isinstance(k, enum.Enum) else k.name,
type_or_callable=type(v),
default=v,
is_default_from_default_instance=True,
helptext=None,
# Dictionary specific key:
call_argname_override=k,
Expand Down Expand Up @@ -984,6 +1011,7 @@ def _field_list_from_params(
# Note that param.annotation doesn't resolve forward references.
type_or_callable=typ,
default=default,
is_default_from_default_instance=False,
helptext=helptext,
markers=markers,
)
Expand All @@ -1009,12 +1037,12 @@ def _ensure_dataclass_instance_used_as_default_is_frozen(

def _get_dataclass_field_default(
field: dataclasses.Field, parent_default_instance: Any
) -> Any:
) -> Tuple[Any, bool]:
"""Helper for getting the default instance for a dataclass field."""
# If the dataclass's parent is explicitly marked MISSING, mark this field as missing
# as well.
if parent_default_instance is MISSING_PROP:
return MISSING_PROP
return MISSING_PROP, False

# Try grabbing default from parent instance.
if (
Expand All @@ -1023,7 +1051,7 @@ def _get_dataclass_field_default(
):
# Populate default from some parent, eg `default=` in `tyro.cli()`.
if hasattr(parent_default_instance, field.name):
return getattr(parent_default_instance, field.name)
return getattr(parent_default_instance, field.name), True
else:
warnings.warn(
f"Could not find field {field.name} in default instance"
Expand All @@ -1039,7 +1067,7 @@ def _get_dataclass_field_default(
# _types_, not just instances.
if type(default) is not type and dataclasses.is_dataclass(default):
_ensure_dataclass_instance_used_as_default_is_frozen(field, default)
return default
return default, False

# Populate default from `dataclasses.field(default_factory=...)`.
if field.default_factory is not dataclasses.MISSING and not (
Expand All @@ -1053,18 +1081,18 @@ def _get_dataclass_field_default(
# before this method is called.
dataclasses.is_dataclass(field.type) and field.default_factory is field.type
):
return field.default_factory()
return field.default_factory(), False

# Otherwise, no default. This is different from MISSING, because MISSING propagates
# to children. We could revisit this design to make it clearer.
return MISSING_NONPROP
return MISSING_NONPROP, False


def _get_pydantic_v1_field_default(
name: str,
field: pydantic_v1.fields.ModelField,
parent_default_instance: DefaultInstance,
) -> Any:
) -> Tuple[Any, bool]:
"""Helper for getting the default instance for a Pydantic field."""

# Try grabbing default from parent instance.
Expand All @@ -1074,7 +1102,7 @@ def _get_pydantic_v1_field_default(
):
# Populate default from some parent, eg `default=` in `tyro.cli()`.
if hasattr(parent_default_instance, name):
return getattr(parent_default_instance, name)
return getattr(parent_default_instance, name), True
else:
warnings.warn(
f"Could not find field {name} in default instance"
Expand All @@ -1084,17 +1112,17 @@ def _get_pydantic_v1_field_default(
)

if not field.required:
return field.get_default()
return field.get_default(), False

# Otherwise, no default.
return MISSING_NONPROP
return MISSING_NONPROP, False


def _get_pydantic_v2_field_default(
name: str,
field: pydantic.fields.FieldInfo,
parent_default_instance: DefaultInstance,
) -> Any:
) -> Tuple[Any, bool]:
"""Helper for getting the default instance for a Pydantic field."""

# Try grabbing default from parent instance.
Expand All @@ -1104,7 +1132,7 @@ def _get_pydantic_v2_field_default(
):
# Populate default from some parent, eg `default=` in `tyro.cli()`.
if hasattr(parent_default_instance, name):
return getattr(parent_default_instance, name)
return getattr(parent_default_instance, name), True
else:
warnings.warn(
f"Could not find field {name} in default instance"
Expand All @@ -1114,7 +1142,7 @@ def _get_pydantic_v2_field_default(
)

if not field.is_required():
return field.get_default(call_default_factory=True)
return field.get_default(call_default_factory=True), False

# Otherwise, no default.
return MISSING_NONPROP
return MISSING_NONPROP, False
20 changes: 19 additions & 1 deletion src/tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,10 +493,28 @@ def from_field(
)

# If names match, borrow subcommand default from field default.
if default_name == subcommand_name:
if default_name == subcommand_name and (
field.is_default_from_default_instance
or subcommand_config.default in _fields.MISSING_SINGLETONS
):
subcommand_config = dataclasses.replace(
subcommand_config, default=field.default
)

# Strip the subcommand config from the option type.
option_origin, annotations = _resolver.unwrap_annotated(option)
annotations = tuple(
a
for a in annotations
if not isinstance(a, _confstruct._SubcommandConfiguration)
)
if len(annotations) == 0:
option = option_origin
else:
option = Annotated.__class_getitem__( # type: ignore
(option_origin,) + annotations
)

subparser = ParserSpecification.from_callable_or_type(
(
# Recursively apply markers.
Expand Down
32 changes: 32 additions & 0 deletions tests/test_base_configs_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,35 @@ def main(cfg: BaseConfig) -> BaseConfig:
),
DataConfig(2),
)


def test_pernicious_override():
"""From: https://github.com/nerfstudio-project/nerfstudio/issues/2789
Situation where we:
- have a default value in the config class
- override that default value with a subcommand annotation
- override it again with a default instance
"""
assert (
tyro.cli(
BaseConfig,
default=BaseConfig(
"test",
"test",
ExperimentConfig(
dataset="mnist",
optimizer=AdamOptimizer(),
batch_size=2048,
num_layers=4,
units=64,
train_steps=30_000,
seed=0,
activation=nn.ReLU,
),
DataConfig(0),
),
args="small small-data".split(" "),
).data_config.test
== 0
)

0 comments on commit 171b636

Please sign in to comment.