Skip to content

Commit

Permalink
feat: different dispatch instances on same function
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Jul 28, 2024
1 parent 70212b7 commit 2583d45
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 5 deletions.
244 changes: 239 additions & 5 deletions plum/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import sys
from dataclasses import dataclass, field
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field, replace
from functools import partial
from typing import Any, Dict, Optional, Tuple, TypeVar, Union, overload
from itertools import chain
from typing import Any, Dict, Optional, Tuple, TypeVar, Union, final, overload

from .function import Function
from .overload import get_overloads
from .signature import Signature
from .util import Callable, TypeHint, get_class, is_in_class

__all__ = ["Dispatcher", "dispatch", "clear_all_cache"]
__all__ = [
"AbstractDispatcher",
"Dispatcher",
"DispatcherBundle",
"dispatch",
"clear_all_cache",
]

T = TypeVar("T", bound=Callable[..., Any])

Expand All @@ -19,7 +27,39 @@


@dataclass(frozen=True, **_dataclass_kw_args)
class Dispatcher:
class AbstractDispatcher(metaclass=ABCMeta):
"""An abstract dispatcher."""

@overload
def __call__(self, method: T, precedence: int = ...) -> T: ...

@overload
def __call__(self, method: None, precedence: int) -> Callable[[T], T]: ...

@abstractmethod
def __call__(
self, method: Optional[T] = None, precedence: int = 0
) -> Union[T, Callable[[T], T]]: ...

@abstractmethod
def abstract(self, method: Callable) -> Function:
"""Decorator for an abstract function definition. The abstract function
definition does not implement any methods."""

@abstractmethod
def multi(
self, *signatures: Union[Signature, Tuple[TypeHint, ...]]
) -> Callable[[Callable], Function]:
"""Decorator to register multiple signatures at once."""

@abstractmethod
def clear_cache(self) -> None:
"""Clear cache."""


@final
@dataclass(frozen=True, **_dataclass_kw_args)
class Dispatcher(AbstractDispatcher):
"""A namespace for functions.
Args:
Expand Down Expand Up @@ -140,11 +180,205 @@ def _add_method(
f.register(method, signature, precedence)
return f

def clear_cache(self):
def clear_cache(self) -> None:
"""Clear cache."""
for f in self.functions.values():
f.clear_cache()

def __or__(self, other: "AbstractDispatcher") -> "DispatcherBundle":
if not isinstance(other, AbstractDispatcher):
raise ValueError(f"Cannot combine `Dispatcher` with `{type(other)}`.")
return DispatcherBundle.from_dispatchers(self, other)


@final
@dataclass(frozen=True, **_dataclass_kw_args)
class DispatcherBundle(AbstractDispatcher):
"""A bundle of dispatchers.
Examples
--------
A DispatcherBundle allows for different dispatchers to share a method, even
when their other methods are different. In this example `f(int, int)` is
shared between `dispatch1` and `dispatch2`, while the following methods are
unique to each dispatcher.
>>> from plum import Dispatcher
>>> from types import SimpleNamespace
In one namespace:
>>> dispatch1 = Dispatcher()
>>> @dispatch1
... def f(x: int, y: float) -> int:
... return x + int(y)
>>> ns1 = SimpleNamespace(f=f)
In another namespace:
>>> dispatch2 = Dispatcher()
>>> @dispatch2
... def f(x: float, y: int) -> float:
... return x + float(y)
>>> ns2 = SimpleNamespace(f=f)
Here we want to share the `f(int, int)` method between `dispatch1` and
`dispatch2`. This can be done with a `DispatcherBundle`, which combines the
dispatchers. A `DispatcherBundle` can be created with the `|` operator.
>>> @(dispatch1 | dispatch2)
... def f(x: int, y: int) -> int:
... return x + y
The function `f` is registered in both dispatchers.
>>> dispatch1.functions
{'f': ...}
>>> dispatch2.functions
{'f': ...}
.. note::
The :class:`plum.Function` object depends on the dispatch order. Here
`dispatch1` is the first dispatcher and `dispatch2` is the second.
Therefore, the returned function is the one registered last, which is
the one in `dispatch2`.
In application:
>>> ns1.f(1, 2) # From dispatch1/2, depending on the namespace.
3
>>> ns2.f(1, 2)
3
>>> ns1.f(1, 2.0) # From dispatch1.
3
>>> ns2.f(1.0, 2) # From dispatch2.
3.0
At least one dispatcher must be provided to `DispatcherBundle`.
>>> from plum import DispatcherBundle
>>> try:
... DispatcherBundle(())
... except ValueError as e:
... print(e)
At least one dispatcher must be provided to DispatcherBundle.
A `DispatcherBundle` can be created from a sequence of dispatchers.
>>> dispatchbundle = DispatcherBundle.from_dispatchers(dispatch1, dispatch2)
A nested `DispatcherBundle` can be flattened.
>>> dispatch3 = Dispatcher()
>>> dispatchbundle = DispatcherBundle((dispatchbundle, dispatch3))
>>> dispatchbundle
DispatcherBundle(dispatchers=(DispatcherBundle(dispatchers=(Dispatcher(...), Dispatcher(...))), Dispatcher(...)))
>>> dispatchbundle = dispatchbundle.flatten()
>>> dispatchbundle
DispatcherBundle(dispatchers=(Dispatcher(...), Dispatcher(...), Dispatcher(...)))
:class:`plum.DispatcherBundle`s can be combined with `|`. They are flattened
automatically.
>>> dispatch4 = Dispatcher()
>>> dispatchbundle1 = dispatch1 | dispatch2
>>> dispatchbundle2 = dispatch3 | dispatch4
>>> dispatchbundle = dispatchbundle1 | dispatchbundle2
>>> dispatchbundle
DispatcherBundle(dispatchers=(Dispatcher(...), Dispatcher(...), Dispatcher(...), Dispatcher(...)))
""" # noqa: E501

dispatchers: Tuple[AbstractDispatcher, ...]

def __post_init__(self) -> None:
if not self.dispatchers:
msg = "At least one dispatcher must be provided to DispatcherBundle."
raise ValueError(msg)

@classmethod
def from_dispatchers(cls, *dispatchers: AbstractDispatcher) -> "DispatcherBundle":
"""Create a `DispatcherBundle` from a sequence of dispatchers.
This also flattens nested `DispatcherBundle`s.
"""

return cls(dispatchers).flatten()

def flatten(self) -> "DispatcherBundle":
"""Flatten the bundle."""

def as_seq(x: AbstractDispatcher) -> Tuple[AbstractDispatcher, ...]:
return x.dispatchers if isinstance(x, DispatcherBundle) else (x,)

return replace(
self, dispatchers=tuple(chain.from_iterable(map(as_seq, self.dispatchers)))
)

@overload
def __call__(self, method: T, precedence: int = ...) -> T: ...

@overload
def __call__(self, method: None, precedence: int) -> Callable[[T], T]: ...

def __call__(
self, method: Optional[T] = None, precedence: int = 0
) -> Union[T, Callable[[T], T]]:
for dispatcher in self.dispatchers:
f = dispatcher(method, precedence=precedence)
return f

def abstract(self, method: Callable) -> Function:
"""Decorator for an abstract function definition. The abstract function
definition does not implement any methods."""
for dispatcher in self.dispatchers:
f = dispatcher.abstract(method)
return f

def multi(
self, *signatures: Union[Signature, Tuple[TypeHint, ...]]
) -> Callable[[Callable], Function]:
"""Decorator to register multiple signatures at once.
Args:
*signatures (tuple or :class:`.signature.Signature`): Signatures to
register.
Returns:
function: Decorator.
"""

def decorator(method: Callable) -> Function:
for dispatcher in self.dispatchers:
f = dispatcher.multi(*signatures)(method)
return f

return decorator

def clear_cache(self) -> None:
"""Clear cache."""
for dispatcher in self.dispatchers:
dispatcher.clear_cache()

def __or__(self, other: "AbstractDispatcher") -> "DispatcherBundle":
if not isinstance(other, AbstractDispatcher):
return NotImplemented
return self.from_dispatchers(self, other)


def clear_all_cache():
"""Clear all cache, including the cache of subclass checks. This should be called
Expand Down
38 changes: 38 additions & 0 deletions tests/test_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from numbers import Number
from types import SimpleNamespace

import pytest

from plum import Dispatcher
Expand Down Expand Up @@ -70,3 +73,38 @@ def f(x):

assert f.__doc__ == "Docs"
assert f.methods == []


def test_multiple_dispatchers_on_same_function():
dispatch1 = Dispatcher()
dispatch2 = Dispatcher()

@dispatch1.abstract
def f(x: Number, y: Number):
return x - 2 * y

@dispatch2.abstract
def f(x: Number, y: Number):
return x - y

@(dispatch2 | dispatch1)
def f(x: int, y: float):
return x + y

@dispatch1
def f(x: str):
return x

ns1 = SimpleNamespace(f=f)

@dispatch2
def f(x: int):
return x

ns2 = SimpleNamespace(f=f)

assert ns1.f("a") == "a"
assert ns1.f(1, 1.0) == 2.0

assert ns2.f(1) == 1
assert ns2.f(1, 1.0) == 2.0

0 comments on commit 2583d45

Please sign in to comment.