diff --git a/docs/source/conf.py b/docs/source/conf.py
index 1ac56224..a22e665f 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -395,6 +395,8 @@ def docstring(app, what, name, obj, options, lines):
rst = m2r2.convert(md)
lines.clear()
lines += rst.splitlines() # type: ignore
+ lines.append("")
+ lines.append("")
def setup(app):
diff --git a/docs/source/examples/04_additional/06_generics_py312.rst b/docs/source/examples/04_additional/06_generics_py312.rst
index 44dbb9d8..23edc94c 100644
--- a/docs/source/examples/04_additional/06_generics_py312.rst
+++ b/docs/source/examples/04_additional/06_generics_py312.rst
@@ -1,7 +1,7 @@
.. Comment: this file is automatically generated by `update_example_docs.py`.
It should not be modified manually.
-Generic Types (Python 3.12+ syntax)
+Generic Types (Python 3.12+)
==========================================
Example of parsing for generic dataclasses using syntax introduced in Python
diff --git a/docs/source/examples/04_additional/17_aliases.rst b/docs/source/examples/04_additional/11_aliases.rst
similarity index 61%
rename from docs/source/examples/04_additional/17_aliases.rst
rename to docs/source/examples/04_additional/11_aliases.rst
index e53f2842..ad728274 100644
--- a/docs/source/examples/04_additional/17_aliases.rst
+++ b/docs/source/examples/04_additional/11_aliases.rst
@@ -43,54 +43,54 @@ Argument Aliases
.. raw:: html
- python 04_additional/17_aliases.py --help
+ python 04_additional/11_aliases.py --help
-.. program-output:: python ../../examples/04_additional/17_aliases.py --help
+.. program-output:: python ../../examples/04_additional/11_aliases.py --help
------------
.. raw:: html
- python 04_additional/17_aliases.py commit --help
+ python 04_additional/11_aliases.py commit --help
-.. program-output:: python ../../examples/04_additional/17_aliases.py commit --help
+.. program-output:: python ../../examples/04_additional/11_aliases.py commit --help
------------
.. raw:: html
- python 04_additional/17_aliases.py commit --message hello --all
+ python 04_additional/11_aliases.py commit --message hello --all
-.. program-output:: python ../../examples/04_additional/17_aliases.py commit --message hello --all
+.. program-output:: python ../../examples/04_additional/11_aliases.py commit --message hello --all
------------
.. raw:: html
- python 04_additional/17_aliases.py commit -m hello -a
+ python 04_additional/11_aliases.py commit -m hello -a
-.. program-output:: python ../../examples/04_additional/17_aliases.py commit -m hello -a
+.. program-output:: python ../../examples/04_additional/11_aliases.py commit -m hello -a
------------
.. raw:: html
- python 04_additional/17_aliases.py checkout --help
+ python 04_additional/11_aliases.py checkout --help
-.. program-output:: python ../../examples/04_additional/17_aliases.py checkout --help
+.. program-output:: python ../../examples/04_additional/11_aliases.py checkout --help
------------
.. raw:: html
- python 04_additional/17_aliases.py checkout --branch main
+ python 04_additional/11_aliases.py checkout --branch main
-.. program-output:: python ../../examples/04_additional/17_aliases.py checkout --branch main
+.. program-output:: python ../../examples/04_additional/11_aliases.py checkout --branch main
------------
.. raw:: html
- python 04_additional/17_aliases.py checkout -b main
+ python 04_additional/11_aliases.py checkout -b main
-.. program-output:: python ../../examples/04_additional/17_aliases.py checkout -b main
+.. program-output:: python ../../examples/04_additional/11_aliases.py checkout -b main
diff --git a/docs/source/examples/04_additional/11_custom_constructors.rst b/docs/source/examples/04_additional/11_custom_constructors.rst
deleted file mode 100644
index 24cea387..00000000
--- a/docs/source/examples/04_additional/11_custom_constructors.rst
+++ /dev/null
@@ -1,84 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Custom Constructors
-==========================================
-
-For additional flexibility, :module:`tyro.constructors` exposes
-tyro's API for defining behavior for different types. This is the same
-API that tyro relies on for the built-in types.
-
-
-.. code-block:: python
- :linenos:
-
-
- import json
-
- from typing_extensions import Annotated
-
- import tyro
-
- # A dictionary type, but `tyro` will expect a JSON string from the CLI.
- JsonDict = Annotated[
- dict,
- tyro.constructors.PrimitiveConstructorSpec(
- nargs=1,
- metavar="JSON",
- instance_from_str=lambda args: json.loads(args[0]),
- is_instance=lambda instance: isinstance(instance, dict),
- str_from_instance=lambda instance: [json.dumps(instance)],
- ),
- ]
-
-
- def main(
- dict1: JsonDict,
- dict2: JsonDict = {"default": None},
- ) -> None:
- print(f"{dict1=}")
- print(f"{dict2=}")
-
-
- if __name__ == "__main__":
- tyro.cli(main)
-
-------------
-
-.. raw:: html
-
- python 04_additional/11_custom_constructors.py --help
-
-.. program-output:: python ../../examples/04_additional/11_custom_constructors.py --help
-
-------------
-
-.. raw:: html
-
- python 04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}'
-
-.. program-output:: python ../../examples/04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}'
-
-------------
-
-.. raw:: html
-
- python 04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}'
-
-.. program-output:: python ../../examples/04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}'
-
-------------
-
-.. raw:: html
-
- python 04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}'
-
-.. program-output:: python ../../examples/04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}'
-
-------------
-
-.. raw:: html
-
- python 04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}'
-
-.. program-output:: python ../../examples/04_additional/11_custom_constructors.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}'
diff --git a/docs/source/examples/04_additional/12_custom_constructors_registry.rst b/docs/source/examples/04_additional/12_custom_constructors_registry.rst
deleted file mode 100644
index 942ec2b9..00000000
--- a/docs/source/examples/04_additional/12_custom_constructors_registry.rst
+++ /dev/null
@@ -1,92 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Custom Constructors (Registry)
-==========================================
-
-For additional flexibility, :module:`tyro.constructors` exposes
-tyro's API for defining behavior for different types. This is the same
-API that tyro relies on for the built-in types.
-
-
-.. code-block:: python
- :linenos:
-
-
- import json
- from typing import Any
-
- import tyro
-
- custom_registry = tyro.constructors.PrimitiveConstructorRegistry()
-
-
- @custom_registry.define_rule
- def _(
- type_info: tyro.constructors.PrimitiveTypeInfo,
- ) -> tyro.constructors.PrimitiveConstructorSpec | None:
- # We return `None` if the rule does not apply.
- if type_info.type != dict[str, Any]:
- return None
-
- # If the rule applies, we return the constructor spec.
- return tyro.constructors.PrimitiveConstructorSpec(
- nargs=1,
- metavar="JSON",
- instance_from_str=lambda args: json.loads(args[0]),
- is_instance=lambda instance: isinstance(instance, dict),
- str_from_instance=lambda instance: [json.dumps(instance)],
- )
-
-
- def main(
- dict1: dict[str, Any],
- dict2: dict[str, Any] = {"default": None},
- ) -> None:
- print(f"{dict1=}")
- print(f"{dict2=}")
-
-
- if __name__ == "__main__":
- with custom_registry:
- tyro.cli(main)
-
-------------
-
-.. raw:: html
-
- python 04_additional/12_custom_constructors_registry.py --help
-
-.. program-output:: python ../../examples/04_additional/12_custom_constructors_registry.py --help
-
-------------
-
-.. raw:: html
-
- python 04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}'
-
-.. program-output:: python ../../examples/04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}'
-
-------------
-
-.. raw:: html
-
- python 04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}'
-
-.. program-output:: python ../../examples/04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}'
-
-------------
-
-.. raw:: html
-
- python 04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}'
-
-.. program-output:: python ../../examples/04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}'
-
-------------
-
-.. raw:: html
-
- python 04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}'
-
-.. program-output:: python ../../examples/04_additional/12_custom_constructors_registry.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}'
diff --git a/docs/source/examples/04_additional/16_type_statement.rst b/docs/source/examples/04_additional/12_type_statement.rst
similarity index 88%
rename from docs/source/examples/04_additional/16_type_statement.rst
rename to docs/source/examples/04_additional/12_type_statement.rst
index 8fe3d666..63975f1e 100644
--- a/docs/source/examples/04_additional/16_type_statement.rst
+++ b/docs/source/examples/04_additional/12_type_statement.rst
@@ -45,6 +45,6 @@ In Python 3.12, the :code:`type` statement is introduced to create type aliases.
.. raw:: html
- python 04_additional/16_type_statement.py --help
+ python 04_additional/12_type_statement.py --help
-.. program-output:: python ../../examples/04_additional/16_type_statement.py --help
+.. program-output:: python ../../examples/04_additional/12_type_statement.py --help
diff --git a/docs/source/examples/05_custom_constructors/01_primitive_annotation.rst b/docs/source/examples/05_custom_constructors/01_primitive_annotation.rst
new file mode 100644
index 00000000..fbe7e192
--- /dev/null
+++ b/docs/source/examples/05_custom_constructors/01_primitive_annotation.rst
@@ -0,0 +1,87 @@
+.. Comment: this file is automatically generated by `update_example_docs.py`.
+ It should not be modified manually.
+
+Custom Primitive
+==========================================
+
+For additional flexibility, :mod:`tyro.constructors` exposes tyro's API for
+defining behavior for different types. There are two categories of types:
+primitive types can be instantiated from a single commandline argument, while
+struct types are broken down into multiple.
+
+In this example, we attach a custom constructor via a runtime annotation.
+
+
+.. code-block:: python
+ :linenos:
+
+
+ import json
+
+ from typing_extensions import Annotated
+
+ import tyro
+
+ # A dictionary type, but `tyro` will expect a JSON string from the CLI.
+ JsonDict = Annotated[
+ dict,
+ tyro.constructors.PrimitiveConstructorSpec(
+ nargs=1,
+ metavar="JSON",
+ instance_from_str=lambda args: json.loads(args[0]),
+ is_instance=lambda instance: isinstance(instance, dict),
+ str_from_instance=lambda instance: [json.dumps(instance)],
+ ),
+ ]
+
+
+ def main(
+ dict1: JsonDict,
+ dict2: JsonDict = {"default": None},
+ ) -> None:
+ print(f"{dict1=}")
+ print(f"{dict2=}")
+
+
+ if __name__ == "__main__":
+ tyro.cli(main)
+
+------------
+
+.. raw:: html
+
+ python 05_custom_constructors/01_primitive_annotation.py --help
+
+.. program-output:: python ../../examples/05_custom_constructors/01_primitive_annotation.py --help
+
+------------
+
+.. raw:: html
+
+ python 05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}'
+
+.. program-output:: python ../../examples/05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}'
+
+------------
+
+.. raw:: html
+
+ python 05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}'
+
+.. program-output:: python ../../examples/05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}'
+
+------------
+
+.. raw:: html
+
+ python 05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'
+
+.. program-output:: python ../../examples/05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'
+
+------------
+
+.. raw:: html
+
+ python 05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'
+
+.. program-output:: python ../../examples/05_custom_constructors/01_primitive_annotation.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'
diff --git a/docs/source/examples/05_custom_constructors/02_primitive_registry.rst b/docs/source/examples/05_custom_constructors/02_primitive_registry.rst
new file mode 100644
index 00000000..de3b210f
--- /dev/null
+++ b/docs/source/examples/05_custom_constructors/02_primitive_registry.rst
@@ -0,0 +1,97 @@
+.. Comment: this file is automatically generated by `update_example_docs.py`.
+ It should not be modified manually.
+
+Custom Primitive (Registry)
+==========================================
+
+For additional flexibility, :mod:`tyro.constructors` exposes tyro's API for
+defining behavior for different types. There are two categories of types:
+primitive types can be instantiated from a single commandline argument, while
+struct types are broken down into multiple.
+
+
+In this example, we attach a custom constructor by defining a rule that applies
+to all types that match `dict[str, Any]`.
+
+
+.. code-block:: python
+ :linenos:
+
+
+ import json
+ from typing import Any
+
+ import tyro
+
+ custom_registry = tyro.constructors.ConstructorRegistry()
+
+
+ @custom_registry.primitive_rule
+ def _(
+ type_info: tyro.constructors.PrimitiveTypeInfo,
+ ) -> tyro.constructors.PrimitiveConstructorSpec | None:
+ # We return `None` if the rule does not apply.
+ if type_info.type != dict[str, Any]:
+ return None
+
+ # If the rule applies, we return the constructor spec.
+ return tyro.constructors.PrimitiveConstructorSpec(
+ nargs=1,
+ metavar="JSON",
+ instance_from_str=lambda args: json.loads(args[0]),
+ is_instance=lambda instance: isinstance(instance, dict),
+ str_from_instance=lambda instance: [json.dumps(instance)],
+ )
+
+
+ def main(
+ dict1: dict[str, Any],
+ dict2: dict[str, Any] = {"default": None},
+ ) -> None:
+ print(f"{dict1=}")
+ print(f"{dict2=}")
+
+
+ if __name__ == "__main__":
+ with custom_registry:
+ tyro.cli(main)
+
+------------
+
+.. raw:: html
+
+ python 05_custom_constructors/02_primitive_registry.py --help
+
+.. program-output:: python ../../examples/05_custom_constructors/02_primitive_registry.py --help
+
+------------
+
+.. raw:: html
+
+ python 05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}'
+
+.. program-output:: python ../../examples/05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}'
+
+------------
+
+.. raw:: html
+
+ python 05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}'
+
+.. program-output:: python ../../examples/05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}'
+
+------------
+
+.. raw:: html
+
+ python 05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'
+
+.. program-output:: python ../../examples/05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'
+
+------------
+
+.. raw:: html
+
+ python 05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'
+
+.. program-output:: python ../../examples/05_custom_constructors/02_primitive_registry.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'
diff --git a/docs/source/goals_and_alternatives.md b/docs/source/goals_and_alternatives.md
index e5f36e0f..d8bd5ffc 100644
--- a/docs/source/goals_and_alternatives.md
+++ b/docs/source/goals_and_alternatives.md
@@ -10,7 +10,7 @@ Usage distinctions are the result of two API goals:
`tyro` should reduce to learning to write type-annotated Python. For example,
types are specified using standard annotations, helptext using docstrings,
choices using the standard `typing.Literal` type, subcommands with
- `typing.Union` of nested types, and positional arguments with `/`.
+ `typing.Union` of struct types, and positional arguments with `/`.
- In contrast, similar libraries have more expansive APIs , and require more
library-specific structures, decorators, or metadata formats for configuring
parsing behavior.
@@ -21,23 +21,31 @@ Usage distinctions are the result of two API goals:
dynamic argparse-style namespaces, or string-based accessors that can't be
statically checked.
+
+
+.. warning::
+ This survey was conducted in late 2022. It may be out of date.
+
+
+
More concretely, we can also compare specific features. A noncomprehensive set:
-| | Dataclasses | Functions | Literals | Docstrings as helptext | Nested structures | Unions over primitives | Unions over nested types | Lists, tuples | Dictionaries | Generics |
-| -------------------------------------------- | ----------- | --------- | -------------------- | ---------------------- | ----------------- | ---------------------- | ------------------------- | -------------------- | ------------ | -------- |
-| [argparse-dataclass][argparse-dataclass] | ✓ | | | | | | | | | |
-| [argparse-dataclasses][argparse-dataclasses] | ✓ | | | | | | | | | |
-| [datargs][datargs] | ✓ | | ✓[^datargs_literals] | | | | ✓[^datargs_unions_nested] | ✓ | | |
-| [tap][tap] | | | ✓ | ✓ | | ✓ | ~[^tap_unions_nested] | ✓ | | |
-| [simple-parsing][simple-parsing] | ✓ | | ✓[^simp_literals] | ✓ | ✓ | ✓ | ✓[^simp_unions_nested] | ✓ | ✓ | |
-| [dataclass-cli][dataclass-cli] | ✓ | | | | | | | | | |
-| [clout][clout] | ✓ | | | | ✓ | | | | | |
-| [hf_argparser][hf_argparser] | ✓ | | | | | | | ✓ | ✓ | |
-| [typer][typer] | | ✓ | | | | | ~[^typer_unions_nested] | ~[^typer_containers] | | |
-| [pyrallis][pyrallis] | ✓ | | | ✓ | ✓ | | | ✓ | | |
-| [yahp][yahp] | ✓ | | | ~[^yahp_docstrings] | ✓ | ✓ | ~[^yahp_unions_nested] | ✓ | | |
-| [omegaconf][omegaconf] | ✓ | | | | ✓ | | | ✓ | ✓ | |
-| **tyro** | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
+| | Dataclasses | Functions | Literals | Docstrings as helptext | Nested structs | Unions over primitives | Unions over structs | Lists, tuples | Dicts | Generics |
+| -------------------------------------------- | ----------- | --------- | -------------------- | ---------------------- | -------------- | ---------------------- | ------------------------- | -------------------- | ----- | -------- |
+| [argparse-dataclass][argparse-dataclass] | ✓ | | | | | | | | | |
+| [argparse-dataclasses][argparse-dataclasses] | ✓ | | | | | | | | | |
+| [datargs][datargs] | ✓ | | ✓[^datargs_literals] | | | | ✓[^datargs_unions_struct] | ✓ | | |
+| [tap][tap] | | | ✓ | ✓ | | ✓ | ~[^tap_unions_struct] | ✓ | | |
+| [simple-parsing][simple-parsing] | ✓ | | ✓[^simp_literals] | ✓ | ✓ | ✓ | ✓[^simp_unions_struct] | ✓ | ✓ | |
+| [dataclass-cli][dataclass-cli] | ✓ | | | | | | | | | |
+| [clout][clout] | ✓ | | | | ✓ | | | | | |
+| [hf_argparser][hf_argparser] | ✓ | | | | | | | ✓ | ✓ | |
+| [typer][typer] | | ✓ | | | | | ~[^typer_unions_struct] | ~[^typer_containers] | | |
+| [pyrallis][pyrallis] | ✓ | | | ✓ | ✓ | | | ✓ | | |
+| [yahp][yahp] | ✓ | | | ~[^yahp_docstrings] | ✓ | ✓ | ~[^yahp_unions_struct] | ✓ | | |
+| [omegaconf][omegaconf] | ✓ | | | | ✓ | | | ✓ | ✓ | |
+| [defopt][defopt] | | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | | |
+| **tyro** | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
@@ -53,12 +61,13 @@ More concretely, we can also compare specific features. A noncomprehensive set:
[typer]: https://typer.tiangolo.com/
[yahp]: https://github.com/mosaicml/yahp
[omegaconf]: https://omegaconf.readthedocs.io/en/2.1_branch/structured_config.html
+[defopt]: https://github.com/anntzer/defopt/
-[^datargs_unions_nested]: One allowed per class.
-[^tap_unions_nested]: Not supported, but API exists for creating subcommands that accomplish a similar goal.
-[^simp_unions_nested]: One allowed per class.
-[^yahp_unions_nested]: Not supported, but similar functionality available via ["registries"](https://docs.mosaicml.com/projects/yahp/en/stable/examples/registry.html).
-[^typer_unions_nested]: Not supported, but API exists for creating subcommands that accomplish a similar goal.
+[^datargs_unions_struct]: One allowed per class.
+[^tap_unions_struct]: Not supported, but API exists for creating subcommands that accomplish a similar goal.
+[^simp_unions_struct]: One allowed per class.
+[^yahp_unions_struct]: Not supported, but similar functionality available via ["registries"](https://docs.mosaicml.com/projects/yahp/en/stable/examples/registry.html).
+[^typer_unions_struct]: Not supported, but API exists for creating subcommands that accomplish a similar goal.
[^simp_literals]: Not supported for mixed (eg `Literal[5, "five"]`) or in container (eg `List[Literal[1, 2]]`) types.
[^datargs_literals]: Not supported for mixed types (eg `Literal[5, "five"]`).
[^typer_containers]: `typer` uses positional arguments for all required fields, which means that only one variable-length argument (such as `List[int]`) without a default is supported per argument parser.
diff --git a/docs/source/index.md b/docs/source/index.md
index 1144b609..dbcbae06 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -2,42 +2,67 @@
|build| |nbsp| |ruff| |nbsp| |mypy| |nbsp| |pyright| |nbsp| |coverage| |nbsp| |versions|
-:code:`tyro` is a tool for generating command-line interfaces and configuration
-objects in Python.
+:func:`tyro.cli()` is a tool for generating CLI
+interfaces.
-Our core API, :func:`tyro.cli()`,
+We can define configurable scripts using functions:
-- **Generates CLI interfaces** from a comprehensive set of Python type
- constructs.
-- **Populates helptext automatically** from defaults, annotations, and
- docstrings.
-- **Understands nesting** of `dataclasses`, `pydantic`, and `attrs` structures.
-- **Prioritizes static analysis** for type checking and autocompletion with
- tools like
- [Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance),
- [Pyright](https://github.com/microsoft/pyright), and
- [mypy](https://github.com/python/mypy).
+```python
+"""A command-line interface defined using a function signature.
-For advanced users, it also supports:
+Usage: python script_name.py --foo INT [--bar STR]
+"""
-- **Subcommands**, as well as choosing between and overriding values in
- configuration objects.
-- **Completion script generation** for `bash`, `zsh`, and `tcsh`.
-- **Fine-grained configuration** via [PEP
- 593](https://peps.python.org/pep-0593/) annotations (`tyro.conf.*`).
+import tyro
-To get started, we recommend browsing the examples to the left.
+def main(
+ foo: int,
+ bar: str = "default",
+) -> None:
+ ... # Main body of a script.
-### Why `tyro`?
+if __name__ == "__main__":
+ # Generate a CLI and call `main` with its two arguments: `foo` and `bar`.
+ tyro.cli(main)
+```
-1. **Strong typing.**
+Or instantiate config objects defined using tools like `dataclasses`, `pydantic`, and `attrs`:
+
+```python
+"""A command-line interface defined using a class signature.
+
+Usage: python script_name.py --foo INT [--bar STR]
+"""
+
+from dataclasses import dataclass
+import tyro
+
+@dataclass
+class Config:
+ foo: int
+ bar: str = "default"
+
+if __name__ == "__main__":
+ # Generate a CLI and instantiate `Config` with its two arguments: `foo` and `bar`.
+ config = tyro.cli(Config)
+
+ # Rest of script.
+ assert isinstance(config, Config) # Should pass.
+```
+
+Other features include helptext generation, nested structures, subcommands, and
+shell completion.
+
+#### Why `tyro`?
+
+1. **Types.**
Unlike tools dependent on dictionaries, YAML, or dynamic namespaces,
arguments populated by `tyro` benefit from IDE and language server-supported
operations — think tab completion, rename, jump-to-def, docstrings on hover —
as well as static checking tools like `pyright` and `mypy`.
-2. **Minimal overhead.**
+2. **Define things once.**
Standard Python type annotations, docstrings, and default values are parsed
to automatically generate command-line interfaces with informative helptext.
@@ -55,11 +80,6 @@ To get started, we recommend browsing the examples to the left.
distribute definitions, defaults, and documentation of configurable fields
across modules or source files.
-4. **Tab completion.**
-
- By extending [shtab](https://github.com/iterative/shtab), `tyro`
- automatically generates tab completion scripts for bash, zsh, and tcsh.
-
.. toctree::
@@ -112,6 +132,16 @@ To get started, we recommend browsing the examples to the left.
examples/04_additional/*
+.. toctree::
+ :caption: Custom Constructors
+ :hidden:
+ :maxdepth: 1
+ :titlesonly:
+ :glob:
+
+ examples/05_custom_constructors/*
+
+
.. toctree::
:caption: Notes
:hidden:
diff --git a/examples/04_additional/06_generics_py312.py b/examples/04_additional/06_generics_py312.py
index 2fd4b458..72f69571 100644
--- a/examples/04_additional/06_generics_py312.py
+++ b/examples/04_additional/06_generics_py312.py
@@ -1,7 +1,7 @@
# mypy: ignore-errors
#
# PEP 695 isn't yet supported in mypy. (April 4, 2024)
-"""Generic Types (Python 3.12+ syntax)
+"""Generic Types (Python 3.12+)
Example of parsing for generic dataclasses using syntax introduced in Python
3.12 (`PEP 695 `_).
diff --git a/examples/04_additional/17_aliases.py b/examples/04_additional/11_aliases.py
similarity index 68%
rename from examples/04_additional/17_aliases.py
rename to examples/04_additional/11_aliases.py
index 16c5b3b3..5ee33de9 100644
--- a/examples/04_additional/17_aliases.py
+++ b/examples/04_additional/11_aliases.py
@@ -3,13 +3,13 @@
:func:`tyro.conf.arg()` can be used to attach aliases to arguments.
Usage:
-`python ./12_aliases.py --help`
-`python ./12_aliases.py commit --help`
-`python ./12_aliases.py commit --message hello --all`
-`python ./12_aliases.py commit -m hello -a`
-`python ./12_aliases.py checkout --help`
-`python ./12_aliases.py checkout --branch main`
-`python ./12_aliases.py checkout -b main`
+`python ./11_aliases.py --help`
+`python ./11_aliases.py commit --help`
+`python ./11_aliases.py commit --message hello --all`
+`python ./11_aliases.py commit -m hello -a`
+`python ./11_aliases.py checkout --help`
+`python ./11_aliases.py checkout --branch main`
+`python ./11_aliases.py checkout -b main`
"""
from typing_extensions import Annotated
diff --git a/examples/04_additional/16_type_statement.py b/examples/04_additional/12_type_statement.py
similarity index 94%
rename from examples/04_additional/16_type_statement.py
rename to examples/04_additional/12_type_statement.py
index 9aa08cf9..fa814029 100644
--- a/examples/04_additional/16_type_statement.py
+++ b/examples/04_additional/12_type_statement.py
@@ -6,7 +6,7 @@
In Python 3.12, the :code:`type` statement is introduced to create type aliases.
Usage:
-`python ./16_type_statement.py --help`
+`python ./12_type_statement.py --help`
"""
import dataclasses
diff --git a/examples/04_additional/11_custom_constructors.py b/examples/05_custom_constructors/01_primitive_annotation.py
similarity index 51%
rename from examples/04_additional/11_custom_constructors.py
rename to examples/05_custom_constructors/01_primitive_annotation.py
index 4d8cff56..505e8b9d 100644
--- a/examples/04_additional/11_custom_constructors.py
+++ b/examples/05_custom_constructors/01_primitive_annotation.py
@@ -1,15 +1,16 @@
-"""Custom Constructors
+"""Custom Primitive
-For additional flexibility, :module:`tyro.constructors` exposes
-tyro's API for defining behavior for different types. This is the same
-API that tyro relies on for the built-in types.
+For additional flexibility, :mod:`tyro.constructors` exposes tyro's API for
+defining behavior for different types. There are two categories of types:
+primitive types can be instantiated from a single commandline argument, while
+struct types are broken down into multiple.
+
+In this example, we attach a custom constructor via a runtime annotation.
Usage:
-`python ./10_custom_constructors.py --help`
-`python ./10_custom_constructors.py --dict1.json '{"hello": "world"}'`
-`python ./10_custom_constructors.py --dict1.json "{\"hello\": \"world\"}"`
-`python ./10_custom_constructors.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}'`
-`python ./10_custom_constructors.py --dict1.json "{\"hello\": \"world\"}" --dict2.json "{\"hello\": \"world\"}"`
+`python ./01_primitive_annotation.py --help`
+`python ./01_primitive_annotation.py --dict1 '{"hello": "world"}'`
+`python ./01_primitive_annotation.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'`
"""
import json
diff --git a/examples/04_additional/12_custom_constructors_registry.py b/examples/05_custom_constructors/02_primitive_registry.py
similarity index 60%
rename from examples/04_additional/12_custom_constructors_registry.py
rename to examples/05_custom_constructors/02_primitive_registry.py
index 88eb1382..767765f0 100644
--- a/examples/04_additional/12_custom_constructors_registry.py
+++ b/examples/05_custom_constructors/02_primitive_registry.py
@@ -1,15 +1,17 @@
-"""Custom Constructors (Registry)
+"""Custom Primitive (Registry)
+For additional flexibility, :mod:`tyro.constructors` exposes tyro's API for
+defining behavior for different types. There are two categories of types:
+primitive types can be instantiated from a single commandline argument, while
+struct types are broken down into multiple.
-For additional flexibility, :module:`tyro.constructors` exposes
-tyro's API for defining behavior for different types. This is the same
-API that tyro relies on for the built-in types.
+
+In this example, we attach a custom constructor by defining a rule that applies
+to all types that match ``dict[str, Any]``.
Usage:
-`python ./10_custom_constructors.py --help`
-`python ./10_custom_constructors.py --dict1.json '{"hello": "world"}'`
-`python ./10_custom_constructors.py --dict1.json "{\"hello\": \"world\"}"`
-`python ./10_custom_constructors.py --dict1.json '{"hello": "world"}' --dict2.json '{"hello": "world"}'`
-`python ./10_custom_constructors.py --dict1.json "{\"hello\": \"world\"}" --dict2.json "{\"hello\": \"world\"}"`
+`python ./02_primitive_registry.py --help`
+`python ./02_primitive_registry.py --dict1 '{"hello": "world"}'`
+`python ./02_primitive_registry.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}'`
"""
import json
diff --git a/src/tyro/_arguments.py b/src/tyro/_arguments.py
index a9f6489d..d0b4aef9 100644
--- a/src/tyro/_arguments.py
+++ b/src/tyro/_arguments.py
@@ -174,9 +174,9 @@ def add_argument(
)
complete_as_path = (
# Catch types like Path, List[Path], Tuple[Path, ...] etc.
- "Path" in str(self.field.type_or_callable)
+ "Path" in str(self.field.type_stripped)
# For string types, we require more evidence.
- or ("str" in str(self.field.type_or_callable) and name_suggests_path)
+ or ("str" in str(self.field.type_stripped) and name_suggests_path)
)
if complete_as_path:
arg.complete = shtab.DIRECTORY if name_suggests_dir else shtab.FILE # type: ignore
@@ -233,7 +233,7 @@ def _rule_handle_boolean_flags(
arg: ArgumentDefinition,
lowered: LoweredArgumentDefinition,
) -> None:
- if arg.field.type_or_callable is not bool:
+ if arg.field.type_stripped is not bool:
return
if (
@@ -281,15 +281,12 @@ def _rule_apply_primitive_specs(
return
try:
- if arg.field.primitive_spec is not None:
- spec = arg.field.primitive_spec
- else:
- spec = ConstructorRegistry._get_active_registry().get_primitive_spec(
- PrimitiveTypeInfo.make(
- cast(type, arg.field.type_or_callable),
- arg.field.markers,
- )
+ spec = ConstructorRegistry._get_active_registry().get_primitive_spec(
+ PrimitiveTypeInfo.make(
+ cast(type, arg.field.type),
+ arg.field.markers,
)
+ )
except UnsupportedTypeAnnotationError as e:
if arg.field.default in _singleton.MISSING_SINGLETONS:
field_name = _strings.make_field_name(
@@ -335,11 +332,11 @@ def _rule_apply_primitive_specs(
def append_instantiator(x: list[list[str]]) -> Any:
"""Handle UseAppendAction effects."""
# We'll assume that the type is annotated as Dict[...], Tuple[...], List[...], etc.
- container_type = get_origin(arg.field.type_or_callable)
+ container_type = get_origin(arg.field.type_stripped)
if container_type is None:
# Raw annotation, like `UseAppendAction[list]`. It's unlikely
# that a user would use this but we can handle it.
- container_type = arg.field.type_or_callable
+ container_type = arg.field.type_stripped
# Instantiate initial output.
out = (
@@ -407,7 +404,7 @@ def _rule_counters(
"""Handle counters, like -vvv for level-3 verbosity."""
if (
_markers.UseCounterAction in arg.field.markers
- and arg.field.type_or_callable is int
+ and arg.field.type_stripped is int
and not arg.field.is_positional()
):
lowered.metavar = None
@@ -459,7 +456,7 @@ def _rule_generate_helptext(
if arg.field.argconf.constructor_factory is not None:
default_label = (
str(default)
- if arg.field.type_or_callable is not json.loads
+ if arg.field.type_stripped is not json.loads
else json.dumps(arg.field.default)
)
elif type(default) in (tuple, list, set):
diff --git a/src/tyro/_calling.py b/src/tyro/_calling.py
index b714c20e..53c091a1 100644
--- a/src/tyro/_calling.py
+++ b/src/tyro/_calling.py
@@ -70,7 +70,7 @@ def get_value_from_arg(prefixed_field_name: str) -> Any:
)
# Resolve field type.
- field_type = field.type_or_callable
+ field_type = field.type_stripped
if prefixed_field_name in arg_from_prefixed_field_name:
assert prefixed_field_name not in consumed_keywords
diff --git a/src/tyro/_cli.py b/src/tyro/_cli.py
index d24ec32e..1b741f5a 100644
--- a/src/tyro/_cli.py
+++ b/src/tyro/_cli.py
@@ -408,7 +408,6 @@ def _cli_impl(
default_instance=default_instance_internal, # Overrides for default values.
intern_prefix="", # Used for recursive calls.
extern_prefix="", # Used for recursive calls.
- subcommand_prefix="", # Used for recursive calls.
)
# Generate parser!
diff --git a/src/tyro/_fields.py b/src/tyro/_fields.py
index b1f50680..2cc973c9 100644
--- a/src/tyro/_fields.py
+++ b/src/tyro/_fields.py
@@ -12,19 +12,22 @@
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import docstring_parser
-from typing_extensions import Annotated
+from typing_extensions import Annotated, get_args, get_origin
from . import _docstrings, _resolver, _strings, _unsafe_cache
from ._singleton import DEFAULT_SENTINEL_SINGLETONS, MISSING_SINGLETONS
from ._typing import TypeForm
from .conf import _confstruct, _markers
from .constructors._primitive_spec import (
- PrimitiveConstructorSpec,
PrimitiveTypeInfo,
UnsupportedTypeAnnotationError,
)
from .constructors._registry import ConstructorRegistry
-from .constructors._struct_spec import StructTypeInfo, UnsupportedStructTypeMessage
+from .constructors._struct_spec import (
+ StructFieldSpec,
+ StructTypeInfo,
+ UnsupportedStructTypeMessage,
+)
global_context_markers: List[Tuple[_markers.Marker, ...]] = []
@@ -33,7 +36,9 @@
class FieldDefinition:
intern_name: str
extern_name: str
- type_or_callable: Union[TypeForm[Any], Callable]
+ type: TypeForm[Any] | Callable
+ """Full type, including runtime annotations."""
+ type_stripped: TypeForm[Any] | Callable
default: Any
# We need to record whether defaults are from default instances to
# determine if they should override the default in
@@ -44,7 +49,6 @@ class FieldDefinition:
custom_constructor: bool
argconf: _confstruct._ArgConfig
- primitive_spec: PrimitiveConstructorSpec | None
# Override the name in our kwargs. Useful whenever the user-facing argument name
# doesn't match the keyword expected by our callable.
@@ -67,6 +71,17 @@ def marker_context(markers: Tuple[_markers.Marker, ...]):
yield
global_context_markers.pop()
+ @staticmethod
+ def from_field_spec(field_spec: StructFieldSpec) -> FieldDefinition:
+ return FieldDefinition.make(
+ name=field_spec.name,
+ typ=field_spec.type,
+ default=field_spec.default,
+ is_default_from_default_instance=field_spec.is_default_overridden,
+ helptext=field_spec.helptext,
+ call_argname_override=field_spec._call_argname,
+ )
+
@staticmethod
def make(
name: str,
@@ -114,18 +129,31 @@ def make(
if argconf.help is not None:
helptext = argconf.help
- _, primitive_specs = _resolver.unwrap_annotated(typ, PrimitiveConstructorSpec)
- if len(primitive_specs) > 0:
- primitive_spec = primitive_specs[0]
- else:
- primitive_spec = None
-
- typ, markers = _resolver.unwrap_annotated(typ, _markers._Marker)
+ type_stripped, markers = _resolver.unwrap_annotated(typ, _markers._Marker)
# Include markers set via context manager.
for context_markers in global_context_markers:
markers += context_markers
+ out = FieldDefinition(
+ intern_name=name,
+ extern_name=name if argconf.name is None else argconf.name,
+ type=typ,
+ type_stripped=type_stripped,
+ default=default,
+ is_default_from_default_instance=is_default_from_default_instance,
+ helptext=helptext,
+ markers=set(markers),
+ custom_constructor=argconf.constructor_factory is not None,
+ argconf=argconf,
+ call_argname=(
+ call_argname_override if call_argname_override is not None else name
+ ),
+ )
+
+ if argconf.constructor_factory is not None:
+ out = out.with_new_type_stripped(argconf.constructor_factory())
+
# Check that the default value matches the final resolved type.
# There's some similar Union-specific logic for this in narrow_union_type(). We
# may be able to consolidate this.
@@ -133,8 +161,8 @@ def make(
# Be relatively conservative: isinstance() can be checked on non-type
# types (like unions in Python >=3.10), but we'll only consider single types
# for now.
- type(typ) is type
- and not isinstance(default, typ) # type: ignore
+ type(out.type_stripped) is type
+ and not isinstance(default, out.type_stripped) # type: ignore
# If a custom constructor is set, static_type may not be
# matched to the annotated type.
and argconf.constructor_factory is None
@@ -150,29 +178,25 @@ def make(
f"but the default value {default} has type {type(default)}. "
f"We'll try to handle this gracefully, but it may cause unexpected behavior."
)
- typ = Union[typ, type(default)] # type: ignore
+ out = out.with_new_type_stripped(Union[out.type_stripped, type(default)]) # type: ignore
- out = FieldDefinition(
- intern_name=name,
- extern_name=name if argconf.name is None else argconf.name,
- type_or_callable=(
- typ
- if argconf.constructor_factory is None
- else argconf.constructor_factory()
- ),
- default=default,
- is_default_from_default_instance=is_default_from_default_instance,
- helptext=helptext,
- markers=set(markers),
- custom_constructor=argconf.constructor_factory is not None,
- argconf=argconf,
- call_argname=(
- call_argname_override if call_argname_override is not None else name
- ),
- primitive_spec=primitive_spec,
- )
return out
+ def with_new_type_stripped(
+ self, new_type_stripped: TypeForm[Any] | Callable
+ ) -> FieldDefinition:
+ if get_origin(self.type) is Annotated:
+ new_type = Annotated.__class_getitem__( # type: ignore
+ (new_type_stripped, *get_args(self.type)[1:])
+ )
+ else:
+ new_type = new_type_stripped
+ return dataclasses.replace(
+ self,
+ type=new_type,
+ type_stripped=new_type_stripped,
+ )
+
def is_positional(self) -> bool:
"""Returns True if the argument should be positional in the commandline."""
return (
@@ -240,17 +264,7 @@ def field_list_from_type_or_callable(
with FieldDefinition.marker_context(type_info.markers):
if spec is not None:
- return f, [
- FieldDefinition.make(
- f.name,
- f.type,
- f.default,
- f.is_default_overridden,
- f.helptext,
- call_argname_override=f._call_argname,
- )
- for f in spec.fields
- ]
+ return f, [FieldDefinition.from_field_spec(f) for f in spec.fields]
try:
registry.get_primitive_spec(PrimitiveTypeInfo.make(f, set()))
diff --git a/src/tyro/_parsers.py b/src/tyro/_parsers.py
index 068b1ac8..06c7cb79 100644
--- a/src/tyro/_parsers.py
+++ b/src/tyro/_parsers.py
@@ -19,6 +19,7 @@
from typing_extensions import Annotated, get_args, get_origin
+from tyro.constructors._registry import ConstructorRegistry
from tyro.constructors._struct_spec import UnsupportedStructTypeMessage
from . import _argparse as argparse
@@ -34,7 +35,11 @@
)
from ._typing import TypeForm
from .conf import _confstruct, _markers
-from .constructors._primitive_spec import UnsupportedTypeAnnotationError
+from .constructors._primitive_spec import (
+ PrimitiveConstructorSpec,
+ PrimitiveTypeInfo,
+ UnsupportedTypeAnnotationError,
+)
T = TypeVar("T")
@@ -315,8 +320,20 @@ def handle_field(
]:
"""Determine what to do with a single field definition."""
+ registry = ConstructorRegistry._get_active_registry()
+
+ # Force primitive if (1) the field is annotated with a primitive constructor spec, or (2) if
+ force_primitive = len(
+ _resolver.unwrap_annotated(field.type, PrimitiveConstructorSpec)[1]
+ ) > 0 or (
+ len(registry._custom_primitive_rules) > 0
+ and registry.get_primitive_spec(
+ PrimitiveTypeInfo.make(field.type, field.markers), rule_mode="custom"
+ )
+ is not None
+ )
if (
- field.primitive_spec is None
+ not force_primitive
and _markers.Fixed not in field.markers
and _markers.Suppress not in field.markers
):
@@ -333,21 +350,22 @@ def handle_field(
and _markers.AvoidSubcommands in field.markers
):
# Don't make a subparser.
- field = dataclasses.replace(field, type_or_callable=type(field.default))
+ field = field.with_new_type_stripped(type(field.default))
else:
return subparsers_attempt
# (2) Handle nested callables.
- if _fields.is_struct_type(field.type_or_callable, field.default):
- field = dataclasses.replace(
- field,
- type_or_callable=_resolver.narrow_subtypes(
- field.type_or_callable,
+ if force_primitive == "struct" or _fields.is_struct_type(
+ field.type_stripped, field.default
+ ):
+ field = field.with_new_type_stripped(
+ _resolver.narrow_subtypes(
+ field.type_stripped,
field.default,
),
)
return ParserSpecification.from_callable_or_type(
- field.type_or_callable,
+ field.type_stripped,
markers=field.markers,
description=None,
parent_classes=parent_classes,
@@ -393,7 +411,7 @@ def from_field(
extern_prefix: str,
) -> Optional[SubparsersSpecification]:
# Union of classes should create subparsers.
- typ = _resolver.unwrap_annotated(field.type_or_callable)
+ typ = _resolver.unwrap_annotated(field.type_stripped)
if get_origin(typ) not in (Union, _resolver.UnionType):
return None
diff --git a/src/tyro/_subcommand_matching.py b/src/tyro/_subcommand_matching.py
index 82e8d20b..ed7bed7e 100644
--- a/src/tyro/_subcommand_matching.py
+++ b/src/tyro/_subcommand_matching.py
@@ -90,7 +90,7 @@ def make(
return _TypeTree(
typ_unwrap,
{
- field.intern_name: _TypeTree.make(field.type_or_callable, field.default)
+ field.intern_name: _TypeTree.make(field.type_stripped, field.default)
for field in field_list
},
)
diff --git a/src/tyro/conf/_confstruct.py b/src/tyro/conf/_confstruct.py
index 0104c4d0..758182ab 100644
--- a/src/tyro/conf/_confstruct.py
+++ b/src/tyro/conf/_confstruct.py
@@ -93,10 +93,10 @@ def subcommand(
is in a nested structure.
constructor: A constructor type or function. This will be used in
place of the argument's type for parsing arguments. For more
- configurability, see :module:`tyro.constructors`.
+ configurability, see :mod:`tyro.constructors`.
constructor_factory: A function that returns a constructor type. This
will be used in place of the argument's type for parsing arguments.
- For more configurability, see :module:`tyro.constructors`.
+ For more configurability, see :mod:`tyro.constructors`.
"""
assert not (
constructor is not None and constructor_factory is not None
@@ -191,10 +191,10 @@ def arg(
it is in a nested structure. Arguments are prefixed by default.
constructor: A constructor type or function. This will be used in
place of the argument's type for parsing arguments. For more
- configurability, see :module:`tyro.constructors`.
+ configurability, see :mod:`tyro.constructors`.
constructor_factory: A function that returns a constructor type. This
will be used in place of the argument's type for parsing arguments.
- For more configurability, see :module:`tyro.constructors`.
+ For more configurability, see :mod:`tyro.constructors`.
Returns:
Object to attach via `typing.Annotated[]`.
diff --git a/src/tyro/constructors/__init__.py b/src/tyro/constructors/__init__.py
index d93bfe0a..e54a8ee5 100644
--- a/src/tyro/constructors/__init__.py
+++ b/src/tyro/constructors/__init__.py
@@ -1,3 +1,11 @@
+"""The :mod:`tyro.constructors` submodule exposes tyro's API for defining
+behavior for different types.
+
+.. warning::
+ This submodule exposes advanced functionality, and is not needed for the
+ majority of users.
+"""
+
from ._primitive_spec import PrimitiveConstructorSpec as PrimitiveConstructorSpec
from ._primitive_spec import PrimitiveTypeInfo as PrimitiveTypeInfo
from ._primitive_spec import (
diff --git a/src/tyro/constructors/_primitive_spec.py b/src/tyro/constructors/_primitive_spec.py
index 8c113d6e..56a1d4d7 100644
--- a/src/tyro/constructors/_primitive_spec.py
+++ b/src/tyro/constructors/_primitive_spec.py
@@ -62,12 +62,19 @@ class PrimitiveTypeInfo:
"""The output of get_origin() on the static type."""
markers: set[_markers.Marker]
"""Set of tyro markers used to configure this field."""
+ _primitive_spec: PrimitiveConstructorSpec | None
+ """Primitive constructor spec that was scraped from runtime annotations."""
@staticmethod
def make(
raw_annotation: TypeForm | Callable,
parent_markers: set[_markers.Marker],
) -> PrimitiveTypeInfo:
+ _, primitive_specs = _resolver.unwrap_annotated(
+ raw_annotation, search_type=PrimitiveConstructorSpec
+ )
+ primitive_spec = primitive_specs[0] if len(primitive_specs) > 0 else None
+
typ, extra_markers = _resolver.unwrap_annotated(
raw_annotation, search_type=_markers._Marker
)
@@ -75,6 +82,7 @@ def make(
type=cast(TypeForm, typ),
type_origin=get_origin(typ),
markers=parent_markers | set(extra_markers),
+ _primitive_spec=primitive_spec,
)
@@ -84,11 +92,11 @@ class PrimitiveConstructorSpec(Generic[T]):
There are two ways to use this class:
- First, we can include it in a type signature via `typing.Annotated`.
+ First, we can include it in a type signature via :class:`typing.Annotated`.
This is the simplest for making local modifications to parsing behavior for
individual fields.
- Alternatively, it can be returned by a rule in a `PrimitiveConstructorRegistry`.
+ Alternatively, it can be returned by a rule in a :class:`ConstructorRegistry`.
"""
nargs: int | Literal["*"]
@@ -123,7 +131,7 @@ def apply_default_primitive_rules(registry: ConstructorRegistry) -> None:
from ._registry import ConstructorRegistry
- @registry.primitive_rule
+ @registry._default_primitive_rule
def any_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
if type_info.type is not Any:
return None
@@ -136,7 +144,7 @@ def any_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
# not to break it.
vanilla_types = (int, str, float, complex, bytes, bytearray, json.loads)
- @registry.primitive_rule
+ @registry._default_primitive_rule
def basics_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
if type_info.type not in vanilla_types:
return None
@@ -159,7 +167,7 @@ def basics_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None
if "torch" in sys.modules.keys():
import torch
- @registry.primitive_rule
+ @registry._default_primitive_rule
def torch_device_rule(
type_info: PrimitiveTypeInfo,
) -> PrimitiveConstructorSpec | None:
@@ -173,7 +181,7 @@ def torch_device_rule(
str_from_instance=lambda instance: [str(instance)],
)
- @registry.primitive_rule
+ @registry._default_primitive_rule
def bool_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
if type_info.type is not bool:
return None
@@ -186,7 +194,7 @@ def bool_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
str_from_instance=lambda instance: ["True" if instance else "False"],
)
- @registry.primitive_rule
+ @registry._default_primitive_rule
def nonetype_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
if type_info.type is not type(None):
return None
@@ -199,7 +207,7 @@ def nonetype_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | No
str_from_instance=lambda instance: ["None"],
)
- @registry.primitive_rule
+ @registry._default_primitive_rule
def path_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
if not (
type_info.type in (os.PathLike, pathlib.Path)
@@ -217,7 +225,7 @@ def path_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
str_from_instance=lambda instance: [str(instance)],
)
- @registry.primitive_rule
+ @registry._default_primitive_rule
def enum_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
if not (
inspect.isclass(type_info.type) and issubclass(type_info.type, enum.Enum)
@@ -250,7 +258,7 @@ def enum_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
choices=choices,
)
- @registry.primitive_rule
+ @registry._default_primitive_rule
def datetime_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
if type_info.type not in (datetime.datetime, datetime.date, datetime.time):
return None
@@ -271,7 +279,7 @@ def datetime_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | No
str_from_instance=lambda instance: [instance.isoformat()],
)
- @registry.primitive_rule
+ @registry._default_primitive_rule
def vague_container_rule(
type_info: PrimitiveTypeInfo,
) -> PrimitiveConstructorSpec | None:
@@ -306,7 +314,7 @@ def vague_container_rule(
)
)
- @registry.primitive_rule
+ @registry._default_primitive_rule
def sequence_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
if type_info.type_origin not in (
collections.abc.Sequence,
@@ -391,7 +399,7 @@ def str_from_instance(instance: Sequence) -> list[str]:
choices=inner_spec.choices,
)
- @registry.primitive_rule
+ @registry._default_primitive_rule
def tuple_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
if type_info.type_origin is not tuple:
return None
@@ -456,7 +464,7 @@ def str_from_instance(instance: tuple) -> list[str]:
and all(spec.is_instance(member) for member, spec in zip(x, inner_specs)),
)
- @registry.primitive_rule
+ @registry._default_primitive_rule
def dict_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
if type_info.type_origin not in (dict, collections.abc.Mapping):
return None
@@ -549,7 +557,7 @@ def str_from_instance(instance: dict) -> list[str]:
str_from_instance=str_from_instance,
)
- @registry.primitive_rule
+ @registry._default_primitive_rule
def literal_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
if type_info.type_origin not in (Literal, LiteralAlternate):
return None
@@ -575,7 +583,7 @@ def literal_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | Non
choices=str_choices,
)
- @registry.primitive_rule
+ @registry._default_primitive_rule
def union_rule(type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | None:
if type_info.type_origin not in (Union, _resolver.UnionType):
return None
diff --git a/src/tyro/constructors/_registry.py b/src/tyro/constructors/_registry.py
index 654379fc..781d406a 100644
--- a/src/tyro/constructors/_registry.py
+++ b/src/tyro/constructors/_registry.py
@@ -2,6 +2,8 @@
from typing import Any, Callable, ClassVar, Union
+from typing_extensions import Literal
+
from ._primitive_spec import (
PrimitiveConstructorSpec,
PrimitiveTypeInfo,
@@ -58,7 +60,8 @@ class ConstructorRegistry:
_old_registry: ConstructorRegistry | None = None
def __init__(self) -> None:
- self._primitive_rules: list[PrimitiveSpecRule] = []
+ self._default_primitive_rules: list[PrimitiveSpecRule] = []
+ self._custom_primitive_rules: list[PrimitiveSpecRule] = []
self._struct_rules: list[StructSpecRule] = []
# Apply the default primitive-handling rules.
@@ -67,9 +70,17 @@ def __init__(self) -> None:
def primitive_rule(self, rule: PrimitiveSpecRule) -> PrimitiveSpecRule:
"""Define a rule for constructing a primitive type from a string. The
- most recently added rule will be applied first."""
+ most recently added rule will be applied first.
+
+ Custom primitive rules will take precedence over both default primitive
+ rules and struct rules
+ """
- self._primitive_rules.append(rule)
+ self._custom_primitive_rules.append(rule)
+ return rule
+
+ def _default_primitive_rule(self, rule: PrimitiveSpecRule) -> PrimitiveSpecRule:
+ self._default_primitive_rules.append(rule)
return rule
def struct_rule(self, rule: StructSpecRule) -> StructSpecRule:
@@ -80,13 +91,25 @@ def struct_rule(self, rule: StructSpecRule) -> StructSpecRule:
return rule
def get_primitive_spec(
- self, type_info: PrimitiveTypeInfo
+ self,
+ type_info: PrimitiveTypeInfo,
+ rule_mode: Literal["default", "custom", "all"] = "all",
) -> PrimitiveConstructorSpec:
"""Get a constructor specification for a given type."""
- for spec_factory in self._primitive_rules[::-1]:
- maybe_spec = spec_factory(type_info)
- if maybe_spec is not None:
- return maybe_spec
+
+ if type_info._primitive_spec is not None:
+ return type_info._primitive_spec
+
+ if rule_mode in ("custom", "all"):
+ for spec_factory in self._custom_primitive_rules[::-1]:
+ maybe_spec = spec_factory(type_info)
+ if maybe_spec is not None:
+ return maybe_spec
+ if rule_mode in ("default", "all"):
+ for spec_factory in self._default_primitive_rules[::-1]:
+ maybe_spec = spec_factory(type_info)
+ if maybe_spec is not None:
+ return maybe_spec
raise UnsupportedTypeAnnotationError(
f"Unsupported type annotation: {type_info.type}"
@@ -97,6 +120,7 @@ def get_struct_spec(
) -> StructConstructorSpec | None:
"""Get a constructor specification for a given type. Returns `None` if
unsuccessful."""
+
for spec_factory in self._struct_rules[::-1]:
maybe_spec = spec_factory(type_info)
if maybe_spec is not None:
diff --git a/src/tyro/constructors/_struct_spec.py b/src/tyro/constructors/_struct_spec.py
index 06f0776f..8f54c21c 100644
--- a/src/tyro/constructors/_struct_spec.py
+++ b/src/tyro/constructors/_struct_spec.py
@@ -43,18 +43,40 @@ class StructFieldSpec:
"""Behavior specification for a single field in our callable."""
name: str
+ """The name of the field. This will be used as a keyword argument for the
+ struct's associated `instantiate(**kwargs)` function."""
type: TypeForm
+ """The type of the field. Can be either a primitive or a nested struct type."""
default: Any
- is_default_overridden: bool
- helptext: str | None
+ """The default value of the field."""
+ is_default_overridden: bool = False
+ """Whether the default value was overridden by the default instance. Should
+ be set to False if the default value was assigned by the field itself."""
+ helptext: str | None = None
+ """Helpjext for the field."""
# TODO: it's theoretically possible to override the argname with `None`.
_call_argname: Any = None
+ """Private: the name of the argument to pass to the callable. This is used
+ for dictionary types."""
@dataclasses.dataclass(frozen=True)
class StructConstructorSpec:
+ """Specification for a struct type, which is broken down into multiple
+ fields.
+
+ Each struct type is instantiated by calling an `instantiate(**kwargs)`
+ function with keyword a set of keyword arguments.
+
+ Unlike `PrimitiveConstructorSpec`, there is only one way to use this class.
+ It must be returned by a rule in `ConstructorRegistry`.
+ """
+
instantiate: Callable[..., Any]
+ """Function to call to instantiate the struct."""
fields: tuple[StructFieldSpec, ...]
+ """Fields used to construct the callable. Each field is used as a keyword
+ argument for the `instantiate(**kwargs)` function."""
@dataclasses.dataclass(frozen=True)
@@ -81,6 +103,7 @@ def make(f: TypeForm | Callable, default: Any) -> StructTypeInfo:
f = typevar_context.origin_type
f = _resolver.narrow_subtypes(f, default)
f = _resolver.narrow_collection_types(f, default)
+
return StructTypeInfo(
cast(TypeForm, f), parent_markers, default, typevar_context
)
diff --git a/src/tyro/extras/__init__.py b/src/tyro/extras/__init__.py
index 5341bc3d..e547a39e 100644
--- a/src/tyro/extras/__init__.py
+++ b/src/tyro/extras/__init__.py
@@ -1,6 +1,7 @@
"""The :mod:`tyro.extras` submodule contains helpers that complement :func:`tyro.cli()`.
-Compared to the core interface, APIs here are more likely to be changed or deprecated.
+.. warning::
+ Compared to the core interface, APIs here are more likely to be changed or deprecated.
"""
from .._argparse_formatter import set_accent_color as set_accent_color
diff --git a/tests/test_custom_primitive.py b/tests/test_custom_constructors.py
similarity index 66%
rename from tests/test_custom_primitive.py
rename to tests/test_custom_constructors.py
index a2758999..4c2a1cbe 100644
--- a/tests/test_custom_primitive.py
+++ b/tests/test_custom_constructors.py
@@ -1,7 +1,7 @@
from __future__ import annotations
import json
-from typing import Any, Dict
+from typing import Any, Dict, Union
from typing_extensions import Annotated, get_args
@@ -36,6 +36,13 @@ def main(x: Dict[str, Any]) -> Dict[str, Any]:
with primitive_registry:
assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1}
+ def main_with_default(x: Dict[str, Any] = {"hello": 5}) -> Dict[str, Any]:
+ return x
+
+ with primitive_registry:
+ assert tyro.cli(main_with_default, args=[]) == {"hello": 5}
+ assert tyro.cli(main_with_default, args=["--x", '{"a": 1}']) == {"a": 1}
+
def test_custom_primitive_annotated():
"""Test that we can use typing.Annotated to specify custom constructors."""
@@ -44,3 +51,15 @@ def main(x: Annotated[Dict[str, Any], json_constructor_spec]) -> Dict[str, Any]:
return x
assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1}
+
+
+def test_custom_primitive_union():
+ """Test that we can use typing.Annotated to specify custom constructors."""
+
+ def main(
+ x: Union[int, Annotated[Dict[str, Any], json_constructor_spec]],
+ ) -> Union[int, Dict[str, Any]]:
+ return x
+
+ assert tyro.cli(main, args=["--x", "3"]) == 3
+ assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1}
diff --git a/tests/test_errors.py b/tests/test_errors.py
index 1e4c9fd5..09ea7b53 100644
--- a/tests/test_errors.py
+++ b/tests/test_errors.py
@@ -572,7 +572,7 @@ def main2() -> None:
def test_wrong_annotation() -> None:
@dataclasses.dataclass
class Args:
- x: dict = None # type: ignore
+ x: Union[dict, int] = None # type: ignore
with pytest.warns(UserWarning):
assert tyro.cli(Args, args=[]).x is None
diff --git a/tests/test_py311_generated/ok.py b/tests/test_py311_generated/ok.py
deleted file mode 100644
index 73e5b33b..00000000
--- a/tests/test_py311_generated/ok.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from dataclasses import dataclass
-from typing import Literal
-
-import tyro
-
-
-@dataclass(frozen=True)
-class Container[T]:
- a: T
-
-
-tyro.cli(Container[Container[bool] | Container[Literal["1", "2"]]])
diff --git a/tests/test_py311_generated/test_custom_primitive_generated.py b/tests/test_py311_generated/test_custom_constructors_generated.py
similarity index 68%
rename from tests/test_py311_generated/test_custom_primitive_generated.py
rename to tests/test_py311_generated/test_custom_constructors_generated.py
index 4244759a..6e0cd35e 100644
--- a/tests/test_py311_generated/test_custom_primitive_generated.py
+++ b/tests/test_py311_generated/test_custom_constructors_generated.py
@@ -34,6 +34,13 @@ def main(x: Dict[str, Any]) -> Dict[str, Any]:
with primitive_registry:
assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1}
+ def main_with_default(x: Dict[str, Any] = {"hello": 5}) -> Dict[str, Any]:
+ return x
+
+ with primitive_registry:
+ assert tyro.cli(main_with_default, args=[]) == {"hello": 5}
+ assert tyro.cli(main_with_default, args=["--x", '{"a": 1}']) == {"a": 1}
+
def test_custom_primitive_annotated():
"""Test that we can use typing.Annotated to specify custom constructors."""
@@ -42,3 +49,15 @@ def main(x: Annotated[Dict[str, Any], json_constructor_spec]) -> Dict[str, Any]:
return x
assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1}
+
+
+def test_custom_primitive_union():
+ """Test that we can use typing.Annotated to specify custom constructors."""
+
+ def main(
+ x: int | Annotated[Dict[str, Any], json_constructor_spec],
+ ) -> int | Dict[str, Any]:
+ return x
+
+ assert tyro.cli(main, args=["--x", "3"]) == 3
+ assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1}
diff --git a/tests/test_py311_generated/test_errors_generated.py b/tests/test_py311_generated/test_errors_generated.py
index 81fdd150..7cbe5a54 100644
--- a/tests/test_py311_generated/test_errors_generated.py
+++ b/tests/test_py311_generated/test_errors_generated.py
@@ -583,7 +583,7 @@ def main2() -> None:
def test_wrong_annotation() -> None:
@dataclasses.dataclass
class Args:
- x: dict = None # type: ignore
+ x: dict | int = None # type: ignore
with pytest.warns(UserWarning):
assert tyro.cli(Args, args=[]).x is None