Skip to content

Commit

Permalink
#444: mad_hatter hooks cache as a dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
pieroit committed Sep 13, 2023
1 parent 0b389eb commit b94d6cd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 20 deletions.
31 changes: 21 additions & 10 deletions core/cat/mad_hatter/mad_hatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, ccat):

self.plugins = {} # plugins dictionary

self.hooks = [] # list of active plugins hooks
self.hooks = {} # dict of active plugins hooks ( hook_name -> [CatHook, CatHook, ...])
self.tools = [] # list of active plugins tools

self.active_plugins = []
Expand Down Expand Up @@ -116,7 +116,7 @@ def load_plugin(self, plugin_path):
def sync_hooks_and_tools(self):

# emptying tools and hooks
self.hooks = []
self.hooks = {}
self.tools = []

for _, plugin in self.plugins.items():
Expand All @@ -128,11 +128,18 @@ def sync_hooks_and_tools(self):
# Prepare the tool to be used in the Cat (setting the cat instance, adding properties)
t.augment_tool(self.ccat)

self.hooks += plugin.hooks
# cache tools
self.tools += plugin.tools

# sort hooks by priority
self.hooks.sort(key=lambda x: x.priority, reverse=True)
# cache hooks (indexed by hook name)
for h in plugin.hooks:
if h.name not in self.hooks.keys():
self.hooks[h.name] = []
self.hooks[h.name].append(h)

# sort each hooks list by priority
for hook_name in self.hooks.keys():
self.hooks[hook_name].sort(key=lambda x: x.priority, reverse=True)

# check if plugin exists
def plugin_exists(self, plugin_id):
Expand Down Expand Up @@ -238,9 +245,13 @@ def toggle_plugin(self, plugin_id):

# execute requested hook
def execute_hook(self, hook_name, *args):
for h in self.hooks:
if hook_name == h.name:
return h.function(*args, cat=self.ccat)

# every hook must have a default in core_plugin
raise Exception(f"Hook {hook_name} not present in any plugin")
# check if hook is supported
if hook_name not in self.hooks.keys():
raise Exception(f"Hook {hook_name} not present in any plugin")

# run hooks
for h in self.hooks[hook_name]:
return h.function(*args, cat=self.ccat)
# TODO: should be run as a pipe, not return immediately

25 changes: 15 additions & 10 deletions core/tests/mad_hatter/test_mad_hatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ def test_instantiation_discovery(mad_hatter):
assert "core_plugin" in mad_hatter.load_active_plugins_from_db()

# finds hooks
assert len(mad_hatter.hooks) > 0
for h in mad_hatter.hooks:
assert len(mad_hatter.hooks.keys()) > 0
for hook_name, hooks_list in mad_hatter.hooks.items():
assert len(hooks_list) == 1 # core plugin implements each hook
h = hooks_list[0]
assert isinstance(h, CatHook)
assert h.plugin_id == "core_plugin"
assert type(h.name) == str
Expand Down Expand Up @@ -86,12 +88,14 @@ def test_plugin_install(mad_hatter: MadHatter, plugin_is_flat):
assert new_tool.plugin_id == "mock_plugin"

# found tool and hook have been cached
assert new_tool in mad_hatter.tools
assert new_hook in mad_hatter.hooks

# new hook has correct priority and has been sorted by mad_hatter as first
assert new_hook.priority == 2
assert id(new_hook) == id(mad_hatter.hooks[0]) # same object in memory!
assert id(new_tool) == id(mad_hatter.tools[1]) # same object in memory!
mock_hook_name = "before_cat_sends_message"
assert len(mad_hatter.hooks[mock_hook_name]) == 2
cached_hook = mad_hatter.hooks[mock_hook_name][0] # correctly sorted by priority
assert cached_hook.name == mock_hook_name
assert cached_hook.plugin_id == "mock_plugin"
assert cached_hook.priority == 2
assert id(new_hook) == id(cached_hook) # same object in memory!

# list of active plugins in DB is correct
active_plugins = mad_hatter.load_active_plugins_from_db()
Expand Down Expand Up @@ -130,8 +134,9 @@ def test_plugin_uninstall(mad_hatter: MadHatter, plugin_is_flat):
assert "mock_plugin" not in mad_hatter.plugins.keys()
# plugin cache updated (only core_plugin stuff)
assert len(mad_hatter.tools) == 1 # default tool
for h in mad_hatter.hooks:
assert h.plugin_id == "core_plugin"
for h_name, h_list in mad_hatter.hooks.items():
assert len(h_list) == 1
assert h_list[0].plugin_id == "core_plugin"

# list of active plugins in DB is correct
active_plugins = mad_hatter.load_active_plugins_from_db()
Expand Down

0 comments on commit b94d6cd

Please sign in to comment.