From 8c8fd90f6f30897601616c1d8f46500691564061 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 5 Nov 2024 03:05:45 -0800 Subject: [PATCH 1/5] Docs pass, primitive rule priority --- docs/source/conf.py | 2 + .../04_additional/06_generics_py312.rst | 2 +- .../04_additional/11_custom_constructors.rst | 84 -------------- .../12_custom_constructors_registry.rst | 92 ---------------- .../04_additional/16_type_statement.rst | 50 --------- .../examples/04_additional/17_aliases.rst | 96 ---------------- docs/source/goals_and_alternatives.md | 51 +++++---- docs/source/index.md | 86 ++++++++++----- examples/04_additional/06_generics_py312.py | 2 +- .../{17_aliases.py => 11_aliases.py} | 14 +-- ...type_statement.py => 12_type_statement.py} | 2 +- .../01_primitive_annotation.py} | 19 ++-- .../02_primitive_registry.py} | 20 ++-- src/tyro/_arguments.py | 27 ++--- src/tyro/_calling.py | 2 +- src/tyro/_cli.py | 1 - src/tyro/_fields.py | 104 ++++++++++-------- src/tyro/_parsers.py | 38 +++++-- src/tyro/_subcommand_matching.py | 2 +- src/tyro/conf/_confstruct.py | 8 +- src/tyro/constructors/__init__.py | 8 ++ src/tyro/constructors/_primitive_spec.py | 40 ++++--- src/tyro/constructors/_registry.py | 36 ++++-- src/tyro/constructors/_struct_spec.py | 27 ++++- src/tyro/extras/__init__.py | 3 +- tests/test_custom_primitive.py | 46 -------- tests/test_errors.py | 2 +- .../test_collections_generated.py | 2 +- .../test_custom_primitive_generated.py | 44 -------- .../test_errors_generated.py | 2 +- 30 files changed, 316 insertions(+), 596 deletions(-) delete mode 100644 docs/source/examples/04_additional/11_custom_constructors.rst delete mode 100644 docs/source/examples/04_additional/12_custom_constructors_registry.rst delete mode 100644 docs/source/examples/04_additional/16_type_statement.rst delete mode 100644 docs/source/examples/04_additional/17_aliases.rst rename examples/04_additional/{17_aliases.py => 11_aliases.py} (68%) rename examples/04_additional/{16_type_statement.py => 12_type_statement.py} (94%) rename examples/{04_additional/11_custom_constructors.py => 05_custom_constructors/01_primitive_annotation.py} (51%) rename examples/{04_additional/12_custom_constructors_registry.py => 05_custom_constructors/02_primitive_registry.py} (60%) delete mode 100644 tests/test_custom_primitive.py delete mode 100644 tests/test_py311_generated/test_custom_primitive_generated.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 1ac56224..a22e665f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -395,6 +395,8 @@ def docstring(app, what, name, obj, options, lines): rst = m2r2.convert(md) lines.clear() lines += rst.splitlines() # type: ignore + lines.append("") + lines.append("") def setup(app): diff --git a/docs/source/examples/04_additional/06_generics_py312.rst b/docs/source/examples/04_additional/06_generics_py312.rst index 44dbb9d8..23edc94c 100644 --- a/docs/source/examples/04_additional/06_generics_py312.rst +++ b/docs/source/examples/04_additional/06_generics_py312.rst @@ -1,7 +1,7 @@ .. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. -Generic Types (Python 3.12+ syntax) +Generic Types (Python 3.12+) ========================================== Example of parsing for generic dataclasses using syntax introduced in Python diff --git a/docs/source/examples/04_additional/11_custom_constructors.rst b/docs/source/examples/04_additional/11_custom_constructors.rst deleted file mode 100644 index 24cea387..00000000 --- a/docs/source/examples/04_additional/11_custom_constructors.rst +++ /dev/null @@ -1,84 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Custom Constructors -========================================== - -For additional flexibility, :module:`tyro.constructors` exposes -tyro's API for defining behavior for different types. This is the same -API that tyro relies on for the built-in types. - - -.. code-block:: python - :linenos: - - - import json - - from typing_extensions import Annotated - - import tyro - - # A dictionary type, but `tyro` will expect a JSON string from the CLI. - JsonDict = Annotated[ - dict, - tyro.constructors.PrimitiveConstructorSpec( - nargs=1, - metavar="JSON", - instance_from_str=lambda args: json.loads(args[0]), - is_instance=lambda instance: isinstance(instance, dict), - str_from_instance=lambda instance: [json.dumps(instance)], - ), - ] - - - def main( - dict1: JsonDict, - dict2: JsonDict = {"default": None}, - ) -> None: - print(f"{dict1=}") - print(f"{dict2=}") - - - if __name__ == "__main__": - tyro.cli(main) - ------------- - -.. raw:: html - - python 04_additional/11_custom_constructors.py --help - -.. program-output:: python ../../examples/04_additional/11_custom_constructors.py --help - ------------- - -.. raw:: html - - python 04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}' - -.. program-output:: python ../../examples/04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}' - ------------- - -.. raw:: html - - python 04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}' - -.. program-output:: python ../../examples/04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}' - ------------- - -.. raw:: html - - python 04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}' - -.. program-output:: python ../../examples/04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}' - ------------- - -.. raw:: html - - python 04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}' - -.. program-output:: python ../../examples/04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}' diff --git a/docs/source/examples/04_additional/12_custom_constructors_registry.rst b/docs/source/examples/04_additional/12_custom_constructors_registry.rst deleted file mode 100644 index 942ec2b9..00000000 --- a/docs/source/examples/04_additional/12_custom_constructors_registry.rst +++ /dev/null @@ -1,92 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Custom Constructors (Registry) -========================================== - -For additional flexibility, :module:`tyro.constructors` exposes -tyro's API for defining behavior for different types. This is the same -API that tyro relies on for the built-in types. - - -.. code-block:: python - :linenos: - - - import json - from typing import Any - - import tyro - - custom_registry = tyro.constructors.PrimitiveConstructorRegistry() - - - @custom_registry.define_rule - def _( - type_info: tyro.constructors.PrimitiveTypeInfo, - ) -> tyro.constructors.PrimitiveConstructorSpec | None: - # We return `None` if the rule does not apply. - if type_info.type != dict[str, Any]: - return None - - # If the rule applies, we return the constructor spec. - return tyro.constructors.PrimitiveConstructorSpec( - nargs=1, - metavar="JSON", - instance_from_str=lambda args: json.loads(args[0]), - is_instance=lambda instance: isinstance(instance, dict), - str_from_instance=lambda instance: [json.dumps(instance)], - ) - - - def main( - dict1: dict[str, Any], - dict2: dict[str, Any] = {"default": None}, - ) -> None: - print(f"{dict1=}") - print(f"{dict2=}") - - - if __name__ == "__main__": - with custom_registry: - tyro.cli(main) - ------------- - -.. raw:: html - - python 04_additional/12_custom_constructors_registry.py --help - -.. program-output:: python ../../examples/04_additional/12_custom_constructors_registry.py --help - ------------- - -.. raw:: html - - python 04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}' - -.. program-output:: python ../../examples/04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}' - ------------- - -.. raw:: html - - python 04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}' - -.. program-output:: python ../../examples/04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}' - ------------- - -.. raw:: html - - python 04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}' - -.. program-output:: python ../../examples/04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}' - ------------- - -.. raw:: html - - python 04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}' - -.. program-output:: python ../../examples/04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}' diff --git a/docs/source/examples/04_additional/16_type_statement.rst b/docs/source/examples/04_additional/16_type_statement.rst deleted file mode 100644 index 8fe3d666..00000000 --- a/docs/source/examples/04_additional/16_type_statement.rst +++ /dev/null @@ -1,50 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Type Aliases (Python 3.12+) -========================================== - -In Python 3.12, the :code:`type` statement is introduced to create type aliases. - - -.. code-block:: python - :linenos: - - - import dataclasses - - import tyro - - # Lazily-evaluated type alias. - type Field1Type = Inner - - - @dataclasses.dataclass - class Inner: - a: int - b: str - - - @dataclasses.dataclass - class Args: - """Description. - This should show up in the helptext!""" - - field1: Field1Type - """A field.""" - - field2: int = 3 - """A numeric field, with a default value.""" - - - if __name__ == "__main__": - args = tyro.cli(Args) - print(args) - ------------- - -.. raw:: html - - python 04_additional/16_type_statement.py --help - -.. program-output:: python ../../examples/04_additional/16_type_statement.py --help diff --git a/docs/source/examples/04_additional/17_aliases.rst b/docs/source/examples/04_additional/17_aliases.rst deleted file mode 100644 index e53f2842..00000000 --- a/docs/source/examples/04_additional/17_aliases.rst +++ /dev/null @@ -1,96 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Argument Aliases -========================================== - -:func:`tyro.conf.arg()` can be used to attach aliases to arguments. - - -.. code-block:: python - :linenos: - - - from typing_extensions import Annotated - - import tyro - - - def checkout( - branch: Annotated[str, tyro.conf.arg(aliases=["-b"])], - ) -> None: - """Check out a branch.""" - print(f"{branch=}") - - - def commit( - message: Annotated[str, tyro.conf.arg(aliases=["-m"])], - all: Annotated[bool, tyro.conf.arg(aliases=["-a"])] = False, - ) -> None: - """Make a commit.""" - print(f"{message=} {all=}") - - - if __name__ == "__main__": - tyro.extras.subcommand_cli_from_dict( - { - "checkout": checkout, - "commit": commit, - } - ) - ------------- - -.. raw:: html - - python 04_additional/17_aliases.py --help - -.. program-output:: python ../../examples/04_additional/17_aliases.py --help - ------------- - -.. raw:: html - - python 04_additional/17_aliases.py commit --help - -.. program-output:: python ../../examples/04_additional/17_aliases.py commit --help - ------------- - -.. raw:: html - - python 04_additional/17_aliases.py commit --message hello --all - -.. program-output:: python ../../examples/04_additional/17_aliases.py commit --message hello --all - ------------- - -.. raw:: html - - python 04_additional/17_aliases.py commit -m hello -a - -.. program-output:: python ../../examples/04_additional/17_aliases.py commit -m hello -a - ------------- - -.. raw:: html - - python 04_additional/17_aliases.py checkout --help - -.. program-output:: python ../../examples/04_additional/17_aliases.py checkout --help - ------------- - -.. raw:: html - - python 04_additional/17_aliases.py checkout --branch main - -.. program-output:: python ../../examples/04_additional/17_aliases.py checkout --branch main - ------------- - -.. raw:: html - - python 04_additional/17_aliases.py checkout -b main - -.. program-output:: python ../../examples/04_additional/17_aliases.py checkout -b main diff --git a/docs/source/goals_and_alternatives.md b/docs/source/goals_and_alternatives.md index e5f36e0f..d8bd5ffc 100644 --- a/docs/source/goals_and_alternatives.md +++ b/docs/source/goals_and_alternatives.md @@ -10,7 +10,7 @@ Usage distinctions are the result of two API goals: `tyro` should reduce to learning to write type-annotated Python. For example, types are specified using standard annotations, helptext using docstrings, choices using the standard `typing.Literal` type, subcommands with - `typing.Union` of nested types, and positional arguments with `/`. + `typing.Union` of struct types, and positional arguments with `/`. - In contrast, similar libraries have more expansive APIs , and require more library-specific structures, decorators, or metadata formats for configuring parsing behavior. @@ -21,23 +21,31 @@ Usage distinctions are the result of two API goals: dynamic argparse-style namespaces, or string-based accessors that can't be statically checked. + + +.. warning:: + This survey was conducted in late 2022. It may be out of date. + + + More concretely, we can also compare specific features. A noncomprehensive set: -| | Dataclasses | Functions | Literals | Docstrings as helptext | Nested structures | Unions over primitives | Unions over nested types | Lists, tuples | Dictionaries | Generics | -| -------------------------------------------- | ----------- | --------- | -------------------- | ---------------------- | ----------------- | ---------------------- | ------------------------- | -------------------- | ------------ | -------- | -| [argparse-dataclass][argparse-dataclass] | ✓ | | | | | | | | | | -| [argparse-dataclasses][argparse-dataclasses] | ✓ | | | | | | | | | | -| [datargs][datargs] | ✓ | | ✓[^datargs_literals] | | | | ✓[^datargs_unions_nested] | ✓ | | | -| [tap][tap] | | | ✓ | ✓ | | ✓ | ~[^tap_unions_nested] | ✓ | | | -| [simple-parsing][simple-parsing] | ✓ | | ✓[^simp_literals] | ✓ | ✓ | ✓ | ✓[^simp_unions_nested] | ✓ | ✓ | | -| [dataclass-cli][dataclass-cli] | ✓ | | | | | | | | | | -| [clout][clout] | ✓ | | | | ✓ | | | | | | -| [hf_argparser][hf_argparser] | ✓ | | | | | | | ✓ | ✓ | | -| [typer][typer] | | ✓ | | | | | ~[^typer_unions_nested] | ~[^typer_containers] | | | -| [pyrallis][pyrallis] | ✓ | | | ✓ | ✓ | | | ✓ | | | -| [yahp][yahp] | ✓ | | | ~[^yahp_docstrings] | ✓ | ✓ | ~[^yahp_unions_nested] | ✓ | | | -| [omegaconf][omegaconf] | ✓ | | | | ✓ | | | ✓ | ✓ | | -| **tyro** | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | +| | Dataclasses | Functions | Literals | Docstrings as helptext | Nested structs | Unions over primitives | Unions over structs | Lists, tuples | Dicts | Generics | +| -------------------------------------------- | ----------- | --------- | -------------------- | ---------------------- | -------------- | ---------------------- | ------------------------- | -------------------- | ----- | -------- | +| [argparse-dataclass][argparse-dataclass] | ✓ | | | | | | | | | | +| [argparse-dataclasses][argparse-dataclasses] | ✓ | | | | | | | | | | +| [datargs][datargs] | ✓ | | ✓[^datargs_literals] | | | | ✓[^datargs_unions_struct] | ✓ | | | +| [tap][tap] | | | ✓ | ✓ | | ✓ | ~[^tap_unions_struct] | ✓ | | | +| [simple-parsing][simple-parsing] | ✓ | | ✓[^simp_literals] | ✓ | ✓ | ✓ | ✓[^simp_unions_struct] | ✓ | ✓ | | +| [dataclass-cli][dataclass-cli] | ✓ | | | | | | | | | | +| [clout][clout] | ✓ | | | | ✓ | | | | | | +| [hf_argparser][hf_argparser] | ✓ | | | | | | | ✓ | ✓ | | +| [typer][typer] | | ✓ | | | | | ~[^typer_unions_struct] | ~[^typer_containers] | | | +| [pyrallis][pyrallis] | ✓ | | | ✓ | ✓ | | | ✓ | | | +| [yahp][yahp] | ✓ | | | ~[^yahp_docstrings] | ✓ | ✓ | ~[^yahp_unions_struct] | ✓ | | | +| [omegaconf][omegaconf] | ✓ | | | | ✓ | | | ✓ | ✓ | | +| [defopt][defopt] | | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | | | +| **tyro** | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | @@ -53,12 +61,13 @@ More concretely, we can also compare specific features. A noncomprehensive set: [typer]: https://typer.tiangolo.com/ [yahp]: https://github.com/mosaicml/yahp [omegaconf]: https://omegaconf.readthedocs.io/en/2.1_branch/structured_config.html +[defopt]: https://github.com/anntzer/defopt/ -[^datargs_unions_nested]: One allowed per class. -[^tap_unions_nested]: Not supported, but API exists for creating subcommands that accomplish a similar goal. -[^simp_unions_nested]: One allowed per class. -[^yahp_unions_nested]: Not supported, but similar functionality available via ["registries"](https://docs.mosaicml.com/projects/yahp/en/stable/examples/registry.html). -[^typer_unions_nested]: Not supported, but API exists for creating subcommands that accomplish a similar goal. +[^datargs_unions_struct]: One allowed per class. +[^tap_unions_struct]: Not supported, but API exists for creating subcommands that accomplish a similar goal. +[^simp_unions_struct]: One allowed per class. +[^yahp_unions_struct]: Not supported, but similar functionality available via ["registries"](https://docs.mosaicml.com/projects/yahp/en/stable/examples/registry.html). +[^typer_unions_struct]: Not supported, but API exists for creating subcommands that accomplish a similar goal. [^simp_literals]: Not supported for mixed (eg `Literal[5, "five"]`) or in container (eg `List[Literal[1, 2]]`) types. [^datargs_literals]: Not supported for mixed types (eg `Literal[5, "five"]`). [^typer_containers]: `typer` uses positional arguments for all required fields, which means that only one variable-length argument (such as `List[int]`) without a default is supported per argument parser. diff --git a/docs/source/index.md b/docs/source/index.md index 1144b609..dbcbae06 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -2,42 +2,67 @@ |build| |nbsp| |ruff| |nbsp| |mypy| |nbsp| |pyright| |nbsp| |coverage| |nbsp| |versions| -:code:`tyro` is a tool for generating command-line interfaces and configuration -objects in Python. +:func:`tyro.cli()` is a tool for generating CLI +interfaces. -Our core API, :func:`tyro.cli()`, +We can define configurable scripts using functions: -- **Generates CLI interfaces** from a comprehensive set of Python type - constructs. -- **Populates helptext automatically** from defaults, annotations, and - docstrings. -- **Understands nesting** of `dataclasses`, `pydantic`, and `attrs` structures. -- **Prioritizes static analysis** for type checking and autocompletion with - tools like - [Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance), - [Pyright](https://github.com/microsoft/pyright), and - [mypy](https://github.com/python/mypy). +```python +"""A command-line interface defined using a function signature. -For advanced users, it also supports: +Usage: python script_name.py --foo INT [--bar STR] +""" -- **Subcommands**, as well as choosing between and overriding values in - configuration objects. -- **Completion script generation** for `bash`, `zsh`, and `tcsh`. -- **Fine-grained configuration** via [PEP - 593](https://peps.python.org/pep-0593/) annotations (`tyro.conf.*`). +import tyro -To get started, we recommend browsing the examples to the left. +def main( + foo: int, + bar: str = "default", +) -> None: + ... # Main body of a script. -### Why `tyro`? +if __name__ == "__main__": + # Generate a CLI and call `main` with its two arguments: `foo` and `bar`. + tyro.cli(main) +``` -1. **Strong typing.** +Or instantiate config objects defined using tools like `dataclasses`, `pydantic`, and `attrs`: + +```python +"""A command-line interface defined using a class signature. + +Usage: python script_name.py --foo INT [--bar STR] +""" + +from dataclasses import dataclass +import tyro + +@dataclass +class Config: + foo: int + bar: str = "default" + +if __name__ == "__main__": + # Generate a CLI and instantiate `Config` with its two arguments: `foo` and `bar`. + config = tyro.cli(Config) + + # Rest of script. + assert isinstance(config, Config) # Should pass. +``` + +Other features include helptext generation, nested structures, subcommands, and +shell completion. + +#### Why `tyro`? + +1. **Types.** Unlike tools dependent on dictionaries, YAML, or dynamic namespaces, arguments populated by `tyro` benefit from IDE and language server-supported operations — think tab completion, rename, jump-to-def, docstrings on hover — as well as static checking tools like `pyright` and `mypy`. -2. **Minimal overhead.** +2. **Define things once.** Standard Python type annotations, docstrings, and default values are parsed to automatically generate command-line interfaces with informative helptext. @@ -55,11 +80,6 @@ To get started, we recommend browsing the examples to the left. distribute definitions, defaults, and documentation of configurable fields across modules or source files. -4. **Tab completion.** - - By extending [shtab](https://github.com/iterative/shtab), `tyro` - automatically generates tab completion scripts for bash, zsh, and tcsh. - .. toctree:: @@ -112,6 +132,16 @@ To get started, we recommend browsing the examples to the left. examples/04_additional/* +.. toctree:: + :caption: Custom Constructors + :hidden: + :maxdepth: 1 + :titlesonly: + :glob: + + examples/05_custom_constructors/* + + .. toctree:: :caption: Notes :hidden: diff --git a/examples/04_additional/06_generics_py312.py b/examples/04_additional/06_generics_py312.py index 2fd4b458..72f69571 100644 --- a/examples/04_additional/06_generics_py312.py +++ b/examples/04_additional/06_generics_py312.py @@ -1,7 +1,7 @@ # mypy: ignore-errors # # PEP 695 isn't yet supported in mypy. (April 4, 2024) -"""Generic Types (Python 3.12+ syntax) +"""Generic Types (Python 3.12+) Example of parsing for generic dataclasses using syntax introduced in Python 3.12 (`PEP 695 `_). diff --git a/examples/04_additional/17_aliases.py b/examples/04_additional/11_aliases.py similarity index 68% rename from examples/04_additional/17_aliases.py rename to examples/04_additional/11_aliases.py index 16c5b3b3..5ee33de9 100644 --- a/examples/04_additional/17_aliases.py +++ b/examples/04_additional/11_aliases.py @@ -3,13 +3,13 @@ :func:`tyro.conf.arg()` can be used to attach aliases to arguments. Usage: -`python ./12_aliases.py --help` -`python ./12_aliases.py commit --help` -`python ./12_aliases.py commit --message hello --all` -`python ./12_aliases.py commit -m hello -a` -`python ./12_aliases.py checkout --help` -`python ./12_aliases.py checkout --branch main` -`python ./12_aliases.py checkout -b main` +`python ./11_aliases.py --help` +`python ./11_aliases.py commit --help` +`python ./11_aliases.py commit --message hello --all` +`python ./11_aliases.py commit -m hello -a` +`python ./11_aliases.py checkout --help` +`python ./11_aliases.py checkout --branch main` +`python ./11_aliases.py checkout -b main` """ from typing_extensions import Annotated diff --git a/examples/04_additional/16_type_statement.py b/examples/04_additional/12_type_statement.py similarity index 94% rename from examples/04_additional/16_type_statement.py rename to examples/04_additional/12_type_statement.py index 9aa08cf9..fa814029 100644 --- a/examples/04_additional/16_type_statement.py +++ b/examples/04_additional/12_type_statement.py @@ -6,7 +6,7 @@ In Python 3.12, the :code:`type` statement is introduced to create type aliases. Usage: -`python ./16_type_statement.py --help` +`python ./12_type_statement.py --help` """ import dataclasses diff --git a/examples/04_additional/11_custom_constructors.py b/examples/05_custom_constructors/01_primitive_annotation.py similarity index 51% rename from examples/04_additional/11_custom_constructors.py rename to examples/05_custom_constructors/01_primitive_annotation.py index 4d8cff56..505e8b9d 100644 --- a/examples/04_additional/11_custom_constructors.py +++ b/examples/05_custom_constructors/01_primitive_annotation.py @@ -1,15 +1,16 @@ -"""Custom Constructors +"""Custom Primitive -For additional flexibility, :module:`tyro.constructors` exposes -tyro's API for defining behavior for different types. This is the same -API that tyro relies on for the built-in types. +For additional flexibility, :mod:`tyro.constructors` exposes tyro's API for +defining behavior for different types. There are two categories of types: +primitive types can be instantiated from a single commandline argument, while +struct types are broken down into multiple. + +In this example, we attach a custom constructor via a runtime annotation. Usage: -`python ./10_custom_constructors.py --help` -`python ./10_custom_constructors.py --dict1.json '{"hello": "world"}'` -`python ./10_custom_constructors.py --dict1.json "{\"hello\": \"world\"}"` -`python ./10_custom_constructors.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}'` -`python ./10_custom_constructors.py --dict1.json "{\"hello\": \"world\"}" --dict2.json "{\"hello\": \"world\"}"` +`python ./01_primitive_annotation.py --help` +`python ./01_primitive_annotation.py --dict1 '{"hello": "world"}'` +`python ./01_primitive_annotation.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'` """ import json diff --git a/examples/04_additional/12_custom_constructors_registry.py b/examples/05_custom_constructors/02_primitive_registry.py similarity index 60% rename from examples/04_additional/12_custom_constructors_registry.py rename to examples/05_custom_constructors/02_primitive_registry.py index 88eb1382..767765f0 100644 --- a/examples/04_additional/12_custom_constructors_registry.py +++ b/examples/05_custom_constructors/02_primitive_registry.py @@ -1,15 +1,17 @@ -"""Custom Constructors (Registry) +"""Custom Primitive (Registry) +For additional flexibility, :mod:`tyro.constructors` exposes tyro's API for +defining behavior for different types. There are two categories of types: +primitive types can be instantiated from a single commandline argument, while +struct types are broken down into multiple. -For additional flexibility, :module:`tyro.constructors` exposes -tyro's API for defining behavior for different types. This is the same -API that tyro relies on for the built-in types. + +In this example, we attach a custom constructor by defining a rule that applies +to all types that match ``dict[str, Any]``. Usage: -`python ./10_custom_constructors.py --help` -`python ./10_custom_constructors.py --dict1.json '{"hello": "world"}'` -`python ./10_custom_constructors.py --dict1.json "{\"hello\": \"world\"}"` -`python ./10_custom_constructors.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}'` -`python ./10_custom_constructors.py --dict1.json "{\"hello\": \"world\"}" --dict2.json "{\"hello\": \"world\"}"` +`python ./02_primitive_registry.py --help` +`python ./02_primitive_registry.py --dict1 '{"hello": "world"}'` +`python ./02_primitive_registry.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'` """ import json diff --git a/src/tyro/_arguments.py b/src/tyro/_arguments.py index a9f6489d..d0b4aef9 100644 --- a/src/tyro/_arguments.py +++ b/src/tyro/_arguments.py @@ -174,9 +174,9 @@ def add_argument( ) complete_as_path = ( # Catch types like Path, List[Path], Tuple[Path, ...] etc. - "Path" in str(self.field.type_or_callable) + "Path" in str(self.field.type_stripped) # For string types, we require more evidence. - or ("str" in str(self.field.type_or_callable) and name_suggests_path) + or ("str" in str(self.field.type_stripped) and name_suggests_path) ) if complete_as_path: arg.complete = shtab.DIRECTORY if name_suggests_dir else shtab.FILE # type: ignore @@ -233,7 +233,7 @@ def _rule_handle_boolean_flags( arg: ArgumentDefinition, lowered: LoweredArgumentDefinition, ) -> None: - if arg.field.type_or_callable is not bool: + if arg.field.type_stripped is not bool: return if ( @@ -281,15 +281,12 @@ def _rule_apply_primitive_specs( return try: - if arg.field.primitive_spec is not None: - spec = arg.field.primitive_spec - else: - spec = ConstructorRegistry._get_active_registry().get_primitive_spec( - PrimitiveTypeInfo.make( - cast(type, arg.field.type_or_callable), - arg.field.markers, - ) + spec = ConstructorRegistry._get_active_registry().get_primitive_spec( + PrimitiveTypeInfo.make( + cast(type, arg.field.type), + arg.field.markers, ) + ) except UnsupportedTypeAnnotationError as e: if arg.field.default in _singleton.MISSING_SINGLETONS: field_name = _strings.make_field_name( @@ -335,11 +332,11 @@ def _rule_apply_primitive_specs( def append_instantiator(x: list[list[str]]) -> Any: """Handle UseAppendAction effects.""" # We'll assume that the type is annotated as Dict[...], Tuple[...], List[...], etc. - container_type = get_origin(arg.field.type_or_callable) + container_type = get_origin(arg.field.type_stripped) if container_type is None: # Raw annotation, like `UseAppendAction[list]`. It's unlikely # that a user would use this but we can handle it. - container_type = arg.field.type_or_callable + container_type = arg.field.type_stripped # Instantiate initial output. out = ( @@ -407,7 +404,7 @@ def _rule_counters( """Handle counters, like -vvv for level-3 verbosity.""" if ( _markers.UseCounterAction in arg.field.markers - and arg.field.type_or_callable is int + and arg.field.type_stripped is int and not arg.field.is_positional() ): lowered.metavar = None @@ -459,7 +456,7 @@ def _rule_generate_helptext( if arg.field.argconf.constructor_factory is not None: default_label = ( str(default) - if arg.field.type_or_callable is not json.loads + if arg.field.type_stripped is not json.loads else json.dumps(arg.field.default) ) elif type(default) in (tuple, list, set): diff --git a/src/tyro/_calling.py b/src/tyro/_calling.py index b714c20e..53c091a1 100644 --- a/src/tyro/_calling.py +++ b/src/tyro/_calling.py @@ -70,7 +70,7 @@ def get_value_from_arg(prefixed_field_name: str) -> Any: ) # Resolve field type. - field_type = field.type_or_callable + field_type = field.type_stripped if prefixed_field_name in arg_from_prefixed_field_name: assert prefixed_field_name not in consumed_keywords diff --git a/src/tyro/_cli.py b/src/tyro/_cli.py index d24ec32e..1b741f5a 100644 --- a/src/tyro/_cli.py +++ b/src/tyro/_cli.py @@ -408,7 +408,6 @@ def _cli_impl( default_instance=default_instance_internal, # Overrides for default values. intern_prefix="", # Used for recursive calls. extern_prefix="", # Used for recursive calls. - subcommand_prefix="", # Used for recursive calls. ) # Generate parser! diff --git a/src/tyro/_fields.py b/src/tyro/_fields.py index b1f50680..2cc973c9 100644 --- a/src/tyro/_fields.py +++ b/src/tyro/_fields.py @@ -12,19 +12,22 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import docstring_parser -from typing_extensions import Annotated +from typing_extensions import Annotated, get_args, get_origin from . import _docstrings, _resolver, _strings, _unsafe_cache from ._singleton import DEFAULT_SENTINEL_SINGLETONS, MISSING_SINGLETONS from ._typing import TypeForm from .conf import _confstruct, _markers from .constructors._primitive_spec import ( - PrimitiveConstructorSpec, PrimitiveTypeInfo, UnsupportedTypeAnnotationError, ) from .constructors._registry import ConstructorRegistry -from .constructors._struct_spec import StructTypeInfo, UnsupportedStructTypeMessage +from .constructors._struct_spec import ( + StructFieldSpec, + StructTypeInfo, + UnsupportedStructTypeMessage, +) global_context_markers: List[Tuple[_markers.Marker, ...]] = [] @@ -33,7 +36,9 @@ class FieldDefinition: intern_name: str extern_name: str - type_or_callable: Union[TypeForm[Any], Callable] + type: TypeForm[Any] | Callable + """Full type, including runtime annotations.""" + type_stripped: TypeForm[Any] | Callable default: Any # We need to record whether defaults are from default instances to # determine if they should override the default in @@ -44,7 +49,6 @@ class FieldDefinition: custom_constructor: bool argconf: _confstruct._ArgConfig - primitive_spec: PrimitiveConstructorSpec | None # Override the name in our kwargs. Useful whenever the user-facing argument name # doesn't match the keyword expected by our callable. @@ -67,6 +71,17 @@ def marker_context(markers: Tuple[_markers.Marker, ...]): yield global_context_markers.pop() + @staticmethod + def from_field_spec(field_spec: StructFieldSpec) -> FieldDefinition: + return FieldDefinition.make( + name=field_spec.name, + typ=field_spec.type, + default=field_spec.default, + is_default_from_default_instance=field_spec.is_default_overridden, + helptext=field_spec.helptext, + call_argname_override=field_spec._call_argname, + ) + @staticmethod def make( name: str, @@ -114,18 +129,31 @@ def make( if argconf.help is not None: helptext = argconf.help - _, primitive_specs = _resolver.unwrap_annotated(typ, PrimitiveConstructorSpec) - if len(primitive_specs) > 0: - primitive_spec = primitive_specs[0] - else: - primitive_spec = None - - typ, markers = _resolver.unwrap_annotated(typ, _markers._Marker) + type_stripped, markers = _resolver.unwrap_annotated(typ, _markers._Marker) # Include markers set via context manager. for context_markers in global_context_markers: markers += context_markers + out = FieldDefinition( + intern_name=name, + extern_name=name if argconf.name is None else argconf.name, + type=typ, + type_stripped=type_stripped, + default=default, + is_default_from_default_instance=is_default_from_default_instance, + helptext=helptext, + markers=set(markers), + custom_constructor=argconf.constructor_factory is not None, + argconf=argconf, + call_argname=( + call_argname_override if call_argname_override is not None else name + ), + ) + + if argconf.constructor_factory is not None: + out = out.with_new_type_stripped(argconf.constructor_factory()) + # Check that the default value matches the final resolved type. # There's some similar Union-specific logic for this in narrow_union_type(). We # may be able to consolidate this. @@ -133,8 +161,8 @@ def make( # Be relatively conservative: isinstance() can be checked on non-type # types (like unions in Python >=3.10), but we'll only consider single types # for now. - type(typ) is type - and not isinstance(default, typ) # type: ignore + type(out.type_stripped) is type + and not isinstance(default, out.type_stripped) # type: ignore # If a custom constructor is set, static_type may not be # matched to the annotated type. and argconf.constructor_factory is None @@ -150,29 +178,25 @@ def make( f"but the default value {default} has type {type(default)}. " f"We'll try to handle this gracefully, but it may cause unexpected behavior." ) - typ = Union[typ, type(default)] # type: ignore + out = out.with_new_type_stripped(Union[out.type_stripped, type(default)]) # type: ignore - out = FieldDefinition( - intern_name=name, - extern_name=name if argconf.name is None else argconf.name, - type_or_callable=( - typ - if argconf.constructor_factory is None - else argconf.constructor_factory() - ), - default=default, - is_default_from_default_instance=is_default_from_default_instance, - helptext=helptext, - markers=set(markers), - custom_constructor=argconf.constructor_factory is not None, - argconf=argconf, - call_argname=( - call_argname_override if call_argname_override is not None else name - ), - primitive_spec=primitive_spec, - ) return out + def with_new_type_stripped( + self, new_type_stripped: TypeForm[Any] | Callable + ) -> FieldDefinition: + if get_origin(self.type) is Annotated: + new_type = Annotated.__class_getitem__( # type: ignore + (new_type_stripped, *get_args(self.type)[1:]) + ) + else: + new_type = new_type_stripped + return dataclasses.replace( + self, + type=new_type, + type_stripped=new_type_stripped, + ) + def is_positional(self) -> bool: """Returns True if the argument should be positional in the commandline.""" return ( @@ -240,17 +264,7 @@ def field_list_from_type_or_callable( with FieldDefinition.marker_context(type_info.markers): if spec is not None: - return f, [ - FieldDefinition.make( - f.name, - f.type, - f.default, - f.is_default_overridden, - f.helptext, - call_argname_override=f._call_argname, - ) - for f in spec.fields - ] + return f, [FieldDefinition.from_field_spec(f) for f in spec.fields] try: registry.get_primitive_spec(PrimitiveTypeInfo.make(f, set())) diff --git a/src/tyro/_parsers.py b/src/tyro/_parsers.py index 068b1ac8..06c7cb79 100644 --- a/src/tyro/_parsers.py +++ b/src/tyro/_parsers.py @@ -19,6 +19,7 @@ from typing_extensions import Annotated, get_args, get_origin +from tyro.constructors._registry import ConstructorRegistry from tyro.constructors._struct_spec import UnsupportedStructTypeMessage from . import _argparse as argparse @@ -34,7 +35,11 @@ ) from ._typing import TypeForm from .conf import _confstruct, _markers -from .constructors._primitive_spec import UnsupportedTypeAnnotationError +from .constructors._primitive_spec import ( + PrimitiveConstructorSpec, + PrimitiveTypeInfo, + UnsupportedTypeAnnotationError, +) T = TypeVar("T") @@ -315,8 +320,20 @@ def handle_field( ]: """Determine what to do with a single field definition.""" + registry = ConstructorRegistry._get_active_registry() + + # Force primitive if (1) the field is annotated with a primitive constructor spec, or (2) if + force_primitive = len( + _resolver.unwrap_annotated(field.type, PrimitiveConstructorSpec)[1] + ) > 0 or ( + len(registry._custom_primitive_rules) > 0 + and registry.get_primitive_spec( + PrimitiveTypeInfo.make(field.type, field.markers), rule_mode="custom" + ) + is not None + ) if ( - field.primitive_spec is None + not force_primitive and _markers.Fixed not in field.markers and _markers.Suppress not in field.markers ): @@ -333,21 +350,22 @@ def handle_field( and _markers.AvoidSubcommands in field.markers ): # Don't make a subparser. - field = dataclasses.replace(field, type_or_callable=type(field.default)) + field = field.with_new_type_stripped(type(field.default)) else: return subparsers_attempt # (2) Handle nested callables. - if _fields.is_struct_type(field.type_or_callable, field.default): - field = dataclasses.replace( - field, - type_or_callable=_resolver.narrow_subtypes( - field.type_or_callable, + if force_primitive == "struct" or _fields.is_struct_type( + field.type_stripped, field.default + ): + field = field.with_new_type_stripped( + _resolver.narrow_subtypes( + field.type_stripped, field.default, ), ) return ParserSpecification.from_callable_or_type( - field.type_or_callable, + field.type_stripped, markers=field.markers, description=None, parent_classes=parent_classes, @@ -393,7 +411,7 @@ def from_field( extern_prefix: str, ) -> Optional[SubparsersSpecification]: # Union of classes should create subparsers. - typ = _resolver.unwrap_annotated(field.type_or_callable) + typ = _resolver.unwrap_annotated(field.type_stripped) if get_origin(typ) not in (Union, _resolver.UnionType): return None diff --git a/src/tyro/_subcommand_matching.py b/src/tyro/_subcommand_matching.py index 82e8d20b..ed7bed7e 100644 --- a/src/tyro/_subcommand_matching.py +++ b/src/tyro/_subcommand_matching.py @@ -90,7 +90,7 @@ def make( return _TypeTree( typ_unwrap, { - field.intern_name: _TypeTree.make(field.type_or_callable, field.default) + field.intern_name: _TypeTree.make(field.type_stripped, field.default) for field in field_list }, ) diff --git a/src/tyro/conf/_confstruct.py b/src/tyro/conf/_confstruct.py index 0104c4d0..758182ab 100644 --- a/src/tyro/conf/_confstruct.py +++ b/src/tyro/conf/_confstruct.py @@ -93,10 +93,10 @@ def subcommand( is in a nested structure. constructor: A constructor type or function. This will be used in place of the argument's type for parsing arguments. For more - configurability, see :module:`tyro.constructors`. + configurability, see :mod:`tyro.constructors`. constructor_factory: A function that returns a constructor type. This will be used in place of the argument's type for parsing arguments. - For more configurability, see :module:`tyro.constructors`. + For more configurability, see :mod:`tyro.constructors`. """ assert not ( constructor is not None and constructor_factory is not None @@ -191,10 +191,10 @@ def arg( it is in a nested structure. Arguments are prefixed by default. constructor: A constructor type or function. This will be used in place of the argument's type for parsing arguments. For more - configurability, see :module:`tyro.constructors`. + configurability, see :mod:`tyro.constructors`. constructor_factory: A function that returns a constructor type. This will be used in place of the argument's type for parsing arguments. - For more configurability, see :module:`tyro.constructors`. + For more configurability, see :mod:`tyro.constructors`. Returns: Object to attach via `typing.Annotated[]`. diff --git a/src/tyro/constructors/__init__.py b/src/tyro/constructors/__init__.py index d93bfe0a..e54a8ee5 100644 --- a/src/tyro/constructors/__init__.py +++ b/src/tyro/constructors/__init__.py @@ -1,3 +1,11 @@ +"""The :mod:`tyro.constructors` submodule exposes tyro's API for defining +behavior for different types. + +.. warning:: + This submodule exposes advanced functionality, and is not needed for the + majority of users. +""" + from ._primitive_spec import PrimitiveConstructorSpec as PrimitiveConstructorSpec from ._primitive_spec import PrimitiveTypeInfo as PrimitiveTypeInfo from ._primitive_spec import ( diff --git a/src/tyro/constructors/_primitive_spec.py b/src/tyro/constructors/_primitive_spec.py index 729c2999..d8ebf74d 100644 --- a/src/tyro/constructors/_primitive_spec.py +++ b/src/tyro/constructors/_primitive_spec.py @@ -62,12 +62,19 @@ class PrimitiveTypeInfo: """The output of get_origin() on the static type.""" markers: set[_markers.Marker] """Set of tyro markers used to configure this field.""" + _primitive_spec: PrimitiveConstructorSpec | None + """Primitive constructor spec that was scraped from runtime annotations.""" @staticmethod def make( raw_annotation: TypeForm | Callable, parent_markers: set[_markers.Marker], ) -> PrimitiveTypeInfo: + _, primitive_specs = _resolver.unwrap_annotated( + raw_annotation, search_type=PrimitiveConstructorSpec + ) + primitive_spec = primitive_specs[0] if len(primitive_specs) > 0 else None + typ, extra_markers = _resolver.unwrap_annotated( raw_annotation, search_type=_markers._Marker ) @@ -75,6 +82,7 @@ def make( type=cast(TypeForm, typ), type_origin=get_origin(typ), markers=parent_markers | set(extra_markers), + _primitive_spec=primitive_spec, ) @@ -84,11 +92,11 @@ class PrimitiveConstructorSpec(Generic[T]): There are two ways to use this class: - First, we can include it in a type signature via `typing.Annotated`. + First, we can include it in a type signature via :class:`typing.Annotated`. This is the simplest for making local modifications to parsing behavior for individual fields. - Alternatively, it can be returned by a rule in a `PrimitiveConstructorRegistry`. + Alternatively, it can be returned by a rule in a :class:`ConstructorRegistry`. """ nargs: int | Literal["*"] @@ -118,7 +126,7 @@ def apply_default_primitive_rules(registry: ConstructorRegistry) -> None: from ._registry import ConstructorRegistry - @registry.primitive_rule + @registry._default_primitive_rules.append def any_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type is not Any: return None @@ -130,7 +138,7 @@ def any_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: # on the behavior so we'll do our best not to break it. vanilla_types = (int, str, float, bytes, json.loads) - @registry.primitive_rule + @registry._default_primitive_rules.append def basics_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type not in vanilla_types: return None @@ -152,7 +160,7 @@ def basics_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None if "torch" in sys.modules.keys(): import torch - @registry.primitive_rule + @registry._default_primitive_rules.append def torch_device_rule( type_info: PrimitiveTypeInfo, ) -> PrimitiveConstructorSpec | None: @@ -166,7 +174,7 @@ def torch_device_rule( str_from_instance=lambda instance: [str(instance)], ) - @registry.primitive_rule + @registry._default_primitive_rules.append def bool_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type is not bool: return None @@ -179,7 +187,7 @@ def bool_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: str_from_instance=lambda instance: ["True" if instance else "False"], ) - @registry.primitive_rule + @registry._default_primitive_rules.append def nonetype_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type is not type(None): return None @@ -192,7 +200,7 @@ def nonetype_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | No str_from_instance=lambda instance: ["None"], ) - @registry.primitive_rule + @registry._default_primitive_rules.append def path_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if not ( type_info.type in (os.PathLike, pathlib.Path) @@ -210,7 +218,7 @@ def path_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: str_from_instance=lambda instance: [str(instance)], ) - @registry.primitive_rule + @registry._default_primitive_rules.append def enum_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if not ( inspect.isclass(type_info.type) and issubclass(type_info.type, enum.Enum) @@ -243,7 +251,7 @@ def enum_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: choices=choices, ) - @registry.primitive_rule + @registry._default_primitive_rules.append def datetime_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type not in (datetime.datetime, datetime.date, datetime.time): return None @@ -264,7 +272,7 @@ def datetime_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | No str_from_instance=lambda instance: [instance.isoformat()], ) - @registry.primitive_rule + @registry._default_primitive_rules.append def vague_container_rule( type_info: PrimitiveTypeInfo, ) -> PrimitiveConstructorSpec | None: @@ -299,7 +307,7 @@ def vague_container_rule( ) ) - @registry.primitive_rule + @registry._default_primitive_rules.append def sequence_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type_origin not in ( collections.abc.Sequence, @@ -384,7 +392,7 @@ def str_from_instance(instance: Sequence) -> list[str]: choices=inner_spec.choices, ) - @registry.primitive_rule + @registry._default_primitive_rules.append def tuple_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type_origin is not tuple: return None @@ -449,7 +457,7 @@ def str_from_instance(instance: tuple) -> list[str]: and all(spec.is_instance(member) for member, spec in zip(x, inner_specs)), ) - @registry.primitive_rule + @registry._default_primitive_rules.append def dict_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type_origin not in (dict, collections.abc.Mapping): return None @@ -542,7 +550,7 @@ def str_from_instance(instance: dict) -> list[str]: str_from_instance=str_from_instance, ) - @registry.primitive_rule + @registry._default_primitive_rules.append def literal_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type_origin not in (Literal, LiteralAlternate): return None @@ -568,7 +576,7 @@ def literal_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | Non choices=str_choices, ) - @registry.primitive_rule + @registry._default_primitive_rules.append def union_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type_origin not in (Union, _resolver.UnionType): return None diff --git a/src/tyro/constructors/_registry.py b/src/tyro/constructors/_registry.py index 654379fc..7da3f081 100644 --- a/src/tyro/constructors/_registry.py +++ b/src/tyro/constructors/_registry.py @@ -2,6 +2,8 @@ from typing import Any, Callable, ClassVar, Union +from typing_extensions import Literal + from ._primitive_spec import ( PrimitiveConstructorSpec, PrimitiveTypeInfo, @@ -58,7 +60,8 @@ class ConstructorRegistry: _old_registry: ConstructorRegistry | None = None def __init__(self) -> None: - self._primitive_rules: list[PrimitiveSpecRule] = [] + self._default_primitive_rules: list[PrimitiveSpecRule] = [] + self._custom_primitive_rules: list[PrimitiveSpecRule] = [] self._struct_rules: list[StructSpecRule] = [] # Apply the default primitive-handling rules. @@ -67,9 +70,13 @@ def __init__(self) -> None: def primitive_rule(self, rule: PrimitiveSpecRule) -> PrimitiveSpecRule: """Define a rule for constructing a primitive type from a string. The - most recently added rule will be applied first.""" + most recently added rule will be applied first. + + Custom primitive rules will take precedence over both default primitive + rules and struct rules + """ - self._primitive_rules.append(rule) + self._custom_primitive_rules.append(rule) return rule def struct_rule(self, rule: StructSpecRule) -> StructSpecRule: @@ -80,13 +87,25 @@ def struct_rule(self, rule: StructSpecRule) -> StructSpecRule: return rule def get_primitive_spec( - self, type_info: PrimitiveTypeInfo + self, + type_info: PrimitiveTypeInfo, + rule_mode: Literal["default", "custom", "all"] = "all", ) -> PrimitiveConstructorSpec: """Get a constructor specification for a given type.""" - for spec_factory in self._primitive_rules[::-1]: - maybe_spec = spec_factory(type_info) - if maybe_spec is not None: - return maybe_spec + + if type_info._primitive_spec is not None: + return type_info._primitive_spec + + if rule_mode in ("custom", "all"): + for spec_factory in self._custom_primitive_rules[::-1]: + maybe_spec = spec_factory(type_info) + if maybe_spec is not None: + return maybe_spec + if rule_mode in ("default", "all"): + for spec_factory in self._default_primitive_rules[::-1]: + maybe_spec = spec_factory(type_info) + if maybe_spec is not None: + return maybe_spec raise UnsupportedTypeAnnotationError( f"Unsupported type annotation: {type_info.type}" @@ -97,6 +116,7 @@ def get_struct_spec( ) -> StructConstructorSpec | None: """Get a constructor specification for a given type. Returns `None` if unsuccessful.""" + for spec_factory in self._struct_rules[::-1]: maybe_spec = spec_factory(type_info) if maybe_spec is not None: diff --git a/src/tyro/constructors/_struct_spec.py b/src/tyro/constructors/_struct_spec.py index 06f0776f..8f54c21c 100644 --- a/src/tyro/constructors/_struct_spec.py +++ b/src/tyro/constructors/_struct_spec.py @@ -43,18 +43,40 @@ class StructFieldSpec: """Behavior specification for a single field in our callable.""" name: str + """The name of the field. This will be used as a keyword argument for the + struct's associated `instantiate(**kwargs)` function.""" type: TypeForm + """The type of the field. Can be either a primitive or a nested struct type.""" default: Any - is_default_overridden: bool - helptext: str | None + """The default value of the field.""" + is_default_overridden: bool = False + """Whether the default value was overridden by the default instance. Should + be set to False if the default value was assigned by the field itself.""" + helptext: str | None = None + """Helpjext for the field.""" # TODO: it's theoretically possible to override the argname with `None`. _call_argname: Any = None + """Private: the name of the argument to pass to the callable. This is used + for dictionary types.""" @dataclasses.dataclass(frozen=True) class StructConstructorSpec: + """Specification for a struct type, which is broken down into multiple + fields. + + Each struct type is instantiated by calling an `instantiate(**kwargs)` + function with keyword a set of keyword arguments. + + Unlike `PrimitiveConstructorSpec`, there is only one way to use this class. + It must be returned by a rule in `ConstructorRegistry`. + """ + instantiate: Callable[..., Any] + """Function to call to instantiate the struct.""" fields: tuple[StructFieldSpec, ...] + """Fields used to construct the callable. Each field is used as a keyword + argument for the `instantiate(**kwargs)` function.""" @dataclasses.dataclass(frozen=True) @@ -81,6 +103,7 @@ def make(f: TypeForm | Callable, default: Any) -> StructTypeInfo: f = typevar_context.origin_type f = _resolver.narrow_subtypes(f, default) f = _resolver.narrow_collection_types(f, default) + return StructTypeInfo( cast(TypeForm, f), parent_markers, default, typevar_context ) diff --git a/src/tyro/extras/__init__.py b/src/tyro/extras/__init__.py index 5341bc3d..e547a39e 100644 --- a/src/tyro/extras/__init__.py +++ b/src/tyro/extras/__init__.py @@ -1,6 +1,7 @@ """The :mod:`tyro.extras` submodule contains helpers that complement :func:`tyro.cli()`. -Compared to the core interface, APIs here are more likely to be changed or deprecated. +.. warning:: + Compared to the core interface, APIs here are more likely to be changed or deprecated. """ from .._argparse_formatter import set_accent_color as set_accent_color diff --git a/tests/test_custom_primitive.py b/tests/test_custom_primitive.py deleted file mode 100644 index a2758999..00000000 --- a/tests/test_custom_primitive.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -import json -from typing import Any, Dict - -from typing_extensions import Annotated, get_args - -import tyro - -json_constructor_spec = tyro.constructors.PrimitiveConstructorSpec( - nargs=1, - metavar="JSON", - instance_from_str=lambda args: json.loads(args[0]), - is_instance=lambda x: isinstance(x, dict), - str_from_instance=lambda x: [json.dumps(x)], -) - - -def test_custom_primitive_registry(): - """Test that we can use a custom primitive registry to parse a custom type.""" - primitive_registry = tyro.constructors.ConstructorRegistry() - - @primitive_registry.primitive_rule - def json_dict_spec( - type_info: tyro.constructors.PrimitiveTypeInfo, - ) -> tyro.constructors.PrimitiveConstructorSpec | None: - if not ( - type_info.type_origin is dict and get_args(type_info.type) == (str, Any) - ): - return None - return json_constructor_spec - - def main(x: Dict[str, Any]) -> Dict[str, Any]: - return x - - with primitive_registry: - assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1} - - -def test_custom_primitive_annotated(): - """Test that we can use typing.Annotated to specify custom constructors.""" - - def main(x: Annotated[Dict[str, Any], json_constructor_spec]) -> Dict[str, Any]: - return x - - assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1} diff --git a/tests/test_errors.py b/tests/test_errors.py index 1e4c9fd5..c2690ddc 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -572,7 +572,7 @@ def main2() -> None: def test_wrong_annotation() -> None: @dataclasses.dataclass class Args: - x: dict = None # type: ignore + x: dict | int = None # type: ignore with pytest.warns(UserWarning): assert tyro.cli(Args, args=[]).x is None diff --git a/tests/test_py311_generated/test_collections_generated.py b/tests/test_py311_generated/test_collections_generated.py index 9aca7e8a..5810a021 100644 --- a/tests/test_py311_generated/test_collections_generated.py +++ b/tests/test_py311_generated/test_collections_generated.py @@ -494,7 +494,7 @@ def main(x: Dict = {"int": 5, "str": "5"}): def test_dict_optional() -> None: # In this case, the `None` is ignored. - def main(x: Optional[Dict[str, int]] = {"three": 3, "five": 5}): + def main(x: Optional[Dict[str, float]] = {"three": 3, "five": 5}): return x assert tyro.cli(main, args=[]) == {"three": 3, "five": 5} diff --git a/tests/test_py311_generated/test_custom_primitive_generated.py b/tests/test_py311_generated/test_custom_primitive_generated.py deleted file mode 100644 index 4244759a..00000000 --- a/tests/test_py311_generated/test_custom_primitive_generated.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations - -import json -from typing import Annotated, Any, Dict, get_args - -import tyro - -json_constructor_spec = tyro.constructors.PrimitiveConstructorSpec( - nargs=1, - metavar="JSON", - instance_from_str=lambda args: json.loads(args[0]), - is_instance=lambda x: isinstance(x, dict), - str_from_instance=lambda x: [json.dumps(x)], -) - - -def test_custom_primitive_registry(): - """Test that we can use a custom primitive registry to parse a custom type.""" - primitive_registry = tyro.constructors.ConstructorRegistry() - - @primitive_registry.primitive_rule - def json_dict_spec( - type_info: tyro.constructors.PrimitiveTypeInfo, - ) -> tyro.constructors.PrimitiveConstructorSpec | None: - if not ( - type_info.type_origin is dict and get_args(type_info.type) == (str, Any) - ): - return None - return json_constructor_spec - - def main(x: Dict[str, Any]) -> Dict[str, Any]: - return x - - with primitive_registry: - assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1} - - -def test_custom_primitive_annotated(): - """Test that we can use typing.Annotated to specify custom constructors.""" - - def main(x: Annotated[Dict[str, Any], json_constructor_spec]) -> Dict[str, Any]: - return x - - assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1} diff --git a/tests/test_py311_generated/test_errors_generated.py b/tests/test_py311_generated/test_errors_generated.py index 81fdd150..7cbe5a54 100644 --- a/tests/test_py311_generated/test_errors_generated.py +++ b/tests/test_py311_generated/test_errors_generated.py @@ -583,7 +583,7 @@ def main2() -> None: def test_wrong_annotation() -> None: @dataclasses.dataclass class Args: - x: dict = None # type: ignore + x: dict | int = None # type: ignore with pytest.warns(UserWarning): assert tyro.cli(Args, args=[]).x is None From 022de9bae81514105660502e8c4e59c85fd7a117 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 5 Nov 2024 03:08:03 -0800 Subject: [PATCH 2/5] Fix default primitive rules --- src/tyro/constructors/_primitive_spec.py | 28 ++++++++++++------------ src/tyro/constructors/_registry.py | 4 ++++ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/tyro/constructors/_primitive_spec.py b/src/tyro/constructors/_primitive_spec.py index d8ebf74d..5e516e08 100644 --- a/src/tyro/constructors/_primitive_spec.py +++ b/src/tyro/constructors/_primitive_spec.py @@ -126,7 +126,7 @@ def apply_default_primitive_rules(registry: ConstructorRegistry) -> None: from ._registry import ConstructorRegistry - @registry._default_primitive_rules.append + @registry._default_primitive_rule def any_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type is not Any: return None @@ -138,7 +138,7 @@ def any_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: # on the behavior so we'll do our best not to break it. vanilla_types = (int, str, float, bytes, json.loads) - @registry._default_primitive_rules.append + @registry._default_primitive_rule def basics_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type not in vanilla_types: return None @@ -160,7 +160,7 @@ def basics_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None if "torch" in sys.modules.keys(): import torch - @registry._default_primitive_rules.append + @registry._default_primitive_rule def torch_device_rule( type_info: PrimitiveTypeInfo, ) -> PrimitiveConstructorSpec | None: @@ -174,7 +174,7 @@ def torch_device_rule( str_from_instance=lambda instance: [str(instance)], ) - @registry._default_primitive_rules.append + @registry._default_primitive_rule def bool_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type is not bool: return None @@ -187,7 +187,7 @@ def bool_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: str_from_instance=lambda instance: ["True" if instance else "False"], ) - @registry._default_primitive_rules.append + @registry._default_primitive_rule def nonetype_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type is not type(None): return None @@ -200,7 +200,7 @@ def nonetype_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | No str_from_instance=lambda instance: ["None"], ) - @registry._default_primitive_rules.append + @registry._default_primitive_rule def path_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if not ( type_info.type in (os.PathLike, pathlib.Path) @@ -218,7 +218,7 @@ def path_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: str_from_instance=lambda instance: [str(instance)], ) - @registry._default_primitive_rules.append + @registry._default_primitive_rule def enum_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if not ( inspect.isclass(type_info.type) and issubclass(type_info.type, enum.Enum) @@ -251,7 +251,7 @@ def enum_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: choices=choices, ) - @registry._default_primitive_rules.append + @registry._default_primitive_rule def datetime_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type not in (datetime.datetime, datetime.date, datetime.time): return None @@ -272,7 +272,7 @@ def datetime_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | No str_from_instance=lambda instance: [instance.isoformat()], ) - @registry._default_primitive_rules.append + @registry._default_primitive_rule def vague_container_rule( type_info: PrimitiveTypeInfo, ) -> PrimitiveConstructorSpec | None: @@ -307,7 +307,7 @@ def vague_container_rule( ) ) - @registry._default_primitive_rules.append + @registry._default_primitive_rule def sequence_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type_origin not in ( collections.abc.Sequence, @@ -392,7 +392,7 @@ def str_from_instance(instance: Sequence) -> list[str]: choices=inner_spec.choices, ) - @registry._default_primitive_rules.append + @registry._default_primitive_rule def tuple_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type_origin is not tuple: return None @@ -457,7 +457,7 @@ def str_from_instance(instance: tuple) -> list[str]: and all(spec.is_instance(member) for member, spec in zip(x, inner_specs)), ) - @registry._default_primitive_rules.append + @registry._default_primitive_rule def dict_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type_origin not in (dict, collections.abc.Mapping): return None @@ -550,7 +550,7 @@ def str_from_instance(instance: dict) -> list[str]: str_from_instance=str_from_instance, ) - @registry._default_primitive_rules.append + @registry._default_primitive_rule def literal_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type_origin not in (Literal, LiteralAlternate): return None @@ -576,7 +576,7 @@ def literal_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | Non choices=str_choices, ) - @registry._default_primitive_rules.append + @registry._default_primitive_rule def union_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None: if type_info.type_origin not in (Union, _resolver.UnionType): return None diff --git a/src/tyro/constructors/_registry.py b/src/tyro/constructors/_registry.py index 7da3f081..781d406a 100644 --- a/src/tyro/constructors/_registry.py +++ b/src/tyro/constructors/_registry.py @@ -79,6 +79,10 @@ def primitive_rule(self, rule: PrimitiveSpecRule) -> PrimitiveSpecRule: self._custom_primitive_rules.append(rule) return rule + def _default_primitive_rule(self, rule: PrimitiveSpecRule) -> PrimitiveSpecRule: + self._default_primitive_rules.append(rule) + return rule + def struct_rule(self, rule: StructSpecRule) -> StructSpecRule: """Define a rule for constructing a primitive type from a string. The most recently added rule will be applied first.""" From dabfa990eea98b2777c656a392bdd227a17f594a Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 5 Nov 2024 03:09:35 -0800 Subject: [PATCH 3/5] tests --- tests/test_custom_constructors.py | 65 +++++++++++++++++++ tests/test_py311_generated/ok.py | 12 ---- .../test_custom_constructors_generated.py | 63 ++++++++++++++++++ 3 files changed, 128 insertions(+), 12 deletions(-) create mode 100644 tests/test_custom_constructors.py delete mode 100644 tests/test_py311_generated/ok.py create mode 100644 tests/test_py311_generated/test_custom_constructors_generated.py diff --git a/tests/test_custom_constructors.py b/tests/test_custom_constructors.py new file mode 100644 index 00000000..4c2a1cbe --- /dev/null +++ b/tests/test_custom_constructors.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import json +from typing import Any, Dict, Union + +from typing_extensions import Annotated, get_args + +import tyro + +json_constructor_spec = tyro.constructors.PrimitiveConstructorSpec( + nargs=1, + metavar="JSON", + instance_from_str=lambda args: json.loads(args[0]), + is_instance=lambda x: isinstance(x, dict), + str_from_instance=lambda x: [json.dumps(x)], +) + + +def test_custom_primitive_registry(): + """Test that we can use a custom primitive registry to parse a custom type.""" + primitive_registry = tyro.constructors.ConstructorRegistry() + + @primitive_registry.primitive_rule + def json_dict_spec( + type_info: tyro.constructors.PrimitiveTypeInfo, + ) -> tyro.constructors.PrimitiveConstructorSpec | None: + if not ( + type_info.type_origin is dict and get_args(type_info.type) == (str, Any) + ): + return None + return json_constructor_spec + + def main(x: Dict[str, Any]) -> Dict[str, Any]: + return x + + with primitive_registry: + assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1} + + def main_with_default(x: Dict[str, Any] = {"hello": 5}) -> Dict[str, Any]: + return x + + with primitive_registry: + assert tyro.cli(main_with_default, args=[]) == {"hello": 5} + assert tyro.cli(main_with_default, args=["--x", '{"a": 1}']) == {"a": 1} + + +def test_custom_primitive_annotated(): + """Test that we can use typing.Annotated to specify custom constructors.""" + + def main(x: Annotated[Dict[str, Any], json_constructor_spec]) -> Dict[str, Any]: + return x + + assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1} + + +def test_custom_primitive_union(): + """Test that we can use typing.Annotated to specify custom constructors.""" + + def main( + x: Union[int, Annotated[Dict[str, Any], json_constructor_spec]], + ) -> Union[int, Dict[str, Any]]: + return x + + assert tyro.cli(main, args=["--x", "3"]) == 3 + assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1} diff --git a/tests/test_py311_generated/ok.py b/tests/test_py311_generated/ok.py deleted file mode 100644 index 73e5b33b..00000000 --- a/tests/test_py311_generated/ok.py +++ /dev/null @@ -1,12 +0,0 @@ -from dataclasses import dataclass -from typing import Literal - -import tyro - - -@dataclass(frozen=True) -class Container[T]: - a: T - - -tyro.cli(Container[Container[bool] | Container[Literal["1", "2"]]]) diff --git a/tests/test_py311_generated/test_custom_constructors_generated.py b/tests/test_py311_generated/test_custom_constructors_generated.py new file mode 100644 index 00000000..6e0cd35e --- /dev/null +++ b/tests/test_py311_generated/test_custom_constructors_generated.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import json +from typing import Annotated, Any, Dict, get_args + +import tyro + +json_constructor_spec = tyro.constructors.PrimitiveConstructorSpec( + nargs=1, + metavar="JSON", + instance_from_str=lambda args: json.loads(args[0]), + is_instance=lambda x: isinstance(x, dict), + str_from_instance=lambda x: [json.dumps(x)], +) + + +def test_custom_primitive_registry(): + """Test that we can use a custom primitive registry to parse a custom type.""" + primitive_registry = tyro.constructors.ConstructorRegistry() + + @primitive_registry.primitive_rule + def json_dict_spec( + type_info: tyro.constructors.PrimitiveTypeInfo, + ) -> tyro.constructors.PrimitiveConstructorSpec | None: + if not ( + type_info.type_origin is dict and get_args(type_info.type) == (str, Any) + ): + return None + return json_constructor_spec + + def main(x: Dict[str, Any]) -> Dict[str, Any]: + return x + + with primitive_registry: + assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1} + + def main_with_default(x: Dict[str, Any] = {"hello": 5}) -> Dict[str, Any]: + return x + + with primitive_registry: + assert tyro.cli(main_with_default, args=[]) == {"hello": 5} + assert tyro.cli(main_with_default, args=["--x", '{"a": 1}']) == {"a": 1} + + +def test_custom_primitive_annotated(): + """Test that we can use typing.Annotated to specify custom constructors.""" + + def main(x: Annotated[Dict[str, Any], json_constructor_spec]) -> Dict[str, Any]: + return x + + assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1} + + +def test_custom_primitive_union(): + """Test that we can use typing.Annotated to specify custom constructors.""" + + def main( + x: int | Annotated[Dict[str, Any], json_constructor_spec], + ) -> int | Dict[str, Any]: + return x + + assert tyro.cli(main, args=["--x", "3"]) == 3 + assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1} From 53c7727f54209973ed4bfb890f533e3739ed1891 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 5 Nov 2024 03:10:43 -0800 Subject: [PATCH 4/5] sync docs --- .../examples/04_additional/11_aliases.rst | 96 ++++++++++++++++++ .../04_additional/12_type_statement.rst | 50 ++++++++++ .../01_primitive_annotation.rst | 87 +++++++++++++++++ .../02_primitive_registry.rst | 97 +++++++++++++++++++ 4 files changed, 330 insertions(+) create mode 100644 docs/source/examples/04_additional/11_aliases.rst create mode 100644 docs/source/examples/04_additional/12_type_statement.rst create mode 100644 docs/source/examples/05_custom_constructors/01_primitive_annotation.rst create mode 100644 docs/source/examples/05_custom_constructors/02_primitive_registry.rst diff --git a/docs/source/examples/04_additional/11_aliases.rst b/docs/source/examples/04_additional/11_aliases.rst new file mode 100644 index 00000000..ad728274 --- /dev/null +++ b/docs/source/examples/04_additional/11_aliases.rst @@ -0,0 +1,96 @@ +.. Comment: this file is automatically generated by `update_example_docs.py`. + It should not be modified manually. + +Argument Aliases +========================================== + +:func:`tyro.conf.arg()` can be used to attach aliases to arguments. + + +.. code-block:: python + :linenos: + + + from typing_extensions import Annotated + + import tyro + + + def checkout( + branch: Annotated[str, tyro.conf.arg(aliases=["-b"])], + ) -> None: + """Check out a branch.""" + print(f"{branch=}") + + + def commit( + message: Annotated[str, tyro.conf.arg(aliases=["-m"])], + all: Annotated[bool, tyro.conf.arg(aliases=["-a"])] = False, + ) -> None: + """Make a commit.""" + print(f"{message=} {all=}") + + + if __name__ == "__main__": + tyro.extras.subcommand_cli_from_dict( + { + "checkout": checkout, + "commit": commit, + } + ) + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py --help + +.. program-output:: python ../../examples/04_additional/11_aliases.py --help + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py commit --help + +.. program-output:: python ../../examples/04_additional/11_aliases.py commit --help + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py commit --message hello --all + +.. program-output:: python ../../examples/04_additional/11_aliases.py commit --message hello --all + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py commit -m hello -a + +.. program-output:: python ../../examples/04_additional/11_aliases.py commit -m hello -a + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py checkout --help + +.. program-output:: python ../../examples/04_additional/11_aliases.py checkout --help + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py checkout --branch main + +.. program-output:: python ../../examples/04_additional/11_aliases.py checkout --branch main + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py checkout -b main + +.. program-output:: python ../../examples/04_additional/11_aliases.py checkout -b main diff --git a/docs/source/examples/04_additional/12_type_statement.rst b/docs/source/examples/04_additional/12_type_statement.rst new file mode 100644 index 00000000..63975f1e --- /dev/null +++ b/docs/source/examples/04_additional/12_type_statement.rst @@ -0,0 +1,50 @@ +.. Comment: this file is automatically generated by `update_example_docs.py`. + It should not be modified manually. + +Type Aliases (Python 3.12+) +========================================== + +In Python 3.12, the :code:`type` statement is introduced to create type aliases. + + +.. code-block:: python + :linenos: + + + import dataclasses + + import tyro + + # Lazily-evaluated type alias. + type Field1Type = Inner + + + @dataclasses.dataclass + class Inner: + a: int + b: str + + + @dataclasses.dataclass + class Args: + """Description. + This should show up in the helptext!""" + + field1: Field1Type + """A field.""" + + field2: int = 3 + """A numeric field, with a default value.""" + + + if __name__ == "__main__": + args = tyro.cli(Args) + print(args) + +------------ + +.. raw:: html + + python 04_additional/12_type_statement.py --help + +.. program-output:: python ../../examples/04_additional/12_type_statement.py --help diff --git a/docs/source/examples/05_custom_constructors/01_primitive_annotation.rst b/docs/source/examples/05_custom_constructors/01_primitive_annotation.rst new file mode 100644 index 00000000..fbe7e192 --- /dev/null +++ b/docs/source/examples/05_custom_constructors/01_primitive_annotation.rst @@ -0,0 +1,87 @@ +.. Comment: this file is automatically generated by `update_example_docs.py`. + It should not be modified manually. + +Custom Primitive +========================================== + +For additional flexibility, :mod:`tyro.constructors` exposes tyro's API for +defining behavior for different types. There are two categories of types: +primitive types can be instantiated from a single commandline argument, while +struct types are broken down into multiple. + +In this example, we attach a custom constructor via a runtime annotation. + + +.. code-block:: python + :linenos: + + + import json + + from typing_extensions import Annotated + + import tyro + + # A dictionary type, but `tyro` will expect a JSON string from the CLI. + JsonDict = Annotated[ + dict, + tyro.constructors.PrimitiveConstructorSpec( + nargs=1, + metavar="JSON", + instance_from_str=lambda args: json.loads(args[0]), + is_instance=lambda instance: isinstance(instance, dict), + str_from_instance=lambda instance: [json.dumps(instance)], + ), + ] + + + def main( + dict1: JsonDict, + dict2: JsonDict = {"default": None}, + ) -> None: + print(f"{dict1=}") + print(f"{dict2=}") + + + if __name__ == "__main__": + tyro.cli(main) + +------------ + +.. raw:: html + + python 05_custom_constructors/01_primitive_annotation.py --help + +.. program-output:: python ../../examples/05_custom_constructors/01_primitive_annotation.py --help + +------------ + +.. raw:: html + + python 05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}' + +.. program-output:: python ../../examples/05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}' + +------------ + +.. raw:: html + + python 05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}' + +.. program-output:: python ../../examples/05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}' + +------------ + +.. raw:: html + + python 05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}' + +.. program-output:: python ../../examples/05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}' + +------------ + +.. raw:: html + + python 05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}' + +.. program-output:: python ../../examples/05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}' diff --git a/docs/source/examples/05_custom_constructors/02_primitive_registry.rst b/docs/source/examples/05_custom_constructors/02_primitive_registry.rst new file mode 100644 index 00000000..de3b210f --- /dev/null +++ b/docs/source/examples/05_custom_constructors/02_primitive_registry.rst @@ -0,0 +1,97 @@ +.. Comment: this file is automatically generated by `update_example_docs.py`. + It should not be modified manually. + +Custom Primitive (Registry) +========================================== + +For additional flexibility, :mod:`tyro.constructors` exposes tyro's API for +defining behavior for different types. There are two categories of types: +primitive types can be instantiated from a single commandline argument, while +struct types are broken down into multiple. + + +In this example, we attach a custom constructor by defining a rule that applies +to all types that match `dict[str, Any]`. + + +.. code-block:: python + :linenos: + + + import json + from typing import Any + + import tyro + + custom_registry = tyro.constructors.ConstructorRegistry() + + + @custom_registry.primitive_rule + def _( + type_info: tyro.constructors.PrimitiveTypeInfo, + ) -> tyro.constructors.PrimitiveConstructorSpec | None: + # We return `None` if the rule does not apply. + if type_info.type != dict[str, Any]: + return None + + # If the rule applies, we return the constructor spec. + return tyro.constructors.PrimitiveConstructorSpec( + nargs=1, + metavar="JSON", + instance_from_str=lambda args: json.loads(args[0]), + is_instance=lambda instance: isinstance(instance, dict), + str_from_instance=lambda instance: [json.dumps(instance)], + ) + + + def main( + dict1: dict[str, Any], + dict2: dict[str, Any] = {"default": None}, + ) -> None: + print(f"{dict1=}") + print(f"{dict2=}") + + + if __name__ == "__main__": + with custom_registry: + tyro.cli(main) + +------------ + +.. raw:: html + + python 05_custom_constructors/02_primitive_registry.py --help + +.. program-output:: python ../../examples/05_custom_constructors/02_primitive_registry.py --help + +------------ + +.. raw:: html + + python 05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}' + +.. program-output:: python ../../examples/05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}' + +------------ + +.. raw:: html + + python 05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}' + +.. program-output:: python ../../examples/05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}' + +------------ + +.. raw:: html + + python 05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}' + +.. program-output:: python ../../examples/05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}' + +------------ + +.. raw:: html + + python 05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}' + +.. program-output:: python ../../examples/05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}' From 55b860131e3a5be483a6a125d3502b5a59259b0f Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 5 Nov 2024 03:13:32 -0800 Subject: [PATCH 5/5] union fix --- tests/test_errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_errors.py b/tests/test_errors.py index c2690ddc..09ea7b53 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -572,7 +572,7 @@ def main2() -> None: def test_wrong_annotation() -> None: @dataclasses.dataclass class Args: - x: dict | int = None # type: ignore + x: Union[dict, int] = None # type: ignore with pytest.warns(UserWarning): assert tyro.cli(Args, args=[]).x is None