Skip to content

Commit

Permalink
Upgrade all deps (#2550)
Browse files Browse the repository at this point in the history
* Upgrade all deps

* More realistic test for middleware

* Use plain ASGI middleware instead of BaseHttpMiddleware

* More updates

* Remove fixtrue
  • Loading branch information
jotare authored Oct 21, 2024
1 parent 25d42f2 commit 93e623f
Show file tree
Hide file tree
Showing 3 changed files with 1,473 additions and 1,268 deletions.
19 changes: 10 additions & 9 deletions nucliadb_telemetry/src/nucliadb_telemetry/fastapi/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,27 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response

from nucliadb_telemetry import context

from .utils import get_path_template


class ContextInjectorMiddleware(BaseHTTPMiddleware):
class ContextInjectorMiddleware:
"""
Automatically inject context values for the current request's path parameters
For example:
- `/api/v1/kb/{kbid}` would inject a context value for `kbid`
"""

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
found_path_template = get_path_template(request.scope)
if found_path_template.match:
context.add_context(found_path_template.scope.get("path_params", {})) # type: ignore
def __init__(self, app):
self.app = app

return await call_next(request)
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
found_path_template = get_path_template(scope)
if found_path_template.match:
context.add_context(found_path_template.scope.get("path_params", {})) # type: ignore

return await self.app(scope, receive, send)
29 changes: 9 additions & 20 deletions nucliadb_telemetry/tests/unit/fastapi/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from unittest.mock import AsyncMock
from unittest.mock import patch

import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient

from nucliadb_telemetry import context
from nucliadb_telemetry.fastapi.context import ContextInjectorMiddleware

app = FastAPI()
Expand All @@ -34,23 +34,12 @@ def get_kb(kbid: str):

@pytest.mark.asyncio
async def test_context_injected():
scope = {
"app": app,
"path": "/api/v1/kb/123",
"method": "GET",
"type": "http",
}
app.add_middleware(ContextInjectorMiddleware)

mdlw = ContextInjectorMiddleware(app)
transport = ASGITransport(app=app) # type: ignore
client = AsyncClient(transport=transport, base_url="http://test/api/v1")

found_ctx = {}

async def receive(*args, **kwargs):
found_ctx.update(context.get_context())
return {
"type": "http.disconnect",
}

await mdlw(scope, receive, AsyncMock())

assert found_ctx == {"kbid": "123"}
with patch("nucliadb_telemetry.fastapi.context.context.add_context") as add_context:
await client.get("/kb/123")
assert add_context.call_count == 1
assert add_context.call_args[0][0] == {"kbid": "123"}
Loading

0 comments on commit 93e623f

Please sign in to comment.