From 19a4af40ddac36530ba9865f7e42d11fcd60fc7e Mon Sep 17 00:00:00 2001 From: Zech Zimmerman Date: Wed, 1 Nov 2023 16:46:11 -0400 Subject: [PATCH] Add lazy dependency annotation resolution, resolves #8 --- bevy/injectors/classes.py | 31 ++++++++++++++++++++++++++----- tests/test_bevy.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/bevy/injectors/classes.py b/bevy/injectors/classes.py index 974aa13..d5528e1 100644 --- a/bevy/injectors/classes.py +++ b/bevy/injectors/classes.py @@ -9,6 +9,28 @@ _K = TypeVar("_K") +class LazyAnnotationResolver: + def __init__(self, annotation, cls): + self.annotation = annotation + self.cls = cls + self._resolved_type = Null() if isinstance(annotation, str) else Value(annotation) + + @property + def resolved_type(self) -> Option[Type]: + match self._resolved_type: + case Value() as resolved_type: + return resolved_type + + case Null(): + resolved_type = self._resolve() + self._resolved_type = Value(resolved_type) + return self._resolved_type + + def _resolve(self): + ns = _get_class_namespace(self.cls) + return ns[self.annotation] + + class Dependency(Generic[_K]): """This class can be used to indicate fields that need to be injected. It also acts as a descriptor that can discover the key that needs to be injected and handle injecting the corresponding instance for that key when @@ -19,7 +41,7 @@ class Dependency(Generic[_K]): """ def __init__(self): - self._key: Option[_K] = Null() + self._key_resolver: LazyAnnotationResolver | None = LazyAnnotationResolver(None, None) def __get__(self, instance: object, owner: Type): if instance is None: @@ -31,12 +53,11 @@ def __get__(self, instance: object, owner: Type): return self._inject_dependency() def __set_name__(self, owner: Type, name: str): - ns = _get_class_namespace(owner) - annotations = get_annotations(owner, globals=ns, eval_str=True) - self._key = Value(annotations[name]) + annotations = get_annotations(owner) + self._key_resolver = LazyAnnotationResolver(annotations[name], owner) def _inject_dependency(self): - match self._key: + match self._key_resolver.resolved_type: case Value(key): repo = get_repository() return repo.get(key) diff --git a/tests/test_bevy.py b/tests/test_bevy.py index 0cb6557..9c9f745 100644 --- a/tests/test_bevy.py +++ b/tests/test_bevy.py @@ -1,3 +1,5 @@ +import sys +from dataclasses import dataclass from typing import Annotated from pytest import fixture @@ -192,3 +194,29 @@ def test_context_forking_inheritance(repository): fork = repository.fork_context() assert fork.get(int) is repository.get(int) + + +def test_dataclass_dependency_injection(): + class Dep: + ... + + @dataclass + class Test: + dep: Dep = dependency() + + repo = Repository.factory() + Repository.set_repository(repo) + + dep = Dep() + repo.set(Dep, dep) + + inst = Test() + assert inst.dep is dep + + +def test_forward_references(): + class Test: + dep: "Dep" = dependency() + + class Dep: + ...