Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔧 Group logic concerning the cache in one single method #96

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 77 additions & 69 deletions plum/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,14 +316,44 @@ def _enhance_exception(self, e: SomeExceptionType) -> SomeExceptionType:
message = str(e)
return type(e)(prefix + message[0].lower() + message[1:])

def _resolve_method_with_cache(
self,
args: Union[Tuple[object, ...], Signature, None] = None,
types: Optional[Tuple[TypeHint, ...]] = None,
) -> Tuple[Callable, TypeHint]:
if args is None and types is None:
raise ValueError("args and types cannot be None, this should not happen")

# Before attempting to use the cache, resolve any unresolved registrations. Use
# an `if`-statement to speed up the common case.
if self._pending:
self._resolve_pending_registrations()

if types is None:
# Attempt to use the cache based on the types of the arguments.
types = tuple(map(type, args))
try:
return self._cache[types]
except KeyError:
if args is None:
args = Signature(*(resolve_type_hint(t) for t in types))

# Cache miss. Run the resolver based on the arguments.
method, return_type = self.resolve_method(args)
# If the resolver is faithful,
# then we can perform caching using the types of
# the arguments. If the resolver is not faithful, then we cannot.
if self._resolver.is_faithful:
self._cache[types] = method, return_type
return method, return_type

def resolve_method(
self, target: Union[Tuple[object, ...], Signature], types: Tuple[TypeHint]
self, target: Union[Tuple[object, ...], Signature]
) -> Tuple[Callable, TypeHint]:
"""Find the method and return type for arguments.

Args:
target (object): Target.
types (tuple[type, ...]): Types of the arguments.

Returns:
function: Method.
Expand All @@ -342,69 +372,57 @@ def resolve_method(

except NotFoundLookupError as e:
e = self._enhance_exception(e) # Specify this function.
method, return_type = self._handle_not_found_lookup_error(e)

if not self.owner:
# Not in a class. Nothing we can do.
raise e
return method, return_type

def _handle_not_found_lookup_error(
self, ex: NotFoundLookupError
) -> Tuple[Callable, TypeHint]:
if not self.owner:
# Not in a class. Nothing we can do.
raise ex

# In a class. Walk through the classes in the class's MRO, except for
# this class, and try to get the method.
method = None
return_type = object

for c in self.owner.__mro__[1:]:
# Skip the top of the type hierarchy given by `object` and `type`.
# We do not suddenly want to fall back to any unexpected default
# behaviour.
if c in {object, type}:
continue

# We need to check `c.__dict__` here instead of using `hasattr`
# since e.g. `c.__le__` will return even if `c` does not implement
# `__le__`!
if self._f.__name__ in c.__dict__:
method = getattr(c, self._f.__name__)
else:
# In a class. Walk through the classes in the class's MRO, except for
# this class, and try to get the method.
# For some reason, coverage fails to catch the `continue`
# below. Add the do-nothing `_ = None` fixes this.
# TODO: Remove this once coverage properly catches this.
_ = None
continue

# Ignore abstract methods.
if getattr(method, "__isabstractmethod__", False):
method = None
return_type = object

for c in self.owner.__mro__[1:]:
# Skip the top of the type hierarchy given by `object` and `type`.
# We do not suddenly want to fall back to any unexpected default
# behaviour.
if c in {object, type}:
continue

# We need to check `c.__dict__` here instead of using `hasattr`
# since e.g. `c.__le__` will return even if `c` does not implement
# `__le__`!
if self._f.__name__ in c.__dict__:
method = getattr(c, self._f.__name__)
else:
# For some reason, coverage fails to catch the `continue`
# below. Add the do-nothing `_ = None` fixes this.
# TODO: Remove this once coverage properly catches this.
_ = None
continue

# Ignore abstract methods.
if getattr(method, "__isabstractmethod__", False):
method = None
continue

# We found a good candidate. Break.
break

if not method:
# If no method has been found after walking through the MRO, raise
# the original exception.
raise e

# If the resolver is faithful, then we can perform caching using the types of
# the arguments. If the resolver is not faithful, then we cannot.
if self._resolver.is_faithful:
self._cache[types] = method, return_type
continue

# We found a good candidate. Break.
break

if not method:
# If no method has been found after walking through the MRO, raise
# the original exception.
raise ex
return method, return_type

def __call__(self, *args, **kw_args):
# Before attempting to use the cache, resolve any unresolved registrations. Use
# an `if`-statement to speed up the common case.
if self._pending:
self._resolve_pending_registrations()

# Attempt to use the cache based on the types of the arguments.
types = tuple(map(type, args))
try:
method, return_type = self._cache[types]
except KeyError:
# Cache miss. Run the resolver based on the arguments.
method, return_type = self.resolve_method(args, types)

method, return_type = self._resolve_method_with_cache(args=args)
return _convert(method(*args, **kw_args), return_type)

def invoke(self, *types: TypeHint) -> Callable:
Expand All @@ -416,17 +434,7 @@ def invoke(self, *types: TypeHint) -> Callable:
Returns:
function: Method.
"""
# Do this before attempting to cache. See above.
if self._pending:
self._resolve_pending_registrations()

# Attempt to use the cache based on the types.
try:
method, return_type = self._cache[types]
except KeyError:
# Cache miss. Run the resolver based on the types.
sig_types = Signature(*(resolve_type_hint(t) for t in types))
method, return_type = self.resolve_method(sig_types, types)
method, return_type = self._resolve_method_with_cache(types=types)

@wraps(self._f)
def wrapped_method(*args, **kw_args):
Expand Down
Loading