Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add register_middleware hook #50

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions fps/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,72 @@ def _load_routers(app: FastAPI) -> None:
logger.info("No plugin API router to load")


def _load_middlewares(app: FastAPI) -> None:

pm = _get_pluggin_manager(HookType.MIDDLEWARE)

grouped_middlewares = _grouped_hookimpls_results(pm.hook.middleware)

if grouped_middlewares:

pkg_names = {get_pkg_name(p, strip_fps=False) for p in grouped_middlewares}
logger.info(f"Loading middlewares from plugin package(s) {pkg_names}")

middleware_dict = {}
for p, middlewares in grouped_middlewares.items():
p_name = Config.plugin_name(p)
plugin_config = Config.from_name(p_name)

disabled = (
plugin_config
and not plugin_config.enabled
or p_name in Config(FPSConfig).disabled_plugins
or (
Config(FPSConfig).enabled_plugins
and p_name not in Config(FPSConfig).enabled_plugins
)
)
if not middlewares or disabled:
disabled_msg = " (disabled)" if disabled else ""
logger.info(
f"No middleware registered for plugin '{p_name}'{disabled_msg}"
)
continue

logger.info(f"Registered middleware(s) for plugin '{p_name}':")
for middleware in middlewares:
logger.info(
f"Middleware: {middleware.__module__}.{middleware.__qualname__}"
)

middleware_dict.update(
{
f"{middleware.__module__}.{middleware.__qualname__}": middleware
for middleware in middlewares
}
)

middleware_cnt = 0
for middleware in Config(FPSConfig).middlewares:
middleware_class_path = middleware.class_path
if middleware_class_path not in middleware_dict:
logger.warning(f"Unknown middleware {middleware_class_path}")
continue

logger.info(f"Adding middleware {middleware_class_path}")
middleware_class = middleware_dict[middleware_class_path]
app.add_middleware(
middleware_class,
**middleware.kwargs,
)
middleware_cnt += 1

logger.info(f"{middleware_cnt} middleware(s) added")

else:
logger.info("No plugin middleware to load")


def create_app():

logging.getLogger("fps")
Expand All @@ -283,6 +349,7 @@ def create_app():

_load_routers(app)
_load_exceptions_handlers(app)
_load_middlewares(app)

Config.check_not_used_sections()

Expand Down
10 changes: 9 additions & 1 deletion fps/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from collections import OrderedDict
from types import ModuleType
from typing import Dict, List, Tuple, Type
from typing import Any, Dict, List, Tuple, Type

import toml
from pydantic import BaseModel, create_model, validator
Expand All @@ -25,6 +25,11 @@ def create_default_plugin_model(plugin_name: str):
return create_model(f"{plugin_name}Model", __base__=PluginModel)


class MiddlewareModel(BaseModel):
class_path: str
kwargs: Dict[str, Any] = {}


class FPSConfig(BaseModel):
# fastapi
title: str = "FPS"
Expand All @@ -35,6 +40,9 @@ class FPSConfig(BaseModel):
enabled_plugins: List[str] = []
disabled_plugins: List[str] = []

# plugin middlewares
middlewares: List[MiddlewareModel] = []

@validator("enabled_plugins", "disabled_plugins")
def plugins_format(cls, plugins):
warnings = [p for p in plugins if p.startswith("[") or p.endswith("]")]
Expand Down
15 changes: 15 additions & 0 deletions fps/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class HookType(Enum):
ROUTER = "fps_router"
CONFIG = "fps_config"
EXCEPTION = "fps_exception"
MIDDLEWARE = "fps_middleware"


@pluggy.HookspecMarker(HookType.ROUTER.value)
Expand Down Expand Up @@ -75,3 +76,17 @@ def plugin_name_callback() -> str:
return pluggy.HookimplMarker(HookType.CONFIG.value)(
function=plugin_name_callback, specname="plugin_name"
)


@pluggy.HookspecMarker(HookType.MIDDLEWARE.value)
def middleware() -> type:
pass


def register_middleware(m: type):
def middleware_callback() -> type:
return m

return pluggy.HookimplMarker(HookType.MIDDLEWARE.value)(
function=middleware_callback, specname="middleware"
)
3 changes: 2 additions & 1 deletion plugins/uvicorn/fps_uvicorn/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import threading
import webbrowser
from pathlib import Path
from typing import Any, Dict, List

import toml
Expand Down Expand Up @@ -86,7 +87,7 @@ def store_extra_options(options: Dict[str, Any]):
f_name = "fps_cli_args.toml"
with open(f_name, "w") as f:
toml.dump(opts, f)
os.environ["FPS_CLI_CONFIG_FILE"] = f_name
os.environ["FPS_CLI_CONFIG_FILE"] = str(Path(f_name).resolve())


@app.command(
Expand Down