Skip to content

Commit

Permalink
Add lazy dependency annotation resolution, resolves #8
Browse files Browse the repository at this point in the history
  • Loading branch information
ZechCodes committed Nov 1, 2023
1 parent be95663 commit 19a4af4
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
31 changes: 26 additions & 5 deletions bevy/injectors/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,28 @@
_K = TypeVar("_K")


class LazyAnnotationResolver:
def __init__(self, annotation, cls):
self.annotation = annotation
self.cls = cls
self._resolved_type = Null() if isinstance(annotation, str) else Value(annotation)

@property
def resolved_type(self) -> Option[Type]:
match self._resolved_type:
case Value() as resolved_type:
return resolved_type

case Null():
resolved_type = self._resolve()
self._resolved_type = Value(resolved_type)
return self._resolved_type

def _resolve(self):
ns = _get_class_namespace(self.cls)
return ns[self.annotation]


class Dependency(Generic[_K]):
"""This class can be used to indicate fields that need to be injected. It also acts as a descriptor that can
discover the key that needs to be injected and handle injecting the corresponding instance for that key when
Expand All @@ -19,7 +41,7 @@ class Dependency(Generic[_K]):
"""

def __init__(self):
self._key: Option[_K] = Null()
self._key_resolver: LazyAnnotationResolver | None = LazyAnnotationResolver(None, None)

def __get__(self, instance: object, owner: Type):
if instance is None:
Expand All @@ -31,12 +53,11 @@ def __get__(self, instance: object, owner: Type):
return self._inject_dependency()

def __set_name__(self, owner: Type, name: str):
ns = _get_class_namespace(owner)
annotations = get_annotations(owner, globals=ns, eval_str=True)
self._key = Value(annotations[name])
annotations = get_annotations(owner)
self._key_resolver = LazyAnnotationResolver(annotations[name], owner)

def _inject_dependency(self):
match self._key:
match self._key_resolver.resolved_type:
case Value(key):
repo = get_repository()
return repo.get(key)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_bevy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys
from dataclasses import dataclass
from typing import Annotated

from pytest import fixture
Expand Down Expand Up @@ -192,3 +194,29 @@ def test_context_forking_inheritance(repository):

fork = repository.fork_context()
assert fork.get(int) is repository.get(int)


def test_dataclass_dependency_injection():
class Dep:
...

@dataclass
class Test:
dep: Dep = dependency()

repo = Repository.factory()
Repository.set_repository(repo)

dep = Dep()
repo.set(Dep, dep)

inst = Test()
assert inst.dep is dep


def test_forward_references():
class Test:
dep: "Dep" = dependency()

class Dep:
...

0 comments on commit 19a4af4

Please sign in to comment.