Skip to content

Commit

Permalink
Support typing.Self (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi authored Jan 21, 2024
1 parent 7049ef0 commit d660da2
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def from_callable_or_type(
# superclass.
if f in parent_classes and f is not dict:
raise _instantiators.UnsupportedTypeAnnotationError(
f"Found a cyclic dataclass dependency with type {f}."
f"Found a cyclic dependency with type {f}."
)

# TODO: we are abusing the (minor) distinctions between types, classes, and
Expand Down
14 changes: 12 additions & 2 deletions src/tyro/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import collections.abc
import copy
import dataclasses
import inspect
import sys
import types
import warnings
Expand All @@ -20,7 +21,7 @@
cast,
)

from typing_extensions import Annotated, get_args, get_origin, get_type_hints
from typing_extensions import Annotated, Self, get_args, get_origin, get_type_hints

from . import _fields, _unsafe_cache
from ._typing import TypeForm
Expand Down Expand Up @@ -61,8 +62,17 @@ def resolve_generic_types(

# We'll ignore NewType when getting the origin + args for generics.
origin_cls = get_origin(unwrap_newtype(cls)[0])
type_from_typevar: Dict[TypeVar, TypeForm[Any]] = {}

# Support typing.Self.
# We'll do this by pretending that `Self` is a TypeVar...
if hasattr(cls, "__self__"):
self_type = getattr(cls, "__self__")
if inspect.isclass(self_type):
type_from_typevar[cast(TypeVar, Self)] = self_type
else:
type_from_typevar[cast(TypeVar, Self)] = self_type.__class__

type_from_typevar = {}
if (
# Apply some heuristics for generic types. Should revisit this.
origin_cls is not None
Expand Down
32 changes: 32 additions & 0 deletions tests/test_py311_generated/test_base_configs_nested_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,35 @@ def main(cfg: BaseConfig) -> BaseConfig:
),
DataConfig(2),
)


def test_pernicious_override():
"""From: https://github.com/nerfstudio-project/nerfstudio/issues/2789
Situation where we:
- have a default value in the config class
- override that default value with a subcommand annotation
- override it again with a default instance
"""
assert (
tyro.cli(
BaseConfig,
default=BaseConfig(
"test",
"test",
ExperimentConfig(
dataset="mnist",
optimizer=AdamOptimizer(),
batch_size=2048,
num_layers=4,
units=64,
train_steps=30_000,
seed=0,
activation=nn.ReLU,
),
DataConfig(0),
),
args="small small-data".split(" "),
).data_config.test
== 0
)
53 changes: 52 additions & 1 deletion tests/test_py311_generated/test_conf_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,57 @@ class Parent:
) == Parent(Nested1(Nested2(B(7))))


def test_subparser_in_nested_with_metadata_suppressed() -> None:
@dataclasses.dataclass(frozen=True)
class A:
a: tyro.conf.Suppress[int]

@dataclasses.dataclass
class B:
b: int
a: A = A(5)

@dataclasses.dataclass
class Nested2:
subcommand: Annotated[
A, tyro.conf.subcommand("command-a", default=A(7))
] | Annotated[B, tyro.conf.subcommand("command-b", default=B(9))]

@dataclasses.dataclass
class Nested1:
nested2: Nested2

@dataclasses.dataclass
class Parent:
nested1: Nested1

assert tyro.cli(
Parent,
args="nested1.nested2.subcommand:command-a".split(" "),
) == Parent(Nested1(Nested2(A(7))))
assert tyro.cli(
Parent,
args=(
"nested1.nested2.subcommand:command-a --nested1.nested2.subcommand.a 3".split(
" "
)
),
) == Parent(Nested1(Nested2(A(3))))

assert tyro.cli(
Parent,
args="nested1.nested2.subcommand:command-b".split(" "),
) == Parent(Nested1(Nested2(B(9))))
assert tyro.cli(
Parent,
args=(
"nested1.nested2.subcommand:command-b --nested1.nested2.subcommand.b 7".split(
" "
)
),
) == Parent(Nested1(Nested2(B(7))))


def test_subparser_in_nested_with_metadata_generic() -> None:
@dataclasses.dataclass(frozen=True)
class A:
Expand Down Expand Up @@ -1264,7 +1315,7 @@ def instantiate_dataclasses(
classes: Tuple[Type[T], ...], args: List[str]
) -> Tuple[T, ...]:
return tyro.cli(
tyro.conf.OmitArgPrefixes[
tyro.conf.OmitArgPrefixes[ # type: ignore
# Convert (type1, type2) into Tuple[type1, type2]
Tuple.__getitem__( # type: ignore
tuple(Annotated[c, tyro.conf.arg(name=c.__name__)] for c in classes)
Expand Down
25 changes: 25 additions & 0 deletions tests/test_py311_generated/test_helptext_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,31 @@ class Helptext:
assert "Documentation 3 (default: 3)" in helptext


def test_helptext_sphinx_autodoc_style() -> None:
@dataclasses.dataclass
class Helptext:
"""This docstring should be printed as a description."""

x: int #: Documentation 1

#:Documentation 2
y: Annotated[int, "ignored"]
z: int = 3

helptext = get_helptext(Helptext)
assert cast(str, helptext) in helptext
assert "x INT" in helptext
assert "y INT" in helptext
assert "z INT" in helptext
assert "Documentation 1 (required)" in helptext
assert ": Documentation 1" not in helptext
assert "Documentation 2 (required)" in helptext
assert ":Documentation 2" not in helptext

# :Documentation 2 should not be applied to `z`.
assert helptext.count("Documentation 2") == 1


def test_helptext_from_class_docstring() -> None:
@dataclasses.dataclass
class Helptext2:
Expand Down
92 changes: 92 additions & 0 deletions tests/test_py311_generated/test_self_type_generated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations

from typing import Self

import pytest

import tyro


class TestClass:
def __init__(self, a: int, b: int) -> None:
self.a = a
self.b = b

def method1(self, x: Self) -> None:
self.effect = x

@classmethod
def method2(cls, x: Self) -> TestClass:
return x

# Self is not valid in static methods.
# https://peps.python.org/pep-0673/#valid-locations-for-self
#
# @staticmethod
# def method3(x: Self) -> TestClass:
# return x


class TestSubclass(TestClass):
...


def test_method() -> None:
x = TestClass(0, 0)
with pytest.raises(SystemExit):
tyro.cli(x.method1, args=[])

assert tyro.cli(x.method1, args="--x.a 3 --x.b 3".split(" ")) is None
assert x.effect.a == 3 and x.effect.b == 3
assert isinstance(x, TestClass)


def test_classmethod() -> None:
x = TestClass(0, 0)
with pytest.raises(SystemExit):
tyro.cli(x.method2, args=[])
with pytest.raises(SystemExit):
tyro.cli(TestClass.method2, args=[])

y = tyro.cli(x.method2, args="--x.a 3 --x.b 3".split(" "))
assert y.a == 3
assert y.b == 3
assert isinstance(y, TestClass)

y = tyro.cli(TestClass.method2, args="--x.a 3 --x.b 3".split(" "))
assert y.a == 3
assert y.b == 3
assert isinstance(y, TestClass)


def test_subclass_method() -> None:
x = TestSubclass(0, 0)
with pytest.raises(SystemExit):
tyro.cli(x.method1, args=[])

assert tyro.cli(x.method1, args="--x.a 3 --x.b 3".split(" ")) is None
assert x.effect.a == 3 and x.effect.b == 3
assert isinstance(x, TestSubclass)

y = tyro.cli(x.method2, args="--x.a 3 --x.b 3".split(" "))
assert y.a == 3
assert y.b == 3
assert isinstance(y, TestClass)


def test_subclass_classmethod() -> None:
x = TestSubclass(0, 0)
with pytest.raises(SystemExit):
tyro.cli(x.method2, args=[])
with pytest.raises(SystemExit):
tyro.cli(TestSubclass.method2, args=[])

y = tyro.cli(x.method2, args="--x.a 3 --x.b 3".split(" "))
assert y.a == 3
assert y.b == 3
assert isinstance(y, TestClass)

y = tyro.cli(TestSubclass.method2, args="--x.a 3 --x.b 3".split(" "))
assert y.a == 3
assert y.b == 3
assert isinstance(y, TestClass)
91 changes: 91 additions & 0 deletions tests/test_self_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations

import pytest
from typing_extensions import Self

import tyro


class TestClass:
def __init__(self, a: int, b: int) -> None:
self.a = a
self.b = b

def method1(self, x: Self) -> None:
self.effect = x

@classmethod
def method2(cls, x: Self) -> TestClass:
return x

# Self is not valid in static methods.
# https://peps.python.org/pep-0673/#valid-locations-for-self
#
# @staticmethod
# def method3(x: Self) -> TestClass:
# return x


class TestSubclass(TestClass):
...


def test_method() -> None:
x = TestClass(0, 0)
with pytest.raises(SystemExit):
tyro.cli(x.method1, args=[])

assert tyro.cli(x.method1, args="--x.a 3 --x.b 3".split(" ")) is None
assert x.effect.a == 3 and x.effect.b == 3
assert isinstance(x, TestClass)


def test_classmethod() -> None:
x = TestClass(0, 0)
with pytest.raises(SystemExit):
tyro.cli(x.method2, args=[])
with pytest.raises(SystemExit):
tyro.cli(TestClass.method2, args=[])

y = tyro.cli(x.method2, args="--x.a 3 --x.b 3".split(" "))
assert y.a == 3
assert y.b == 3
assert isinstance(y, TestClass)

y = tyro.cli(TestClass.method2, args="--x.a 3 --x.b 3".split(" "))
assert y.a == 3
assert y.b == 3
assert isinstance(y, TestClass)


def test_subclass_method() -> None:
x = TestSubclass(0, 0)
with pytest.raises(SystemExit):
tyro.cli(x.method1, args=[])

assert tyro.cli(x.method1, args="--x.a 3 --x.b 3".split(" ")) is None
assert x.effect.a == 3 and x.effect.b == 3
assert isinstance(x, TestSubclass)

y = tyro.cli(x.method2, args="--x.a 3 --x.b 3".split(" "))
assert y.a == 3
assert y.b == 3
assert isinstance(y, TestClass)


def test_subclass_classmethod() -> None:
x = TestSubclass(0, 0)
with pytest.raises(SystemExit):
tyro.cli(x.method2, args=[])
with pytest.raises(SystemExit):
tyro.cli(TestSubclass.method2, args=[])

y = tyro.cli(x.method2, args="--x.a 3 --x.b 3".split(" "))
assert y.a == 3
assert y.b == 3
assert isinstance(y, TestClass)

y = tyro.cli(TestSubclass.method2, args="--x.a 3 --x.b 3".split(" "))
assert y.a == 3
assert y.b == 3
assert isinstance(y, TestClass)

0 comments on commit d660da2

Please sign in to comment.