From c5ca29b1497a7a94040d38f7fc95a0d0a9bf7998 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 29 Oct 2024 20:02:29 -0700 Subject: [PATCH] Less busy main API --- src/tyro/_arguments.py | 4 +- src/tyro/_cli.py | 37 ++++++----------- src/tyro/constructors/_primitive_spec.py | 51 +++++++++++------------- tests/test_custom_primitive.py | 13 +++--- 4 files changed, 43 insertions(+), 62 deletions(-) diff --git a/src/tyro/_arguments.py b/src/tyro/_arguments.py index 83a9cd7a..44e36dcd 100644 --- a/src/tyro/_arguments.py +++ b/src/tyro/_arguments.py @@ -27,9 +27,9 @@ from . import _fields, _strings from .conf import _markers from .constructors._primitive_spec import ( + PrimitiveConstructorRegistry, PrimitiveTypeInfo, UnsupportedTypeAnnotationError, - get_current_primitive_registry, ) if TYPE_CHECKING: @@ -284,7 +284,7 @@ def _rule_apply_primitive_specs( if arg.field.primitive_spec is not None: spec = arg.field.primitive_spec else: - registry = get_current_primitive_registry() + registry = PrimitiveConstructorRegistry._get_active_registry() spec = registry.get_spec( PrimitiveTypeInfo.make( cast(type, arg.field.type_or_callable), diff --git a/src/tyro/_cli.py b/src/tyro/_cli.py index ab2a5a0b..6703d65b 100644 --- a/src/tyro/_cli.py +++ b/src/tyro/_cli.py @@ -23,10 +23,6 @@ conf, ) from ._typing import TypeForm -from .constructors._primitive_spec import ( - PrimitiveConstructorRegistry, - use_primitive_registry, -) OutT = TypeVar("OutT") @@ -49,7 +45,6 @@ def cli( use_underscores: bool = False, console_outputs: bool = True, config: None | Sequence[conf._markers.Marker] = None, - primitive_constructor_registry: PrimitiveConstructorRegistry | None = None, ) -> OutT: ... @@ -65,7 +60,6 @@ def cli( use_underscores: bool = False, console_outputs: bool = True, config: None | Sequence[conf._markers.Marker] = None, - primitive_constructor_registry: PrimitiveConstructorRegistry | None = None, ) -> tuple[OutT, list[str]]: ... @@ -84,7 +78,6 @@ def cli( use_underscores: bool = False, console_outputs: bool = True, config: None | Sequence[conf._markers.Marker] = None, - primitive_constructor_registry: PrimitiveConstructorRegistry | None = None, ) -> OutT: ... @@ -103,7 +96,6 @@ def cli( use_underscores: bool = False, console_outputs: bool = True, config: None | Sequence[conf._markers.Marker] = None, - primitive_constructor_registry: PrimitiveConstructorRegistry | None = None, ) -> tuple[OutT, list[str]]: ... @@ -118,7 +110,6 @@ def cli( use_underscores: bool = False, console_outputs: bool = True, config: None | Sequence[conf._markers.Marker] = None, - primitive_constructor_registry: PrimitiveConstructorRegistry | None = None, **deprecated_kwargs, ) -> OutT | tuple[OutT, list[str]]: """Call or instantiate `f`, with inputs populated from an automatically generated @@ -189,9 +180,6 @@ def cli( alternative to using them locally in annotations (`tyro.conf.FlagConversionOff[bool]`), we can also pass in a sequence of them here to apply globally. - primitive_constructor_registry: A custom registry for primitive constructors. - Not typically needed, but can be used to extend the set of primitive types - that are supported by `tyro`. Returns: The output of `f(...)` or an instance `f`. If `f` is a class, the two are @@ -207,19 +195,18 @@ def cli( f = conf.configure(*config)(f) with _strings.delimeter_context("_" if use_underscores else "-"): - with use_primitive_registry(primitive_constructor_registry): - output = _cli_impl( - f, - prog=prog, - description=description, - args=args, - default=default, - return_parser=False, - return_unknown_args=return_unknown_args, - use_underscores=use_underscores, - console_outputs=console_outputs, - **deprecated_kwargs, - ) + output = _cli_impl( + f, + prog=prog, + description=description, + args=args, + default=default, + return_parser=False, + return_unknown_args=return_unknown_args, + use_underscores=use_underscores, + console_outputs=console_outputs, + **deprecated_kwargs, + ) # Prevent unnecessary memory usage. _unsafe_cache.clear_cache() diff --git a/src/tyro/constructors/_primitive_spec.py b/src/tyro/constructors/_primitive_spec.py index 452e6cbe..12d2ef37 100644 --- a/src/tyro/constructors/_primitive_spec.py +++ b/src/tyro/constructors/_primitive_spec.py @@ -2,7 +2,6 @@ import collections import collections.abc -import contextlib import dataclasses import datetime import enum @@ -14,8 +13,8 @@ from typing import ( Any, Callable, + ClassVar, Dict, - Generator, Generic, List, Sequence, @@ -48,31 +47,6 @@ class UnsupportedTypeAnnotationError(Exception): T = TypeVar("T") -current_registry: PrimitiveConstructorRegistry | None = None - - -def get_current_primitive_registry() -> PrimitiveConstructorRegistry: - """For internal use: get the current primitive registry.""" - global current_registry - if current_registry is None: - current_registry = PrimitiveConstructorRegistry() - return current_registry - - -@contextlib.contextmanager -def use_primitive_registry( - registry: PrimitiveConstructorRegistry | None, -) -> Generator[None, None, None]: - """For internal use: temporarily use a different primitive registry.""" - global current_registry - if registry is not None: - old_registry = current_registry - current_registry = registry - yield - current_registry = old_registry - else: - yield - @dataclasses.dataclass(frozen=True) class PrimitiveTypeInfo: @@ -140,13 +114,17 @@ class PrimitiveConstructorSpec(Generic[T]): SpecFactory = Callable[[PrimitiveTypeInfo], PrimitiveConstructorSpec] +current_registry: PrimitiveConstructorRegistry | None = None + class PrimitiveConstructorRegistry: """Registry for rules that define how primitive types that can be constructed from a single command-line argument.""" + _active_registry: ClassVar[PrimitiveConstructorRegistry | None] = None + _old_registry: PrimitiveConstructorRegistry | None = None + def __init__(self) -> None: - self._old_registry: PrimitiveConstructorRegistry | None = None self._rules: list[ tuple[ # Matching function. @@ -181,6 +159,23 @@ def get_spec(self, type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec: f"Unsupported type annotation: {type_info.type}" ) + @classmethod + def _get_active_registry(cls) -> PrimitiveConstructorRegistry: + """Get the active registry. Can be changed by using a + PrimitiveConstructorRegistry object as a context.""" + if cls._active_registry is None: + cls._active_registry = PrimitiveConstructorRegistry() + return cls._active_registry + + def __enter__(self) -> None: + cls = self.__class__ + self._old_registry = cls._active_registry + cls._active_registry = self + + def __exit__(self, *args: Any) -> None: + cls = self.__class__ + cls._active_registry = self._old_registry + def _apply_default_rules(registry: PrimitiveConstructorRegistry) -> None: """Apply default rules to the registry.""" diff --git a/tests/test_custom_primitive.py b/tests/test_custom_primitive.py index e990f2af..c9c9d1f6 100644 --- a/tests/test_custom_primitive.py +++ b/tests/test_custom_primitive.py @@ -1,7 +1,7 @@ import json -from typing import Any, Dict, get_args +from typing import Any, Dict -from typing_extensions import Annotated +from typing_extensions import Annotated, get_args import tyro @@ -16,9 +16,9 @@ def test_custom_primitive_registry(): """Test that we can use a custom primitive registry to parse a custom type.""" - registry = tyro.constructors.PrimitiveConstructorRegistry() + primitive_registry = tyro.constructors.PrimitiveConstructorRegistry() - @registry.define_rule( + @primitive_registry.define_rule( matcher_fn=lambda type_info: type_info.type_origin is dict and get_args(type_info.type) == (str, Any) ) @@ -31,9 +31,8 @@ def json_dict_spec( def main(x: Dict[str, Any]) -> Dict[str, Any]: return x - assert tyro.cli( - main, args=["--x", '{"a": 1}'], primitive_constructor_registry=registry - ) == {"a": 1} + with primitive_registry: + assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1} def test_custom_primitive_annotated():