From b43c668c5b811baf73c03a53c41140bf18dde97f Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Thu, 4 Apr 2024 01:08:39 -0700 Subject: [PATCH] Implement counter actions (#130) * Implement counter actions * ruff * Add counter test, sync docs * Special-case for Python 3.7 * Formatting --- .../examples/04_additional/12_counters.rst | 69 +++++++++++++++++++ examples/04_additional/12_counters.py | 33 +++++++++ src/tyro/_arguments.py | 30 +++++++- src/tyro/conf/__init__.py | 1 + src/tyro/conf/_markers.py | 4 ++ tests/test_conf.py | 24 +++++++ 6 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 docs/source/examples/04_additional/12_counters.rst create mode 100644 examples/04_additional/12_counters.py diff --git a/docs/source/examples/04_additional/12_counters.rst b/docs/source/examples/04_additional/12_counters.rst new file mode 100644 index 00000000..d0bedf57 --- /dev/null +++ b/docs/source/examples/04_additional/12_counters.rst @@ -0,0 +1,69 @@ +.. Comment: this file is automatically generated by `update_example_docs.py`. + It should not be modified manually. + +Counters +========================================== + + +Repeatable 'counter' arguments can be specified via :data:`tyro.conf.UseCounterAction`. + + + +.. code-block:: python + :linenos: + + + from typing_extensions import Annotated + + import tyro + from tyro.conf import UseCounterAction + + + def main( + verbosity: UseCounterAction[int], + aliased_verbosity: Annotated[UseCounterAction[int], tyro.conf.arg(aliases=["-v"])], + ) -> None: + """Example showing how to use counter actions. + + Args: + verbosity: Verbosity level. + aliased_verbosity: Same as above, but can also be specified with -v, -vv, -vvv, etc. + """ + print("Verbosity level:", verbosity) + print("Verbosity level (aliased):", aliased_verbosity) + + + if __name__ == "__main__": + tyro.cli(main) + +------------ + +.. raw:: html + + python 04_additional/12_counters.py --help + +.. program-output:: python ../../examples/04_additional/12_counters.py --help + +------------ + +.. raw:: html + + python 04_additional/12_counters.py --verbosity + +.. program-output:: python ../../examples/04_additional/12_counters.py --verbosity + +------------ + +.. raw:: html + + python 04_additional/12_counters.py --verbosity --verbosity + +.. program-output:: python ../../examples/04_additional/12_counters.py --verbosity --verbosity + +------------ + +.. raw:: html + + python 04_additional/12_counters.py -vvv + +.. program-output:: python ../../examples/04_additional/12_counters.py -vvv diff --git a/examples/04_additional/12_counters.py b/examples/04_additional/12_counters.py new file mode 100644 index 00000000..2a6dd2bb --- /dev/null +++ b/examples/04_additional/12_counters.py @@ -0,0 +1,33 @@ +"""Counters + +Repeatable 'counter' arguments can be specified via :data:`tyro.conf.UseCounterAction`. + +Usage: +`python ./12_counters.py --help` +`python ./12_counters.py --verbosity` +`python ./12_counters.py --verbosity --verbosity` +`python ./12_counters.py -vvv` +""" + +from typing_extensions import Annotated + +import tyro +from tyro.conf import UseCounterAction + + +def main( + verbosity: UseCounterAction[int], + aliased_verbosity: Annotated[UseCounterAction[int], tyro.conf.arg(aliases=["-v"])], +) -> None: + """Example showing how to use counter actions. + + Args: + verbosity: Verbosity level. + aliased_verbosity: Same as above, but can also be specified with -v, -vv, -vvv, etc. + """ + print("Verbosity level:", verbosity) + print("Verbosity level (aliased):", aliased_verbosity) + + +if __name__ == "__main__": + tyro.cli(main) diff --git a/src/tyro/_arguments.py b/src/tyro/_arguments.py index de7c94b5..5c90a0ac 100644 --- a/src/tyro/_arguments.py +++ b/src/tyro/_arguments.py @@ -132,9 +132,9 @@ def add_argument( # directly be used. This helps reduce the likelihood of issues with converting # the field default to a string format, then back to the desired type. action = kwargs.get("action", None) - if action != "append": + if action not in {"append", "count"}: kwargs["default"] = _fields.MISSING_NONPROP - elif action == BooleanOptionalAction: + elif action in {BooleanOptionalAction, "count"}: pass else: kwargs["default"] = [] @@ -193,6 +193,7 @@ def lowered(self) -> LoweredArgumentDefinition: _rule_handle_boolean_flags, _rule_recursive_instantiator_from_type, _rule_convert_defaults_to_strings, + _rule_counters, _rule_generate_helptext, _rule_set_name_or_flag_and_dest, _rule_positional_special_handling, @@ -405,6 +406,28 @@ def _rich_tag_if_enabled(x: str, tag: str) -> str: return x if not USE_RICH else f"[{tag}]{x}[/{tag}]" +def _rule_counters( + arg: ArgumentDefinition, + lowered: LoweredArgumentDefinition, +) -> LoweredArgumentDefinition: + """Handle counters, like -vvv for level-3 verbosity.""" + if ( + _markers.UseCounterAction in arg.field.markers + and arg.field.type_or_callable is int + and not arg.field.is_positional() + ): + return dataclasses.replace( + lowered, + metavar=None, + nargs=None, + action="count", + default=0, + required=False, + instantiator=lambda x: x, # argparse will directly give us an int! + ) + return lowered + + def _rule_generate_helptext( arg: ArgumentDefinition, lowered: LoweredArgumentDefinition, @@ -465,6 +488,9 @@ def _rule_generate_helptext( # Intentionally not quoted via shlex, since this can't actually be passed # in via the commandline. default_text = f"(fixed to: {default_label})" + elif lowered.action == "count": + # Repeatable argument. + default_text = "(repeatable)" elif lowered.action == "append" and ( default in _fields.MISSING_SINGLETONS or len(cast(tuple, default)) == 0 ): diff --git a/src/tyro/conf/__init__.py b/src/tyro/conf/__init__.py index ef969852..f5840777 100644 --- a/src/tyro/conf/__init__.py +++ b/src/tyro/conf/__init__.py @@ -21,4 +21,5 @@ from ._markers import Suppress as Suppress from ._markers import SuppressFixed as SuppressFixed from ._markers import UseAppendAction as UseAppendAction +from ._markers import UseCounterAction as UseCounterAction from ._markers import configure as configure diff --git a/src/tyro/conf/_markers.py b/src/tyro/conf/_markers.py index 6d04a38e..babfc810 100644 --- a/src/tyro/conf/_markers.py +++ b/src/tyro/conf/_markers.py @@ -123,6 +123,10 @@ `Tuple[T, ...]`, etc), including dictionaries without default values. """ +UseCounterAction = Annotated[T, None] +"""Use "counter" actions for integer arguments. Example usage: `verbose: UseCounterAction[int]`.""" + + CallableType = TypeVar("CallableType", bound=Callable) # Dynamically generate marker singletons. diff --git a/tests/test_conf.py b/tests/test_conf.py index 5321b055..764b01d6 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -4,6 +4,7 @@ import io import json as json_ import shlex +import sys from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union import pytest @@ -1360,3 +1361,26 @@ class DatasetConfig: helptext = target.getvalue() assert "OptimizerConfig options" in helptext assert "DatasetConfig options" in helptext + + +def test_counter_action() -> None: + def main( + verbosity: tyro.conf.UseCounterAction[int], + aliased_verbosity: Annotated[ + tyro.conf.UseCounterAction[int], tyro.conf.arg(aliases=["-v"]) + ], + ) -> Tuple[int, int]: + """Example showing how to use counter actions. + Args: + verbosity: Verbosity level. + aliased_verbosity: Same as above, but can also be specified with -v, -vv, -vvv, etc. + """ + return verbosity, aliased_verbosity + + assert tyro.cli(main, args=[]) == (0, 0) + assert tyro.cli(main, args="--verbosity --verbosity".split(" ")) == (2, 0) + assert tyro.cli(main, args="--verbosity --verbosity -v".split(" ")) == (2, 1) + if sys.version_info >= (3, 8): + # Doesn't work in Python 3.7 because of argparse limitations. + assert tyro.cli(main, args="--verbosity --verbosity -vv".split(" ")) == (2, 2) + assert tyro.cli(main, args="--verbosity --verbosity -vvv".split(" ")) == (2, 3)