Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: issue #980 #982

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions core/cat/mad_hatter/mad_hatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,8 @@ def toggle_plugin(self, plugin_id):
# Execute hook on plugin deactivation
# Deactivation hook must happen before actual deactivation,
# otherwise the hook will not be available in _plugin_overrides anymore
for hook in self.plugins[plugin_id]._plugin_overrides:
if hook.name == "deactivated":
hook.function(self.plugins[plugin_id])
if "deactivated" in self.plugins[plugin_id]._plugin_overrides:
self.plugins[plugin_id]._plugin_overrides["deactivated"].function(self.plugins[plugin_id])

# Deactivate the plugin
self.plugins[plugin_id].deactivate()
Expand All @@ -221,9 +220,8 @@ def toggle_plugin(self, plugin_id):
# Execute hook on plugin activation
# Activation hook must happen before actual activation,
# otherwise the hook will still not be available in _plugin_overrides
for hook in self.plugins[plugin_id]._plugin_overrides:
if hook.name == "activated":
hook.function(self.plugins[plugin_id])
if "activated" in self.plugins[plugin_id]._plugin_overrides:
self.plugins[plugin_id]._plugin_overrides["activated"].function(self.plugins[plugin_id])

# Add the plugin in the list of active plugins
self.active_plugins.append(plugin_id)
Expand Down
41 changes: 18 additions & 23 deletions core/cat/mad_hatter/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def __init__(self, plugin_path: str):
self._forms: List[CatForm] = [] # list of plugin forms
self._endpoints: List[CustomEndpoint] = [] # list of plugin endpoints

# list of @plugin decorated functions overriding default plugin behaviour
self._plugin_overrides = [] # TODO: make this a dictionary indexed by func name, for faster access
# list of @plugin decorated functions overriding default plugin behaviour
self._plugin_overrides = {}

# plugin starts deactivated
self._active = False
Expand Down Expand Up @@ -101,40 +101,37 @@ def deactivate(self):
self._tools = []
self._forms = []
self._deactivate_endpoints()
self._plugin_overrides = []
self._plugin_overrides = {}
self._active = False

# get plugin settings JSON schema
def settings_schema(self):
# is "settings_schema" hook defined in the plugin?
for h in self._plugin_overrides:
if h.name == "settings_schema":
return h.function()
else:
# if the "settings_schema" is not defined but
# "settings_model" is it get the schema from the model
if h.name == "settings_model":
return h.function().model_json_schema()
if "settings_schema" in self._plugin_overrides:
return self._plugin_overrides["settings_schema"].function()
else:
# if the "settings_schema" is not defined but
# "settings_model" is it get the schema from the model
if "settings_model" in self._plugin_overrides:
return self._plugin_overrides["settings_model"].function().model_json_schema()

# default schema (empty)
return PluginSettingsModel.model_json_schema()

# get plugin settings Pydantic model
def settings_model(self):
# is "settings_model" hook defined in the plugin?
for h in self._plugin_overrides:
if h.name == "settings_model":
return h.function()
if "settings_model" in self._plugin_overrides:
return self._plugin_overrides["settings_model"].function()

# default schema (empty)
return PluginSettingsModel

# load plugin settings
def load_settings(self):
# is "settings_load" hook defined in the plugin?
for h in self._plugin_overrides:
if h.name == "load_settings":
return h.function()
if "load_settings" in self._plugin_overrides:
return self._plugin_overrides["load_settings"].function()

# by default, plugin settings are saved inside the plugin folder
# in a JSON file called settings.json
Expand All @@ -159,9 +156,8 @@ def load_settings(self):
# save plugin settings
def save_settings(self, settings: Dict):
# is "settings_save" hook defined in the plugin?
for h in self._plugin_overrides:
if h.name == "save_settings":
return h.function(settings)
if "save_settings" in self._plugin_overrides:
return self._plugin_overrides["save_settings"].function(settings)

# by default, plugin settings are saved inside the plugin folder
# in a JSON file called settings.json
Expand Down Expand Up @@ -331,9 +327,8 @@ def _load_decorated_functions(self):
self._tools = list(map(self._clean_tool, tools))
self._forms = list(map(self._clean_form, forms))
self._endpoints = list(map(self._clean_endpoint, endpoints))
self._plugin_overrides = list(
map(self._clean_plugin_override, plugin_overrides)
)
self._plugin_overrides = {override.name: override for override in list(map(self._clean_plugin_override, plugin_overrides))}


def plugin_specific_error_message(self):
name = self.manifest.get("name")
Expand Down