Skip to content

Commit

Permalink
Enhance AuthMiddleware, introduce @login_not_required decorator and…
Browse files Browse the repository at this point in the history
… `allow_routes`, deprecate `allow_paths` (#474)

* Enhance AuthMiddleware, introduce `@login_not_required` decorator and `allow_routes`, deprecate `allow_paths`

* Update pyproject.toml

* Fix CI

* Fix ci
  • Loading branch information
jowilf authored Jan 16, 2024
1 parent d30fb64 commit f124fec
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 56 deletions.
1 change: 1 addition & 0 deletions docs/api/auth/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
members:
- BaseAuthProvider
- AuthProvider
- login_not_required
14 changes: 6 additions & 8 deletions docs/tutorial/authentication/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,15 @@ Additionally, you can override these methods depending on your needs:
from typing import Optional

from starlette.datastructures import URL
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import RedirectResponse, Response
from starlette.routing import Route
from starlette_admin import BaseAdmin
from starlette_admin.auth import AdminUser, AuthMiddleware, AuthProvider
from starlette_admin.auth import (
AdminUser,
AuthProvider,
login_not_required,
)

from authlib.integrations.starlette_client import OAuth

Expand Down Expand Up @@ -181,6 +184,7 @@ class MyAuthProvider(AuthProvider):
)
)

@login_not_required
async def handle_auth_callback(self, request: Request):
auth0 = oauth.create_client("auth0")
token = await auth0.authorize_access_token(request)
Expand All @@ -198,12 +202,6 @@ class MyAuthProvider(AuthProvider):
name="authorize_auth0",
)
)

def get_middleware(self, admin: "BaseAdmin") -> Middleware:
return Middleware(
AuthMiddleware, provider=self, allow_paths=["/auth0/authorize"]
)

```

For a working example, have a look
Expand Down
13 changes: 6 additions & 7 deletions examples/authlib/provider.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Optional

from starlette.datastructures import URL
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import RedirectResponse, Response
from starlette.routing import Route
from starlette_admin import BaseAdmin
from starlette_admin.auth import AdminUser, AuthMiddleware, AuthProvider
from starlette_admin.auth import (
AdminUser,
AuthProvider,
login_not_required,
)

from authlib.integrations.starlette_client import OAuth

Expand Down Expand Up @@ -56,6 +59,7 @@ async def render_logout(self, request: Request, admin: BaseAdmin) -> Response:
)
)

@login_not_required
async def handle_auth_callback(self, request: Request):
auth0 = oauth.create_client("auth0")
token = await auth0.authorize_access_token(request)
Expand All @@ -73,8 +77,3 @@ def setup_admin(self, admin: "BaseAdmin"):
name="authorize_auth0",
)
)

def get_middleware(self, admin: "BaseAdmin") -> Middleware:
return Middleware(
AuthMiddleware, provider=self, allow_paths=["/auth0/authorize"]
)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ sqla_version = ["sqla14", "sqla2"]
matrix.sqla_version.dependencies = [
{ value = "SQLAlchemy[asyncio] >=2.0, <2.1", if = ["sqla2"] },
{ value = "SQLAlchemy[asyncio] >=1.4, <1.5", if = ["sqla14"] },
{ value = "starlette>=0.32,<0.33", if = ["sqla14"] },
]
matrix.sqla_version.scripts = [
{ key = "all", value = 'coverage run -m pytest tests --ignore=tests/odmantic', if = ["sqla2"] },
Expand Down
109 changes: 71 additions & 38 deletions starlette_admin/auth.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, Sequence
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Union

from packaging import version
from starlette import __version__ as starlette_version
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.routing import Route
from starlette.routing import Match, Mount, Route, WebSocketRoute
from starlette.status import (
HTTP_400_BAD_REQUEST,
HTTP_422_UNPROCESSABLE_ENTITY,
Expand All @@ -28,6 +28,16 @@
from starlette.types import ASGIApp


def login_not_required(
endpoint: Callable[..., Any],
) -> Callable[..., Any]:
"""Decorators for endpoints that do not require login."""

endpoint._login_not_required = True # type: ignore[attr-defined]

return endpoint


@dataclass
class AdminUser:
username: str = field(default_factory=lambda: _("Administrator"))
Expand All @@ -48,6 +58,11 @@ class BaseAuthProvider(ABC):
login_path: The path for the login page.
logout_path: The path for the logout page.
allow_paths: A list of paths that are allowed without authentication.
allow_routes: A list of route names that are allowed without authentication.
Warning:
- The usage of `allow_paths` is deprecated. It is recommended to use `allow_routes`
that specifies the route names instead.
"""

Expand All @@ -56,10 +71,19 @@ def __init__(
login_path: str = "/login",
logout_path: str = "/logout",
allow_paths: Optional[Sequence[str]] = None,
allow_routes: Optional[Sequence[str]] = None,
) -> None:
self.login_path = login_path
self.logout_path = logout_path
self.allow_paths = allow_paths
self.allow_routes = allow_routes

if allow_paths:
warnings.warn(
"`allow_paths` is deprecated. Use `allow_routes` instead.",
DeprecationWarning,
stacklevel=2,
)

@abstractmethod
def setup_admin(self, admin: "BaseAdmin") -> None:
Expand Down Expand Up @@ -270,7 +294,7 @@ def get_logout_route(self, admin: "BaseAdmin") -> Route:

def setup_admin(self, admin: "BaseAdmin") -> None:
"""
Setup the admin interface by adding necessary middleware and routes.
Set up the admin interface by adding necessary middleware and routes.
"""
admin.middlewares.append(self.get_middleware(admin=admin))
login_route = self.get_login_route(admin=admin)
Expand All @@ -286,49 +310,58 @@ def __init__(
app: ASGIApp,
provider: "BaseAuthProvider",
allow_paths: Optional[Sequence[str]] = None,
allow_routes: Optional[Sequence[str]] = None,
) -> None:
super().__init__(app)
self.provider = provider
self.allow_paths = list(allow_paths) if allow_paths is not None else []
self.allow_paths.extend(
[
self.provider.login_path,
"/statics/css/tabler.min.css",
"/statics/css/fontawesome.min.css",
"/statics/js/vendor/jquery.min.js",
"/statics/js/vendor/tabler.min.js",
"/statics/js/vendor/js.cookie.min.js",
]
) # Allow static files needed for the login page
self.allow_paths.extend(
self.provider.allow_paths if self.provider.allow_paths is not None else []
)

self.allow_routes = list(allow_routes) if allow_routes is not None else []
self.allow_routes.extend(["login", "statics"])
self.allow_routes.extend(
self.provider.allow_routes if self.provider.allow_routes is not None else []
)

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
request_path = request.scope["path"]
if version.parse(starlette_version) >= version.parse("0.33"):
"""In Starlette version 0.33, there's a change in the implementation of request.scope["path"],
which impacts this middleware. Discussions about this issue can be found at
https://github.com/encode/starlette/discussions/2361 and https://github.com/encode/starlette/pull/2400
The following line provides a temporary fix to address this change."""
_route_path_name = (
"root_path"
if version.parse(starlette_version) >= version.parse("0.35")
else "route_root_path"
)
request_path = request_path[len(request.scope.get("root_path")) :] # type: ignore[arg-type]
"""This middleware checks if the requested admin endpoint requires login.
If login is required, it redirects to the login page when the user is
not authenticated.
if request_path not in self.allow_paths and not (
await self.provider.is_authenticated(request)
):
# TODO: Improve the implementation in the future to eliminate the need for request.scope['path']
return RedirectResponse(
"{url}?{query_params}".format(
url=request.url_for(request.app.state.ROUTE_NAME + ":login"),
query_params=urlencode({"next": str(request.url)}),
),
status_code=HTTP_303_SEE_OTHER,
Endpoints are authorized without login if:
- They are decorated with `@login_not_required`
- Their path is in `allow_paths`
- Their name is in `allow_routes`
- The user is already authenticated
"""
_admin_app: Starlette = request.scope["app"]
current_route: Optional[Union[Route, Mount, WebSocketRoute]] = None
for route in _admin_app.routes:
match, _ = route.matches(request.scope)
if match == Match.FULL:
assert isinstance(route, (Route, Mount, WebSocketRoute))
current_route = route
break
if (
(current_route is not None and current_route.path in self.allow_paths)
or (current_route is not None and current_route.name in self.allow_routes)
or (
current_route is not None
and hasattr(current_route, "endpoint")
and getattr(current_route.endpoint, "_login_not_required", False)
)
return await call_next(request)
or await self.provider.is_authenticated(request)
):
return await call_next(request)
return RedirectResponse(
"{url}?{query_params}".format(
url=request.url_for(request.app.state.ROUTE_NAME + ":login"),
query_params=urlencode({"next": str(request.url)}),
),
status_code=HTTP_303_SEE_OTHER,
)
38 changes: 36 additions & 2 deletions tests/auth_provider.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import Optional

from starlette.requests import Request
from starlette.responses import Response
from starlette_admin.auth import AdminConfig, AdminUser, AuthProvider
from starlette.responses import PlainTextResponse, Response
from starlette.routing import Route
from starlette_admin import BaseAdmin
from starlette_admin.auth import (
AdminConfig,
AdminUser,
AuthProvider,
login_not_required,
)
from starlette_admin.exceptions import FormValidationError, LoginFailed

users = {
Expand Down Expand Up @@ -53,3 +60,30 @@ def get_admin_config(self, request: Request) -> Optional[AdminConfig]:
async def logout(self, request: Request, response: Response):
response.delete_cookie("session")
return response

@login_not_required
async def public_route_async(self, request: Request):
return PlainTextResponse("async public route")

@login_not_required
def public_route_sync(self, request: Request):
return PlainTextResponse("sync public route")

def setup_admin(self, admin: "BaseAdmin"):
super().setup_admin(admin)
admin.routes.extend(
[
Route(
"/public_sync",
self.public_route_sync,
methods=["GET"],
name="public_sync",
),
Route(
"/public_async",
self.public_route_async,
methods=["GET"],
name="public_async",
),
]
)
17 changes: 17 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ async def test_custom_login_path(self):
== "http://testserver/admin/custom-login?next=http%3A%2F%2Ftestserver%2Fadmin%2F"
)

def test_deprecated_allow_paths(self):
with pytest.warns(DeprecationWarning, match="`allow_paths` is deprecated"):
MyAuthProvider(allow_paths=["/custom-path"])

@pytest.mark.asyncio
async def test_invalid_login(self):
admin = BaseAdmin(auth_provider=MyAuthProvider())
Expand Down Expand Up @@ -284,6 +288,19 @@ async def test_access_model_view_delete(self, client):
)
assert response.status_code == 200

@pytest.mark.asyncio
@pytest.mark.parametrize(
"route_path,response_text",
[
("public_sync", "sync public route"),
("public_async", "async public route"),
],
)
async def test_public_route(self, client, route_path, response_text):
response = await client.get(f"/admin/{route_path}")
assert response.status_code == 200
assert response.text == response_text


class TestFieldAccess:
def setup_method(self, method):
Expand Down

0 comments on commit f124fec

Please sign in to comment.