Skip to content

Commit

Permalink
Less busy main API
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 30, 2024
1 parent eaf12af commit c5ca29b
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 62 deletions.
4 changes: 2 additions & 2 deletions src/tyro/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from . import _fields, _strings
from .conf import _markers
from .constructors._primitive_spec import (
PrimitiveConstructorRegistry,
PrimitiveTypeInfo,
UnsupportedTypeAnnotationError,
get_current_primitive_registry,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -284,7 +284,7 @@ def _rule_apply_primitive_specs(
if arg.field.primitive_spec is not None:
spec = arg.field.primitive_spec
else:
registry = get_current_primitive_registry()
registry = PrimitiveConstructorRegistry._get_active_registry()
spec = registry.get_spec(
PrimitiveTypeInfo.make(
cast(type, arg.field.type_or_callable),
Expand Down
37 changes: 12 additions & 25 deletions src/tyro/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
conf,
)
from ._typing import TypeForm
from .constructors._primitive_spec import (
PrimitiveConstructorRegistry,
use_primitive_registry,
)

OutT = TypeVar("OutT")

Expand All @@ -49,7 +45,6 @@ def cli(
use_underscores: bool = False,
console_outputs: bool = True,
config: None | Sequence[conf._markers.Marker] = None,
primitive_constructor_registry: PrimitiveConstructorRegistry | None = None,
) -> OutT: ...


Expand All @@ -65,7 +60,6 @@ def cli(
use_underscores: bool = False,
console_outputs: bool = True,
config: None | Sequence[conf._markers.Marker] = None,
primitive_constructor_registry: PrimitiveConstructorRegistry | None = None,
) -> tuple[OutT, list[str]]: ...


Expand All @@ -84,7 +78,6 @@ def cli(
use_underscores: bool = False,
console_outputs: bool = True,
config: None | Sequence[conf._markers.Marker] = None,
primitive_constructor_registry: PrimitiveConstructorRegistry | None = None,
) -> OutT: ...


Expand All @@ -103,7 +96,6 @@ def cli(
use_underscores: bool = False,
console_outputs: bool = True,
config: None | Sequence[conf._markers.Marker] = None,
primitive_constructor_registry: PrimitiveConstructorRegistry | None = None,
) -> tuple[OutT, list[str]]: ...


Expand All @@ -118,7 +110,6 @@ def cli(
use_underscores: bool = False,
console_outputs: bool = True,
config: None | Sequence[conf._markers.Marker] = None,
primitive_constructor_registry: PrimitiveConstructorRegistry | None = None,
**deprecated_kwargs,
) -> OutT | tuple[OutT, list[str]]:
"""Call or instantiate `f`, with inputs populated from an automatically generated
Expand Down Expand Up @@ -189,9 +180,6 @@ def cli(
alternative to using them locally in annotations
(`tyro.conf.FlagConversionOff[bool]`), we can also pass in a sequence of
them here to apply globally.
primitive_constructor_registry: A custom registry for primitive constructors.
Not typically needed, but can be used to extend the set of primitive types
that are supported by `tyro`.
Returns:
The output of `f(...)` or an instance `f`. If `f` is a class, the two are
Expand All @@ -207,19 +195,18 @@ def cli(
f = conf.configure(*config)(f)

with _strings.delimeter_context("_" if use_underscores else "-"):
with use_primitive_registry(primitive_constructor_registry):
output = _cli_impl(
f,
prog=prog,
description=description,
args=args,
default=default,
return_parser=False,
return_unknown_args=return_unknown_args,
use_underscores=use_underscores,
console_outputs=console_outputs,
**deprecated_kwargs,
)
output = _cli_impl(
f,
prog=prog,
description=description,
args=args,
default=default,
return_parser=False,
return_unknown_args=return_unknown_args,
use_underscores=use_underscores,
console_outputs=console_outputs,
**deprecated_kwargs,
)

# Prevent unnecessary memory usage.
_unsafe_cache.clear_cache()
Expand Down
51 changes: 23 additions & 28 deletions src/tyro/constructors/_primitive_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import collections
import collections.abc
import contextlib
import dataclasses
import datetime
import enum
Expand All @@ -14,8 +13,8 @@
from typing import (
Any,
Callable,
ClassVar,
Dict,
Generator,
Generic,
List,
Sequence,
Expand Down Expand Up @@ -48,31 +47,6 @@ class UnsupportedTypeAnnotationError(Exception):

T = TypeVar("T")

current_registry: PrimitiveConstructorRegistry | None = None


def get_current_primitive_registry() -> PrimitiveConstructorRegistry:
"""For internal use: get the current primitive registry."""
global current_registry
if current_registry is None:
current_registry = PrimitiveConstructorRegistry()
return current_registry


@contextlib.contextmanager
def use_primitive_registry(
registry: PrimitiveConstructorRegistry | None,
) -> Generator[None, None, None]:
"""For internal use: temporarily use a different primitive registry."""
global current_registry
if registry is not None:
old_registry = current_registry
current_registry = registry
yield
current_registry = old_registry
else:
yield


@dataclasses.dataclass(frozen=True)
class PrimitiveTypeInfo:
Expand Down Expand Up @@ -140,13 +114,17 @@ class PrimitiveConstructorSpec(Generic[T]):

SpecFactory = Callable[[PrimitiveTypeInfo], PrimitiveConstructorSpec]

current_registry: PrimitiveConstructorRegistry | None = None


class PrimitiveConstructorRegistry:
"""Registry for rules that define how primitive types that can be
constructed from a single command-line argument."""

_active_registry: ClassVar[PrimitiveConstructorRegistry | None] = None
_old_registry: PrimitiveConstructorRegistry | None = None

def __init__(self) -> None:
self._old_registry: PrimitiveConstructorRegistry | None = None
self._rules: list[
tuple[
# Matching function.
Expand Down Expand Up @@ -181,6 +159,23 @@ def get_spec(self, type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec:
f"Unsupported type annotation: {type_info.type}"
)

@classmethod
def _get_active_registry(cls) -> PrimitiveConstructorRegistry:
"""Get the active registry. Can be changed by using a
PrimitiveConstructorRegistry object as a context."""
if cls._active_registry is None:
cls._active_registry = PrimitiveConstructorRegistry()
return cls._active_registry

def __enter__(self) -> None:
cls = self.__class__
self._old_registry = cls._active_registry
cls._active_registry = self

def __exit__(self, *args: Any) -> None:
cls = self.__class__
cls._active_registry = self._old_registry


def _apply_default_rules(registry: PrimitiveConstructorRegistry) -> None:
"""Apply default rules to the registry."""
Expand Down
13 changes: 6 additions & 7 deletions tests/test_custom_primitive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Any, Dict, get_args
from typing import Any, Dict

from typing_extensions import Annotated
from typing_extensions import Annotated, get_args

import tyro

Expand All @@ -16,9 +16,9 @@

def test_custom_primitive_registry():
"""Test that we can use a custom primitive registry to parse a custom type."""
registry = tyro.constructors.PrimitiveConstructorRegistry()
primitive_registry = tyro.constructors.PrimitiveConstructorRegistry()

@registry.define_rule(
@primitive_registry.define_rule(
matcher_fn=lambda type_info: type_info.type_origin is dict
and get_args(type_info.type) == (str, Any)
)
Expand All @@ -31,9 +31,8 @@ def json_dict_spec(
def main(x: Dict[str, Any]) -> Dict[str, Any]:
return x

assert tyro.cli(
main, args=["--x", '{"a": 1}'], primitive_constructor_registry=registry
) == {"a": 1}
with primitive_registry:
assert tyro.cli(main, args=["--x", '{"a": 1}']) == {"a": 1}


def test_custom_primitive_annotated():
Expand Down

0 comments on commit c5ca29b

Please sign in to comment.