diff --git a/src/aleph/vm/orchestrator/views/__init__.py b/src/aleph/vm/orchestrator/views/__init__.py index 4c729d6f2..8b7702ba4 100644 --- a/src/aleph/vm/orchestrator/views/__init__.py +++ b/src/aleph/vm/orchestrator/views/__init__.py @@ -11,7 +11,7 @@ import aiodns import aiohttp from aiohttp import web -from aiohttp.web_exceptions import HTTPNotFound +from aiohttp.web_exceptions import HTTPBadRequest, HTTPNotFound from aleph_message.exceptions import UnknownHashError from aleph_message.models import ItemHash, MessageType from pydantic import ValidationError @@ -65,7 +65,13 @@ async def run_code_from_path(request: web.Request) -> web.Response: path = request.match_info["suffix"] path = path if path.startswith("/") else f"/{path}" - message_ref = ItemHash(request.match_info["ref"]) + try: + message_ref = ItemHash(request.match_info["ref"]) + except UnknownHashError as e: + raise HTTPBadRequest( + reason="Invalid message reference", text=f"Invalid message reference: {request.match_info['ref']}" + ) from e + pool: VmPool = request.app["vm_pool"] return await run_code_on_request(message_ref, path, pool, request) @@ -98,8 +104,10 @@ async def run_code_from_hostname(request: web.Request) -> web.Response: try: message_ref = ItemHash(await get_ref_from_dns(domain=f"_aleph-id.{request.host}")) logger.debug(f"Using DNS TXT record to obtain '{message_ref}'") - except aiodns.error.DNSError as error: - raise HTTPNotFound(reason="Invalid message reference") from error + except aiodns.error.DNSError: + return HTTPNotFound(reason="Invalid message reference") + except UnknownHashError: + return HTTPNotFound(reason="Invalid message reference") pool = request.app["vm_pool"] return await run_code_on_request(message_ref, path, pool, request) diff --git a/tests/supervisor/views/test_run_code.py b/tests/supervisor/views/test_run_code.py new file mode 100644 index 000000000..639a8f7bf --- /dev/null +++ b/tests/supervisor/views/test_run_code.py @@ -0,0 +1,46 @@ +import pytest +from aiohttp import ClientResponseError, web +from aiohttp.test_utils import make_mocked_request +from aiohttp.web_exceptions import HTTPBadRequest +from aleph_message.exceptions import UnknownHashError +from aleph_message.models import ItemHash + +from aleph.vm.conf import settings +from aleph.vm.orchestrator.views import run_code_from_path + + +@pytest.mark.asyncio +async def test_run_code_from_invalid_path(aiohttp_client): + """ + Test that the run_code_from_path endpoint raises the right + error on invalid paths. + """ + item_hash = "invalid-item-hash" + with pytest.raises(UnknownHashError): + assert ItemHash(item_hash).is_storage(item_hash) + + app = web.Application() + + app.router.add_route("*", "/vm/{ref}{suffix:.*}", run_code_from_path), + client = await aiohttp_client(app) + + invalid_hash_request: web.Request = make_mocked_request( + "GET", + "/vm/" + item_hash, + match_info={ + "ref": item_hash, + "suffix": "/some/suffix", + }, + headers={"Host": settings.DOMAIN_NAME}, + app=app, + ) + with pytest.raises(HTTPBadRequest): + await run_code_from_path(invalid_hash_request) + + # Calling the view from an HTTP client should result in a Bad Request error. + resp = await client.get("/vm/" + item_hash + "/some/suffix") + assert resp.status == HTTPBadRequest.status_code + text = await resp.text() + assert text == f"Invalid message reference: {item_hash}" + with pytest.raises(ClientResponseError): + resp.raise_for_status()