diff --git a/pyproject.toml b/pyproject.toml index d19615b5..73d98ad7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "tyro" authors = [ {name = "brentyi", email = "brentyi@berkeley.edu"}, ] -version = "0.6.5" +version = "0.6.6" description = "Strongly typed, zero-effort CLI interfaces" readme = "README.md" license = { text="MIT" } diff --git a/src/tyro/_fields.py b/src/tyro/_fields.py index 4f7ebdf0..c0d814a8 100644 --- a/src/tyro/_fields.py +++ b/src/tyro/_fields.py @@ -62,6 +62,10 @@ class FieldDefinition: """Type or callable for this field. This should have all Annotated[] annotations stripped.""" default: Any + # We need to record whether defaults are from default instances to + # determine if they should override the default in + # tyro.conf.subcommand(default=...). + is_default_from_default_instance: bool helptext: Optional[str] markers: FrozenSet[Any] custom_constructor: bool @@ -85,6 +89,7 @@ def make( name: str, type_or_callable: Union[TypeForm[Any], Callable], default: Any, + is_default_from_default_instance: bool, helptext: Optional[str], call_argname_override: Optional[Any] = None, *, @@ -125,6 +130,7 @@ def make( if argconf.constructor_factory is None else argconf.constructor_factory(), default=default, + is_default_from_default_instance=is_default_from_default_instance, helptext=helptext, markers=frozenset(inferred_markers).union(markers), custom_constructor=argconf.constructor_factory is not None, @@ -285,6 +291,7 @@ def field_list_from_callable( extern_name="value", # Doesn't matter. type_or_callable=f, default=default_instance, + is_default_from_default_instance=True, helptext="", custom_constructor=False, markers=frozenset( @@ -370,6 +377,9 @@ def _try_field_list_from_callable( f: Union[Callable, TypeForm[Any]], default_instance: DefaultInstance, ) -> Union[List[FieldDefinition], UnsupportedNestedTypeMessage]: + # TODO: this is needed when field_list_from_callable() is called in _calling.py. + # It's basically duplicated from (completely separate!) logic in + # _parsers.py, which is risky and might caused edge cases. f, found_subcommand_configs = _resolver.unwrap_annotated( f, conf._confstruct._SubcommandConfiguration ) @@ -463,8 +473,10 @@ def _field_list_from_typeddict( assert not valid_default_instance or isinstance(default_instance, dict) for name, typ in _resolver.get_type_hints(cls, include_extras=True).items(): typ_origin = get_origin(typ) + is_default_from_default_instance = False if valid_default_instance and name in cast(dict, default_instance): default = cast(dict, default_instance)[name] + is_default_from_default_instance = True elif typ_origin is Required and total is False: # Support total=False. default = MISSING_PROP @@ -500,6 +512,7 @@ def _field_list_from_typeddict( name=name, type_or_callable=typ, default=default, + is_default_from_default_instance=is_default_from_default_instance, helptext=_docstrings.get_field_docstring(cls, name), ) ) @@ -531,6 +544,7 @@ def _field_list_from_namedtuple( name=name, type_or_callable=typ, default=default, + is_default_from_default_instance=True, helptext=_docstrings.get_field_docstring(cls, name), ) ) @@ -562,7 +576,9 @@ def _field_list_from_dataclass( if is_flax_module and dc_field.name in ("name", "parent"): continue - default = _get_dataclass_field_default(dc_field, default_instance) + default, is_default_from_default_instance = _get_dataclass_field_default( + dc_field, default_instance + ) # Try to get helptext from field metadata. This is also intended to be # compatible with HuggingFace-style config objects. @@ -579,6 +595,7 @@ def _field_list_from_dataclass( name=dc_field.name, type_or_callable=dc_field.type, default=default, + is_default_from_default_instance=is_default_from_default_instance, helptext=helptext, ) ) @@ -630,7 +647,7 @@ def _field_list_from_pydantic( if helptext is None: helptext = _docstrings.get_field_docstring(cls, pd1_field.name) - default = _get_pydantic_v1_field_default( + default, is_default_from_default_instance = _get_pydantic_v1_field_default( pd1_field.name, pd1_field, default_instance ) @@ -639,6 +656,7 @@ def _field_list_from_pydantic( name=pd1_field.name, type_or_callable=pd1_field.outer_type_, default=default, + is_default_from_default_instance=is_default_from_default_instance, helptext=helptext, ) ) @@ -649,7 +667,9 @@ def _field_list_from_pydantic( if helptext is None: helptext = _docstrings.get_field_docstring(cls, name) - default = _get_pydantic_v2_field_default(name, pd2_field, default_instance) + default, is_default_from_default_instance = _get_pydantic_v2_field_default( + name, pd2_field, default_instance + ) field_list.append( FieldDefinition.make( @@ -660,6 +680,7 @@ def _field_list_from_pydantic( if len(pd2_field.metadata) > 0 else pd2_field.annotation, default=default, + is_default_from_default_instance=is_default_from_default_instance, helptext=helptext, ) ) @@ -687,9 +708,11 @@ def _field_list_from_attrs( # Default handling. name = attr_field.name default = attr_field.default + is_default_from_default_instance = False if default_instance not in MISSING_SINGLETONS: if hasattr(default_instance, name): default = getattr(default_instance, name) + is_default_from_default_instance = True else: warnings.warn( f"Could not find field {name} in default instance" @@ -708,6 +731,7 @@ def _field_list_from_attrs( name=name, type_or_callable=attr_field.type, default=default, + is_default_from_default_instance=is_default_from_default_instance, helptext=_docstrings.get_field_docstring(cls, name), ) ) @@ -754,6 +778,7 @@ def _field_list_from_tuple( name=str(i), type_or_callable=child, default=default_i, + is_default_from_default_instance=True, helptext="", # This should really set the positional marker, but the CLI is more # intuitive for mixed nested/non-nested types in tuples when we stick @@ -839,6 +864,7 @@ def _try_field_list_from_sequence_inner( name=str(i), type_or_callable=contained_type, default=default_i, + is_default_from_default_instance=True, helptext="", ) ) @@ -860,6 +886,7 @@ def _field_list_from_dict( name=str(k) if not isinstance(k, enum.Enum) else k.name, type_or_callable=type(v), default=v, + is_default_from_default_instance=True, helptext=None, # Dictionary specific key: call_argname_override=k, @@ -984,6 +1011,7 @@ def _field_list_from_params( # Note that param.annotation doesn't resolve forward references. type_or_callable=typ, default=default, + is_default_from_default_instance=False, helptext=helptext, markers=markers, ) @@ -1009,12 +1037,12 @@ def _ensure_dataclass_instance_used_as_default_is_frozen( def _get_dataclass_field_default( field: dataclasses.Field, parent_default_instance: Any -) -> Any: +) -> Tuple[Any, bool]: """Helper for getting the default instance for a dataclass field.""" # If the dataclass's parent is explicitly marked MISSING, mark this field as missing # as well. if parent_default_instance is MISSING_PROP: - return MISSING_PROP + return MISSING_PROP, False # Try grabbing default from parent instance. if ( @@ -1023,7 +1051,7 @@ def _get_dataclass_field_default( ): # Populate default from some parent, eg `default=` in `tyro.cli()`. if hasattr(parent_default_instance, field.name): - return getattr(parent_default_instance, field.name) + return getattr(parent_default_instance, field.name), True else: warnings.warn( f"Could not find field {field.name} in default instance" @@ -1039,7 +1067,7 @@ def _get_dataclass_field_default( # _types_, not just instances. if type(default) is not type and dataclasses.is_dataclass(default): _ensure_dataclass_instance_used_as_default_is_frozen(field, default) - return default + return default, False # Populate default from `dataclasses.field(default_factory=...)`. if field.default_factory is not dataclasses.MISSING and not ( @@ -1053,18 +1081,18 @@ def _get_dataclass_field_default( # before this method is called. dataclasses.is_dataclass(field.type) and field.default_factory is field.type ): - return field.default_factory() + return field.default_factory(), False # Otherwise, no default. This is different from MISSING, because MISSING propagates # to children. We could revisit this design to make it clearer. - return MISSING_NONPROP + return MISSING_NONPROP, False def _get_pydantic_v1_field_default( name: str, field: pydantic_v1.fields.ModelField, parent_default_instance: DefaultInstance, -) -> Any: +) -> Tuple[Any, bool]: """Helper for getting the default instance for a Pydantic field.""" # Try grabbing default from parent instance. @@ -1074,7 +1102,7 @@ def _get_pydantic_v1_field_default( ): # Populate default from some parent, eg `default=` in `tyro.cli()`. if hasattr(parent_default_instance, name): - return getattr(parent_default_instance, name) + return getattr(parent_default_instance, name), True else: warnings.warn( f"Could not find field {name} in default instance" @@ -1084,17 +1112,17 @@ def _get_pydantic_v1_field_default( ) if not field.required: - return field.get_default() + return field.get_default(), False # Otherwise, no default. - return MISSING_NONPROP + return MISSING_NONPROP, False def _get_pydantic_v2_field_default( name: str, field: pydantic.fields.FieldInfo, parent_default_instance: DefaultInstance, -) -> Any: +) -> Tuple[Any, bool]: """Helper for getting the default instance for a Pydantic field.""" # Try grabbing default from parent instance. @@ -1104,7 +1132,7 @@ def _get_pydantic_v2_field_default( ): # Populate default from some parent, eg `default=` in `tyro.cli()`. if hasattr(parent_default_instance, name): - return getattr(parent_default_instance, name) + return getattr(parent_default_instance, name), True else: warnings.warn( f"Could not find field {name} in default instance" @@ -1114,7 +1142,7 @@ def _get_pydantic_v2_field_default( ) if not field.is_required(): - return field.get_default(call_default_factory=True) + return field.get_default(call_default_factory=True), False # Otherwise, no default. - return MISSING_NONPROP + return MISSING_NONPROP, False diff --git a/src/tyro/_parsers.py b/src/tyro/_parsers.py index 1669aeed..98da4fb2 100644 --- a/src/tyro/_parsers.py +++ b/src/tyro/_parsers.py @@ -493,10 +493,28 @@ def from_field( ) # If names match, borrow subcommand default from field default. - if default_name == subcommand_name: + if default_name == subcommand_name and ( + field.is_default_from_default_instance + or subcommand_config.default in _fields.MISSING_SINGLETONS + ): subcommand_config = dataclasses.replace( subcommand_config, default=field.default ) + + # Strip the subcommand config from the option type. + option_origin, annotations = _resolver.unwrap_annotated(option) + annotations = tuple( + a + for a in annotations + if not isinstance(a, _confstruct._SubcommandConfiguration) + ) + if len(annotations) == 0: + option = option_origin + else: + option = Annotated.__class_getitem__( # type: ignore + (option_origin,) + annotations + ) + subparser = ParserSpecification.from_callable_or_type( ( # Recursively apply markers. diff --git a/tests/test_base_configs_nested.py b/tests/test_base_configs_nested.py index 1933617a..7500df32 100644 --- a/tests/test_base_configs_nested.py +++ b/tests/test_base_configs_nested.py @@ -178,3 +178,35 @@ def main(cfg: BaseConfig) -> BaseConfig: ), DataConfig(2), ) + + +def test_pernicious_override(): + """From: https://github.com/nerfstudio-project/nerfstudio/issues/2789 + + Situation where we: + - have a default value in the config class + - override that default value with a subcommand annotation + - override it again with a default instance + """ + assert ( + tyro.cli( + BaseConfig, + default=BaseConfig( + "test", + "test", + ExperimentConfig( + dataset="mnist", + optimizer=AdamOptimizer(), + batch_size=2048, + num_layers=4, + units=64, + train_steps=30_000, + seed=0, + activation=nn.ReLU, + ), + DataConfig(0), + ), + args="small small-data".split(" "), + ).data_config.test + == 0 + )