diff --git a/fps/app.py b/fps/app.py index 8ad2a0b..cf203db 100644 --- a/fps/app.py +++ b/fps/app.py @@ -271,6 +271,72 @@ def _load_routers(app: FastAPI) -> None: logger.info("No plugin API router to load") +def _load_middlewares(app: FastAPI) -> None: + + pm = _get_pluggin_manager(HookType.MIDDLEWARE) + + grouped_middlewares = _grouped_hookimpls_results(pm.hook.middleware) + + if grouped_middlewares: + + pkg_names = {get_pkg_name(p, strip_fps=False) for p in grouped_middlewares} + logger.info(f"Loading middlewares from plugin package(s) {pkg_names}") + + middleware_dict = {} + for p, middlewares in grouped_middlewares.items(): + p_name = Config.plugin_name(p) + plugin_config = Config.from_name(p_name) + + disabled = ( + plugin_config + and not plugin_config.enabled + or p_name in Config(FPSConfig).disabled_plugins + or ( + Config(FPSConfig).enabled_plugins + and p_name not in Config(FPSConfig).enabled_plugins + ) + ) + if not middlewares or disabled: + disabled_msg = " (disabled)" if disabled else "" + logger.info( + f"No middleware registered for plugin '{p_name}'{disabled_msg}" + ) + continue + + logger.info(f"Registered middleware(s) for plugin '{p_name}':") + for middleware in middlewares: + logger.info( + f"Middleware: {middleware.__module__}.{middleware.__qualname__}" + ) + + middleware_dict.update( + { + f"{middleware.__module__}.{middleware.__qualname__}": middleware + for middleware in middlewares + } + ) + + middleware_cnt = 0 + for middleware in Config(FPSConfig).middlewares: + middleware_class_path = middleware.class_path + if middleware_class_path not in middleware_dict: + logger.warning(f"Unknown middleware {middleware_class_path}") + continue + + logger.info(f"Adding middleware {middleware_class_path}") + middleware_class = middleware_dict[middleware_class_path] + app.add_middleware( + middleware_class, + **middleware.kwargs, + ) + middleware_cnt += 1 + + logger.info(f"{middleware_cnt} middleware(s) added") + + else: + logger.info("No plugin middleware to load") + + def create_app(): logging.getLogger("fps") @@ -283,6 +349,7 @@ def create_app(): _load_routers(app) _load_exceptions_handlers(app) + _load_middlewares(app) Config.check_not_used_sections() diff --git a/fps/config.py b/fps/config.py index 6b54485..f89cba4 100644 --- a/fps/config.py +++ b/fps/config.py @@ -2,7 +2,7 @@ import os from collections import OrderedDict from types import ModuleType -from typing import Dict, List, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import toml from pydantic import BaseModel, create_model, validator @@ -25,6 +25,11 @@ def create_default_plugin_model(plugin_name: str): return create_model(f"{plugin_name}Model", __base__=PluginModel) +class MiddlewareModel(BaseModel): + class_path: str + kwargs: Dict[str, Any] = {} + + class FPSConfig(BaseModel): # fastapi title: str = "FPS" @@ -35,6 +40,9 @@ class FPSConfig(BaseModel): enabled_plugins: List[str] = [] disabled_plugins: List[str] = [] + # plugin middlewares + middlewares: List[MiddlewareModel] = [] + @validator("enabled_plugins", "disabled_plugins") def plugins_format(cls, plugins): warnings = [p for p in plugins if p.startswith("[") or p.endswith("]")] diff --git a/fps/hooks.py b/fps/hooks.py index 835f64b..5a22f13 100644 --- a/fps/hooks.py +++ b/fps/hooks.py @@ -12,6 +12,7 @@ class HookType(Enum): ROUTER = "fps_router" CONFIG = "fps_config" EXCEPTION = "fps_exception" + MIDDLEWARE = "fps_middleware" @pluggy.HookspecMarker(HookType.ROUTER.value) @@ -75,3 +76,17 @@ def plugin_name_callback() -> str: return pluggy.HookimplMarker(HookType.CONFIG.value)( function=plugin_name_callback, specname="plugin_name" ) + + +@pluggy.HookspecMarker(HookType.MIDDLEWARE.value) +def middleware() -> type: + pass + + +def register_middleware(m: type): + def middleware_callback() -> type: + return m + + return pluggy.HookimplMarker(HookType.MIDDLEWARE.value)( + function=middleware_callback, specname="middleware" + ) diff --git a/plugins/uvicorn/fps_uvicorn/cli.py b/plugins/uvicorn/fps_uvicorn/cli.py index 5c7a54a..fa09f75 100644 --- a/plugins/uvicorn/fps_uvicorn/cli.py +++ b/plugins/uvicorn/fps_uvicorn/cli.py @@ -2,6 +2,7 @@ import os import threading import webbrowser +from pathlib import Path from typing import Any, Dict, List import toml @@ -86,7 +87,7 @@ def store_extra_options(options: Dict[str, Any]): f_name = "fps_cli_args.toml" with open(f_name, "w") as f: toml.dump(opts, f) - os.environ["FPS_CLI_CONFIG_FILE"] = f_name + os.environ["FPS_CLI_CONFIG_FILE"] = str(Path(f_name).resolve()) @app.command(