Skip to content

Commit

Permalink
Implement counter actions (#130)
Browse files Browse the repository at this point in the history
* Implement counter actions

* ruff

* Add counter test, sync docs

* Special-case for Python 3.7

* Formatting
  • Loading branch information
brentyi authored Apr 4, 2024
1 parent 42d6944 commit b43c668
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 2 deletions.
69 changes: 69 additions & 0 deletions docs/source/examples/04_additional/12_counters.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
.. Comment: this file is automatically generated by `update_example_docs.py`.
It should not be modified manually.
Counters
==========================================


Repeatable 'counter' arguments can be specified via :data:`tyro.conf.UseCounterAction`.



.. code-block:: python
:linenos:
from typing_extensions import Annotated
import tyro
from tyro.conf import UseCounterAction
def main(
verbosity: UseCounterAction[int],
aliased_verbosity: Annotated[UseCounterAction[int], tyro.conf.arg(aliases=["-v"])],
) -> None:
"""Example showing how to use counter actions.
Args:
verbosity: Verbosity level.
aliased_verbosity: Same as above, but can also be specified with -v, -vv, -vvv, etc.
"""
print("Verbosity level:", verbosity)
print("Verbosity level (aliased):", aliased_verbosity)
if __name__ == "__main__":
tyro.cli(main)
------------

.. raw:: html

<kbd>python 04_additional/12_counters.py --help</kbd>

.. program-output:: python ../../examples/04_additional/12_counters.py --help

------------

.. raw:: html

<kbd>python 04_additional/12_counters.py --verbosity</kbd>

.. program-output:: python ../../examples/04_additional/12_counters.py --verbosity

------------

.. raw:: html

<kbd>python 04_additional/12_counters.py --verbosity --verbosity</kbd>

.. program-output:: python ../../examples/04_additional/12_counters.py --verbosity --verbosity

------------

.. raw:: html

<kbd>python 04_additional/12_counters.py -vvv</kbd>

.. program-output:: python ../../examples/04_additional/12_counters.py -vvv
33 changes: 33 additions & 0 deletions examples/04_additional/12_counters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Counters
Repeatable 'counter' arguments can be specified via :data:`tyro.conf.UseCounterAction`.
Usage:
`python ./12_counters.py --help`
`python ./12_counters.py --verbosity`
`python ./12_counters.py --verbosity --verbosity`
`python ./12_counters.py -vvv`
"""

from typing_extensions import Annotated

import tyro
from tyro.conf import UseCounterAction


def main(
verbosity: UseCounterAction[int],
aliased_verbosity: Annotated[UseCounterAction[int], tyro.conf.arg(aliases=["-v"])],
) -> None:
"""Example showing how to use counter actions.
Args:
verbosity: Verbosity level.
aliased_verbosity: Same as above, but can also be specified with -v, -vv, -vvv, etc.
"""
print("Verbosity level:", verbosity)
print("Verbosity level (aliased):", aliased_verbosity)


if __name__ == "__main__":
tyro.cli(main)
30 changes: 28 additions & 2 deletions src/tyro/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ def add_argument(
# directly be used. This helps reduce the likelihood of issues with converting
# the field default to a string format, then back to the desired type.
action = kwargs.get("action", None)
if action != "append":
if action not in {"append", "count"}:
kwargs["default"] = _fields.MISSING_NONPROP
elif action == BooleanOptionalAction:
elif action in {BooleanOptionalAction, "count"}:
pass
else:
kwargs["default"] = []
Expand Down Expand Up @@ -193,6 +193,7 @@ def lowered(self) -> LoweredArgumentDefinition:
_rule_handle_boolean_flags,
_rule_recursive_instantiator_from_type,
_rule_convert_defaults_to_strings,
_rule_counters,
_rule_generate_helptext,
_rule_set_name_or_flag_and_dest,
_rule_positional_special_handling,
Expand Down Expand Up @@ -405,6 +406,28 @@ def _rich_tag_if_enabled(x: str, tag: str) -> str:
return x if not USE_RICH else f"[{tag}]{x}[/{tag}]"


def _rule_counters(
arg: ArgumentDefinition,
lowered: LoweredArgumentDefinition,
) -> LoweredArgumentDefinition:
"""Handle counters, like -vvv for level-3 verbosity."""
if (
_markers.UseCounterAction in arg.field.markers
and arg.field.type_or_callable is int
and not arg.field.is_positional()
):
return dataclasses.replace(
lowered,
metavar=None,
nargs=None,
action="count",
default=0,
required=False,
instantiator=lambda x: x, # argparse will directly give us an int!
)
return lowered


def _rule_generate_helptext(
arg: ArgumentDefinition,
lowered: LoweredArgumentDefinition,
Expand Down Expand Up @@ -465,6 +488,9 @@ def _rule_generate_helptext(
# Intentionally not quoted via shlex, since this can't actually be passed
# in via the commandline.
default_text = f"(fixed to: {default_label})"
elif lowered.action == "count":
# Repeatable argument.
default_text = "(repeatable)"
elif lowered.action == "append" and (
default in _fields.MISSING_SINGLETONS or len(cast(tuple, default)) == 0
):
Expand Down
1 change: 1 addition & 0 deletions src/tyro/conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
from ._markers import Suppress as Suppress
from ._markers import SuppressFixed as SuppressFixed
from ._markers import UseAppendAction as UseAppendAction
from ._markers import UseCounterAction as UseCounterAction
from ._markers import configure as configure
4 changes: 4 additions & 0 deletions src/tyro/conf/_markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@
`Tuple[T, ...]`, etc), including dictionaries without default values.
"""

UseCounterAction = Annotated[T, None]
"""Use "counter" actions for integer arguments. Example usage: `verbose: UseCounterAction[int]`."""


CallableType = TypeVar("CallableType", bound=Callable)

# Dynamically generate marker singletons.
Expand Down
24 changes: 24 additions & 0 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io
import json as json_
import shlex
import sys
from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union

import pytest
Expand Down Expand Up @@ -1360,3 +1361,26 @@ class DatasetConfig:
helptext = target.getvalue()
assert "OptimizerConfig options" in helptext
assert "DatasetConfig options" in helptext


def test_counter_action() -> None:
def main(
verbosity: tyro.conf.UseCounterAction[int],
aliased_verbosity: Annotated[
tyro.conf.UseCounterAction[int], tyro.conf.arg(aliases=["-v"])
],
) -> Tuple[int, int]:
"""Example showing how to use counter actions.
Args:
verbosity: Verbosity level.
aliased_verbosity: Same as above, but can also be specified with -v, -vv, -vvv, etc.
"""
return verbosity, aliased_verbosity

assert tyro.cli(main, args=[]) == (0, 0)
assert tyro.cli(main, args="--verbosity --verbosity".split(" ")) == (2, 0)
assert tyro.cli(main, args="--verbosity --verbosity -v".split(" ")) == (2, 1)
if sys.version_info >= (3, 8):
# Doesn't work in Python 3.7 because of argparse limitations.
assert tyro.cli(main, args="--verbosity --verbosity -vv".split(" ")) == (2, 2)
assert tyro.cli(main, args="--verbosity --verbosity -vvv".split(" ")) == (2, 3)

0 comments on commit b43c668

Please sign in to comment.