Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran committed Aug 1, 2024
1 parent c9a8419 commit 0ad4879
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
18 changes: 13 additions & 5 deletions nucliadb_telemetry/src/nucliadb_telemetry/fastapi/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,17 @@ def get(self, carrier: dict, key: str) -> typing.Optional[typing.List[str]]:
# ASGI header keys are in lower case
key = key.lower()
decoded = [
_value.decode("utf8") for (_key, _value) in headers if _key.decode("utf8").lower() == key
_value.decode("utf8", errors="replace")
for (_key, _value) in headers
if _key.decode("utf8", errors="replace").lower() == key
]
if not decoded:
return None
return decoded

def keys(self, carrier: dict) -> typing.List[str]:
headers = carrier.get("headers") or []
return [_key.decode("utf8") for (_key, _) in headers]
return [_key.decode("utf8", errors="replace") for (_key, _) in headers]


asgi_getter = ASGIGetter()
Expand Down Expand Up @@ -125,7 +127,7 @@ def collect_request_attributes(scope):
query_string = scope.get("query_string")
if query_string and http_url:
if isinstance(query_string, bytes):
query_string = query_string.decode("utf8")
query_string = query_string.decode("utf8", errors="replace")
http_url += "?" + urllib.parse.unquote(query_string)

result = {
Expand Down Expand Up @@ -167,7 +169,10 @@ def collect_custom_request_headers_attributes(scope):
)

# Decode headers before processing.
headers = {_key.decode("utf8"): _value.decode("utf8") for (_key, _value) in scope.get("headers")}
headers = {
_key.decode("utf8", errors="replace"): _value.decode("utf8", errors="replace")
for (_key, _value) in scope.get("headers")
}

return sanitize.sanitize_header_values(
headers,
Expand All @@ -186,7 +191,10 @@ def collect_custom_response_headers_attributes(message):
)

# Decode headers before processing.
headers = {_key.decode("utf8"): _value.decode("utf8") for (_key, _value) in message.get("headers")}
headers = {
_key.decode("utf8", errors="replace"): _value.decode("utf8", errors="replace")
for (_key, _value) in message.get("headers")
}

return sanitize.sanitize_header_values(
headers,
Expand Down
10 changes: 9 additions & 1 deletion nucliadb_telemetry/tests/unit/fastapi/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import pytest
from opentelemetry.trace import format_trace_id

from nucliadb_telemetry.fastapi.tracing import CaptureTraceIdMiddleware
from nucliadb_telemetry.fastapi.tracing import (
CaptureTraceIdMiddleware,
collect_custom_request_headers_attributes,
)


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -59,3 +62,8 @@ async def test_capture_trace_id_middleware_appends_trace_id_header_to_exposed(tr
response = await mdw.dispatch(request, call_next)

assert response.headers["Access-Control-Expose-Headers"] == "Foo-Bar,X-Header,X-NUCLIA-TRACE-ID"


def test_collect_custom_request_headers_attributes():
scope = {"headers": [[b"x-filename", b"Synth\xe8ses\\3229-navigation.pdf"]]}
collect_custom_request_headers_attributes(scope)

0 comments on commit 0ad4879

Please sign in to comment.