Skip to content

Commit

Permalink
cache nodes, limit recursion depth
Browse files Browse the repository at this point in the history
  • Loading branch information
Tishka17 committed Oct 25, 2024
1 parent 48152f1 commit 2a7c085
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 35 deletions.
5 changes: 4 additions & 1 deletion src/dishka/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
):
self.registry = registry
self.child_registries = child_registries
self._context = {DependencyKey(type(self), DEFAULT_COMPONENT): self}
self._context = {CONTAINER_KEY: self}
if context:
for key, value in context.items():
if not isinstance(key, DependencyKey):
Expand Down Expand Up @@ -252,3 +252,6 @@ def make_container(
close_parent=True,
)
return container


CONTAINER_KEY = DependencyKey(Container, DEFAULT_COMPONENT)
42 changes: 27 additions & 15 deletions src/dishka/graph_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from .text_rendering import get_name


MAX_DEPTH = 5 # max code depth, otherwise we get too big file


class Node(FactoryData):
__slots__ = (
"dependencies",
Expand Down Expand Up @@ -106,7 +109,6 @@ def make_args(args: list[str], kwargs: dict[str, str]) -> str:
FactoryType.VALUE: VALUE,
FactoryType.CONTEXT: CONTEXT,
FactoryType.ALIAS: ALIAS,
None: GO_PARENT,
}
FUNC_TEMPLATE = """
{async_}def {func_name}(getter, exits, context):
Expand All @@ -116,19 +118,19 @@ def make_args(args: list[str], kwargs: dict[str, str]) -> str:
"""

IF_TEMPLATE = """
if {var} := cache_getter({key}, None):
pass # cache found
else:
if ({var} := cache_getter({key}, ...)) is ...:
{deps}
{body}
{cache}
"""
CACHE = "context[{key}] = {var}"


builtins = {getattr(__builtins__, name): name for name in dir(__builtins__)}
def make_name(obj: Any, ns: dict[Any, str]) -> str:
if obj in builtins:
return builtins[obj]
if isinstance(obj, DependencyKey):
key = get_name(obj.type_hint, include_module=False) + obj.component
key = get_name(obj.type_hint, include_module=False) +"_"+ obj.component
else:
key = get_name(obj, include_module=False)
key = re.sub(r"\W", "_", key)
Expand All @@ -153,24 +155,33 @@ def make_var(node: Node, ns: dict[Any, str]):


def make_if(
node: Node, node_var: str, ns: dict[Any, str], is_async: bool,
node: Node, node_var: str, ns: dict[Any, str],
is_async: bool,
depth: int,
) -> str:
node_key = ns[node.provides]
node_source = ns[node.source]
if depth > MAX_DEPTH or node.type is None:
if is_async:
return GO_PARENT.format(
var=node_var,
key=node_real_key,
)
else:
return GO_PARENT.format(
var=node_var,
key=node_key,
)

deps = "".join(
make_if(dep, make_var(dep, ns), ns, is_async)
make_if(dep, make_var(dep, ns), ns, is_async, depth+1)
for dep in node.dependencies
)
deps += "".join(
make_if(dep, make_var(dep, ns), ns, is_async)
make_if(dep, make_var(dep, ns), ns, is_async, depth+1)
for dep in node.kw_dependencies.values()
)
deps = indent(deps, " ")
if node.cache:
cache = CACHE.format(var=node_var, key=node_key)
else:
cache = "# no cache"

args = [make_var(dep, ns) for dep in node.dependencies]
kwargs = {
Expand All @@ -192,6 +203,7 @@ def make_if(
)

if node.cache:
cache = CACHE.format(var=node_var, key=node_key)
body_str = indent(body_str, " ")
return IF_TEMPLATE.format(
var=node_var,
Expand All @@ -201,14 +213,14 @@ def make_if(
cache=cache,
)
else:
return "\n".join([deps, body_str, cache])
return "\n".join([deps, body_str])


def make_func(
node: Node, ns: dict[Any, str], func_name: str, is_async: bool,
) -> str:
node_var = make_var(node, ns)
body = make_if(node, node_var, ns, is_async)
body = make_if(node, node_var, ns, is_async, 0)
body = indent(body, " ")
return FUNC_TEMPLATE.format(
async_="async " if is_async else "",
Expand Down
43 changes: 24 additions & 19 deletions src/dishka/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import time
from collections.abc import Callable
from linecache import cache
from typing import Any, TypeVar, get_args, get_origin

from pydantic.v1 import compiled

from ._adaptix.type_tools.fundamentals import get_type_vars
from .container_objects import CompiledFactory
Expand All @@ -12,7 +13,6 @@
from .entities.factory_type import FactoryType
from .entities.key import DependencyKey
from .entities.scope import BaseScope
from .factory_compiler import compile_factory
from .graph_compiler import Node, compile_graph


Expand Down Expand Up @@ -153,10 +153,12 @@ def _specialize_generic(
)


def make_node(registry: Registry, key: DependencyKey) -> Node:
def make_node(registry: Registry, key: DependencyKey, cache: dict| None = None) -> Node:
if cache is None:
cache = {}
factory = registry.get_factory(key)
if not factory:
return Node(
node = Node(
provides=key,
scope=registry.scope,
type_=None,
Expand All @@ -165,18 +167,21 @@ def make_node(registry: Registry, key: DependencyKey) -> Node:
cache=False,
source=None,
)
return Node(
provides=factory.provides,
scope=factory.scope,
source=factory.source,
type_=factory.type,
cache=factory.cache,
dependencies=[
make_node(registry, dep)
for dep in factory.dependencies
],
kw_dependencies={
key: make_node(registry, dep)
for key, dep in factory.kw_dependencies.items()
},
)
else:
node = Node(
provides=factory.provides,
scope=factory.scope,
source=factory.source,
type_=factory.type,
cache=factory.cache,
dependencies=[
make_node(registry, dep, cache)
for dep in factory.dependencies
],
kw_dependencies={
key: make_node(registry, dep, cache)
for key, dep in factory.kw_dependencies.items()
},
)
cache[key] = node
return node

0 comments on commit 2a7c085

Please sign in to comment.