Skip to content

Commit

Permalink
feat: different dispatch instances on same function
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Jul 28, 2024
1 parent 70212b7 commit 0793cf8
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 4 deletions.
191 changes: 187 additions & 4 deletions plum/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys
from dataclasses import dataclass, field
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field, replace
from functools import partial
from typing import Any, Dict, Optional, Tuple, TypeVar, Union, overload
from itertools import chain
from typing import Any, Dict, Optional, Tuple, TypeVar, Union, final, overload

from .function import Function
from .overload import get_overloads
Expand All @@ -19,7 +21,39 @@


@dataclass(frozen=True, **_dataclass_kw_args)
class Dispatcher:
class AbstractDispatcher(metaclass=ABCMeta):
"""An abstract dispatcher."""

@overload
def __call__(self, method: T, precedence: int = ...) -> T: ...

@overload
def __call__(self, method: None, precedence: int) -> Callable[[T], T]: ...

@abstractmethod
def __call__(
self, method: Optional[T] = None, precedence: int = 0
) -> Union[T, Callable[[T], T]]: ...

@abstractmethod
def abstract(self, method: Callable) -> Function:
"""Decorator for an abstract function definition. The abstract function
definition does not implement any methods."""

@abstractmethod
def multi(
self, *signatures: Union[Signature, Tuple[TypeHint, ...]]
) -> Callable[[Callable], Function]:
"""Decorator to register multiple signatures at once."""

@abstractmethod
def clear_cache(self) -> None:
"""Clear cache."""


@final
@dataclass(frozen=True, **_dataclass_kw_args)
class Dispatcher(AbstractDispatcher):
"""A namespace for functions.
Args:
Expand Down Expand Up @@ -140,11 +174,160 @@ def _add_method(
f.register(method, signature, precedence)
return f

def clear_cache(self):
def clear_cache(self) -> None:
"""Clear cache."""
for f in self.functions.values():
f.clear_cache()

def __or__(self, other: "AbstractDispatcher") -> "DispatcherBundle":
if not isinstance(other, AbstractDispatcher):
raise ValueError(f"Cannot combine `Dispatcher` with `{type(other)}`.")
return DispatcherBundle.from_dispatchers(self, other)


@final
@dataclass(frozen=True, **_dataclass_kw_args)
class DispatcherBundle(AbstractDispatcher):
"""A bundle of dispatchers.
Examples
--------
>>> from plum import Dispatcher, DispatcherBundle
>>> dispatch1 = Dispatcher()
>>> dispatch2 = Dispatcher()
>>> dispatchbundle = dispatch1 | dispatch2
Some Notes:
At least one dispatcher must be provided to `DispatcherBundle`.
>>> try:
... DispatcherBundle()
... except ValueError as e:
... print(e)
At least one dispatcher must be provided to DispatcherBundle.
A `DispatcherBundle` can be created from a sequence of dispatchers.
>>> dispatchbundle = DispatcherBundle.from_dispatchers(dispatch1, dispatch2)
A nested `DispatcherBundle` can be flattened.
>>> dispatch3 = Dispatcher()
>>> dispatchbundle = DispatcherBundle.from_dispatchers(dispatchbundle, dispatch3)
>>> dispatchbundle
DispatcherBundle(dispatchers=(
<DispatcherBundle(dispatchers=(
<Dispatcher functions={}, classes={} warn_redefinition=False>,
<Dispatcher functions={}, classes={} warn_redefinition=False>))>,
<Dispatcher functions={}, classes={} warn_redefinition=False>)
)
>>> dispatchbundle = dispatchbundle.flatten()
DispatcherBundle(dispatchers=(
<Dispatcher functions={}, classes={} warn_redefinition=False>,
<Dispatcher functions={}, classes={} warn_redefinition=False>,
<Dispatcher functions={}, classes={} warn_redefinition=False>)
)
DispatchBundles can be combined with `|`.
>>> dispatch4 = Dispatcher()
>>> dispatchbundle1 = dispatch1 | dispatch2
>>> dispatchbundle2 = dispatch3 | dispatch4
>>> dispatchbundle = dispatchbundle1 | dispatchbundle2
>>> dispatchbundle
DispatcherBundle(dispatchers=(
<DispatcherBundle(dispatchers=(
<Dispatcher functions={}, classes={} warn_redefinition=False>,
<Dispatcher functions={}, classes={} warn_redefinition=False>))>,
<DispatcherBundle(dispatchers=(
<Dispatcher functions={}, classes={} warn_redefinition=False>,
<Dispatcher functions={}, classes={} warn_redefinition=False>))>)
)
"""

dispatchers: Tuple[AbstractDispatcher, ...]

def __post_init__(self) -> None:
if not self.dispatchers:
msg = "At least one dispatcher must be provided to DispatcherBundle."
raise ValueError(msg)

@classmethod
def from_dispatchers(cls, *dispatchers: AbstractDispatcher) -> "DispatcherBundle":
"""Create a `DispatcherBundle` from a sequence of dispatchers.
This also flattens nested `DispatcherBundle`s.
"""

return cls(dispatchers).flatten()

def flatten(self) -> "DispatcherBundle":
"""Flatten the bundle."""

def as_seq(x: AbstractDispatcher) -> Tuple[AbstractDispatcher, ...]:
return x.dispatchers if isinstance(x, DispatcherBundle) else (x,)

return replace(
self, dispatchers=tuple(chain.from_iterable(map(as_seq, self.dispatchers)))
)

@overload
def __call__(self, method: T, precedence: int = ...) -> T: ...

@overload
def __call__(self, method: None, precedence: int) -> Callable[[T], T]: ...

def __call__(
self, method: Optional[T] = None, precedence: int = 0
) -> Union[T, Callable[[T], T]]:
for dispatcher in self.dispatchers:
f = dispatcher(method, precedence=precedence)
return f

def abstract(self, method: Callable) -> Function:
"""Decorator for an abstract function definition. The abstract function
definition does not implement any methods."""
for dispatcher in self.dispatchers:
f = dispatcher.abstract(method)
return f

def multi(
self, *signatures: Union[Signature, Tuple[TypeHint, ...]]
) -> Callable[[Callable], Function]:
"""Decorator to register multiple signatures at once.
Args:
*signatures (tuple or :class:`.signature.Signature`): Signatures to
register.
Returns:
function: Decorator.
"""

def decorator(method: Callable) -> Function:
for dispatcher in self.dispatchers:
f = dispatcher.multi(*signatures)(method)
return f

return decorator

def clear_cache(self) -> None:
"""Clear cache."""
for dispatcher in self.dispatchers:
dispatcher.clear_cache()

def __or__(self, other: "AbstractDispatcher") -> "DispatcherBundle":
if not isinstance(other, AbstractDispatcher):
return NotImplemented
return self.from_dispatchers(self, other)


def clear_all_cache():
"""Clear all cache, including the cache of subclass checks. This should be called
Expand Down
39 changes: 39 additions & 0 deletions tests/test_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from numbers import Number
from types import SimpleNamespace

import pytest

from plum import Dispatcher
Expand Down Expand Up @@ -70,3 +73,39 @@ def f(x):

assert f.__doc__ == "Docs"
assert f.methods == []


def test_multiple_dispatchers_on_same_function():
dispatch1 = Dispatcher()
dispatch2 = Dispatcher()

@dispatch1.abstract
def f(x: Number, y: Number):
return x - 2 * y

@dispatch2.abstract
def f(x: Number, y: Number):
return x - y

@dispatch2
@dispatch1
def f(x: int, y: float):
return x + y

@dispatch1
def f(x: str):
return x

ns1 = SimpleNamespace(f=f)

@dispatch2
def f(x: int):
return x

ns2 = SimpleNamespace(f=f)

assert ns1.f("a") == "a"
assert ns1.f(1, 1.0) == 2.0

assert ns2.f(1) == 1
assert ns2.f(1, 1.0) == 2.0

0 comments on commit 0793cf8

Please sign in to comment.