diff --git a/plum/dispatcher.py b/plum/dispatcher.py index afd7f35..89bde09 100644 --- a/plum/dispatcher.py +++ b/plum/dispatcher.py @@ -1,7 +1,9 @@ import sys -from dataclasses import dataclass, field +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass, field, replace from functools import partial -from typing import Any, Dict, Optional, Tuple, TypeVar, Union, overload +from itertools import chain +from typing import Any, Dict, Optional, Tuple, TypeVar, Union, final, overload from .function import Function from .overload import get_overloads @@ -19,7 +21,39 @@ @dataclass(frozen=True, **_dataclass_kw_args) -class Dispatcher: +class AbstractDispatcher(metaclass=ABCMeta): + """An abstract dispatcher.""" + + @overload + def __call__(self, method: T, precedence: int = ...) -> T: ... + + @overload + def __call__(self, method: None, precedence: int) -> Callable[[T], T]: ... + + @abstractmethod + def __call__( + self, method: Optional[T] = None, precedence: int = 0 + ) -> Union[T, Callable[[T], T]]: ... + + @abstractmethod + def abstract(self, method: Callable) -> Function: + """Decorator for an abstract function definition. The abstract function + definition does not implement any methods.""" + + @abstractmethod + def multi( + self, *signatures: Union[Signature, Tuple[TypeHint, ...]] + ) -> Callable[[Callable], Function]: + """Decorator to register multiple signatures at once.""" + + @abstractmethod + def clear_cache(self) -> None: + """Clear cache.""" + + +@final +@dataclass(frozen=True, **_dataclass_kw_args) +class Dispatcher(AbstractDispatcher): """A namespace for functions. Args: @@ -140,11 +174,160 @@ def _add_method( f.register(method, signature, precedence) return f - def clear_cache(self): + def clear_cache(self) -> None: """Clear cache.""" for f in self.functions.values(): f.clear_cache() + def __or__(self, other: "AbstractDispatcher") -> "DispatcherBundle": + if not isinstance(other, AbstractDispatcher): + raise ValueError(f"Cannot combine `Dispatcher` with `{type(other)}`.") + return DispatcherBundle.from_dispatchers(self, other) + + +@final +@dataclass(frozen=True, **_dataclass_kw_args) +class DispatcherBundle(AbstractDispatcher): + """A bundle of dispatchers. + + Examples + -------- + >>> from plum import Dispatcher, DispatcherBundle + + >>> dispatch1 = Dispatcher() + >>> dispatch2 = Dispatcher() + + >>> dispatchbundle = dispatch1 | dispatch2 + + Some Notes: + + At least one dispatcher must be provided to `DispatcherBundle`. + + >>> try: + ... DispatcherBundle() + ... except ValueError as e: + ... print(e) + At least one dispatcher must be provided to DispatcherBundle. + + + A `DispatcherBundle` can be created from a sequence of dispatchers. + + >>> dispatchbundle = DispatcherBundle.from_dispatchers(dispatch1, dispatch2) + + A nested `DispatcherBundle` can be flattened. + + >>> dispatch3 = Dispatcher() + >>> dispatchbundle = DispatcherBundle.from_dispatchers(dispatchbundle, dispatch3) + >>> dispatchbundle + DispatcherBundle(dispatchers=( + , + ))>, + ) + ) + + + >>> dispatchbundle = dispatchbundle.flatten() + DispatcherBundle(dispatchers=( + , + , + ) + ) + + DispatchBundles can be combined with `|`. + + >>> dispatch4 = Dispatcher() + >>> dispatchbundle1 = dispatch1 | dispatch2 + >>> dispatchbundle2 = dispatch3 | dispatch4 + >>> dispatchbundle = dispatchbundle1 | dispatchbundle2 + >>> dispatchbundle + DispatcherBundle(dispatchers=( + , + ))>, + , + ))>) + ) + + """ + + dispatchers: Tuple[AbstractDispatcher, ...] + + def __post_init__(self) -> None: + if not self.dispatchers: + msg = "At least one dispatcher must be provided to DispatcherBundle." + raise ValueError(msg) + + @classmethod + def from_dispatchers(cls, *dispatchers: AbstractDispatcher) -> "DispatcherBundle": + """Create a `DispatcherBundle` from a sequence of dispatchers. + + This also flattens nested `DispatcherBundle`s. + """ + + return cls(dispatchers).flatten() + + def flatten(self) -> "DispatcherBundle": + """Flatten the bundle.""" + + def as_seq(x: AbstractDispatcher) -> Tuple[AbstractDispatcher, ...]: + return x.dispatchers if isinstance(x, DispatcherBundle) else (x,) + + return replace( + self, dispatchers=tuple(chain.from_iterable(map(as_seq, self.dispatchers))) + ) + + @overload + def __call__(self, method: T, precedence: int = ...) -> T: ... + + @overload + def __call__(self, method: None, precedence: int) -> Callable[[T], T]: ... + + def __call__( + self, method: Optional[T] = None, precedence: int = 0 + ) -> Union[T, Callable[[T], T]]: + for dispatcher in self.dispatchers: + f = dispatcher(method, precedence=precedence) + return f + + def abstract(self, method: Callable) -> Function: + """Decorator for an abstract function definition. The abstract function + definition does not implement any methods.""" + for dispatcher in self.dispatchers: + f = dispatcher.abstract(method) + return f + + def multi( + self, *signatures: Union[Signature, Tuple[TypeHint, ...]] + ) -> Callable[[Callable], Function]: + """Decorator to register multiple signatures at once. + + Args: + *signatures (tuple or :class:`.signature.Signature`): Signatures to + register. + + Returns: + function: Decorator. + """ + + def decorator(method: Callable) -> Function: + for dispatcher in self.dispatchers: + f = dispatcher.multi(*signatures)(method) + return f + + return decorator + + def clear_cache(self) -> None: + """Clear cache.""" + for dispatcher in self.dispatchers: + dispatcher.clear_cache() + + def __or__(self, other: "AbstractDispatcher") -> "DispatcherBundle": + if not isinstance(other, AbstractDispatcher): + return NotImplemented + return self.from_dispatchers(self, other) + def clear_all_cache(): """Clear all cache, including the cache of subclass checks. This should be called diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index 5d861a5..5a5d5df 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -1,3 +1,6 @@ +from numbers import Number +from types import SimpleNamespace + import pytest from plum import Dispatcher @@ -70,3 +73,38 @@ def f(x): assert f.__doc__ == "Docs" assert f.methods == [] + + +def test_multiple_dispatchers_on_same_function(): + dispatch1 = Dispatcher() + dispatch2 = Dispatcher() + + @dispatch1.abstract + def f(x: Number, y: Number): + return x - 2 * y + + @dispatch2.abstract + def f(x: Number, y: Number): + return x - y + + @(dispatch2 | dispatch1) + def f(x: int, y: float): + return x + y + + @dispatch1 + def f(x: str): + return x + + ns1 = SimpleNamespace(f=f) + + @dispatch2 + def f(x: int): + return x + + ns2 = SimpleNamespace(f=f) + + assert ns1.f("a") == "a" + assert ns1.f(1, 1.0) == 2.0 + + assert ns2.f(1) == 1 + assert ns2.f(1, 1.0) == 2.0