From 230fe9fd359beecaa0904adba42059da83b65157 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Thu, 17 Aug 2023 21:17:25 +0000 Subject: [PATCH 1/2] :wrench: Group logic concerning the cache in one single method --- plum/function.py | 146 +++++++++++++++++++++++++---------------------- 1 file changed, 77 insertions(+), 69 deletions(-) diff --git a/plum/function.py b/plum/function.py index 150d0d68..723419b2 100644 --- a/plum/function.py +++ b/plum/function.py @@ -316,14 +316,44 @@ def _enhance_exception(self, e: SomeExceptionType) -> SomeExceptionType: message = str(e) return type(e)(prefix + message[0].lower() + message[1:]) + def _resolve_method_with_cache( + self, + args: Union[Tuple[object, ...], Signature, None] = None, + types: Optional[Tuple[TypeHint, ...]] = None, + ) -> Tuple[Callable, TypeHint]: + if args is None and types is None: + raise ValueError("args and types cannot be None, this should not happen") + + # Before attempting to use the cache, resolve any unresolved registrations. Use + # an `if`-statement to speed up the common case. + if self._pending: + self._resolve_pending_registrations() + + if types is None: + # Attempt to use the cache based on the types of the arguments. + types = tuple(map(type, args)) + try: + return self._cache[types] + except KeyError: + if args is None: + args = Signature(*(resolve_type_hint(t) for t in types)) + + # Cache miss. Run the resolver based on the arguments. + method, return_type = self.resolve_method(args) + # If the resolver is faithful, + # then we can perform caching using the types of + # the arguments. If the resolver is not faithful, then we cannot. + if self._resolver.is_faithful: + self._cache[types] = method, return_type + return method, return_type + def resolve_method( - self, target: Union[Tuple[object, ...], Signature], types: Tuple[TypeHint] + self, target: Union[Tuple[object, ...], Signature] ) -> Tuple[Callable, TypeHint]: """Find the method and return type for arguments. Args: target (object): Target. - types (tuple[type, ...]): Types of the arguments. Returns: function: Method. @@ -342,69 +372,57 @@ def resolve_method( except NotFoundLookupError as e: e = self._enhance_exception(e) # Specify this function. + method, return_type = self._handle_not_found_lookup_error(e) - if not self.owner: - # Not in a class. Nothing we can do. - raise e + return method, return_type + + def _handle_not_found_lookup_error( + self, ex: NotFoundLookupError + ) -> Tuple[Callable, TypeHint]: + if not self.owner: + # Not in a class. Nothing we can do. + raise ex + + # In a class. Walk through the classes in the class's MRO, except for + # this class, and try to get the method. + method = None + return_type = object + + for c in self.owner.__mro__[1:]: + # Skip the top of the type hierarchy given by `object` and `type`. + # We do not suddenly want to fall back to any unexpected default + # behaviour. + if c in {object, type}: + continue + + # We need to check `c.__dict__` here instead of using `hasattr` + # since e.g. `c.__le__` will return even if `c` does not implement + # `__le__`! + if self._f.__name__ in c.__dict__: + method = getattr(c, self._f.__name__) else: - # In a class. Walk through the classes in the class's MRO, except for - # this class, and try to get the method. + # For some reason, coverage fails to catch the `continue` + # below. Add the do-nothing `_ = None` fixes this. + # TODO: Remove this once coverage properly catches this. + _ = None + continue + + # Ignore abstract methods. + if getattr(method, "__isabstractmethod__", False): method = None - return_type = object - - for c in self.owner.__mro__[1:]: - # Skip the top of the type hierarchy given by `object` and `type`. - # We do not suddenly want to fall back to any unexpected default - # behaviour. - if c in {object, type}: - continue - - # We need to check `c.__dict__` here instead of using `hasattr` - # since e.g. `c.__le__` will return even if `c` does not implement - # `__le__`! - if self._f.__name__ in c.__dict__: - method = getattr(c, self._f.__name__) - else: - # For some reason, coverage fails to catch the `continue` - # below. Add the do-nothing `_ = None` fixes this. - # TODO: Remove this once coverage properly catches this. - _ = None - continue - - # Ignore abstract methods. - if getattr(method, "__isabstractmethod__", False): - method = None - continue - - # We found a good candidate. Break. - break - - if not method: - # If no method has been found after walking through the MRO, raise - # the original exception. - raise e - - # If the resolver is faithful, then we can perform caching using the types of - # the arguments. If the resolver is not faithful, then we cannot. - if self._resolver.is_faithful: - self._cache[types] = method, return_type + continue + # We found a good candidate. Break. + break + + if not method: + # If no method has been found after walking through the MRO, raise + # the original exception. + raise ex return method, return_type def __call__(self, *args, **kw_args): - # Before attempting to use the cache, resolve any unresolved registrations. Use - # an `if`-statement to speed up the common case. - if self._pending: - self._resolve_pending_registrations() - - # Attempt to use the cache based on the types of the arguments. - types = tuple(map(type, args)) - try: - method, return_type = self._cache[types] - except KeyError: - # Cache miss. Run the resolver based on the arguments. - method, return_type = self.resolve_method(args, types) - + method, return_type = self._resolve_method_with_cache(args=args) return _convert(method(*args, **kw_args), return_type) def invoke(self, *types: TypeHint) -> Callable: @@ -416,17 +434,7 @@ def invoke(self, *types: TypeHint) -> Callable: Returns: function: Method. """ - # Do this before attempting to cache. See above. - if self._pending: - self._resolve_pending_registrations() - - # Attempt to use the cache based on the types. - try: - method, return_type = self._cache[types] - except KeyError: - # Cache miss. Run the resolver based on the types. - sig_types = Signature(*(resolve_type_hint(t) for t in types)) - method, return_type = self.resolve_method(sig_types, types) + method, return_type = self._resolve_method_with_cache(types=types) @wraps(self._f) def wrapped_method(*args, **kw_args): From 762b2b129ff32c0dc17b1784031dfb939f5b3084 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Fri, 18 Aug 2023 12:13:05 +0000 Subject: [PATCH 2/2] Add small unit test --- tests/test_function.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_function.py b/tests/test_function.py index aa09ddf2..96b1089b 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -87,6 +87,16 @@ def f(x): assert Function(f, owner="A").owner is A +def test_resolve_method_with_cache_no_arguments(): + def f(x): + pass + + with pytest.raises(ValueError) as err: + Function(f)._resolve_method_with_cache() + + assert "args and types cannot be None" in str(err) + + @pytest.fixture() def owner_transfer(): # Save and clear.