diff --git a/backend/pyproject.toml b/backend/pyproject.toml index d4ccba1..1e07ae8 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -40,12 +40,14 @@ lint = [ "ruff==0.4.3", ] check = [ - "pyright==1.1.366", + "pyright==1.1.367", + "pytest == 8.2.0", # import pytest in tests ] test = [ "pytest==8.2.0", "coverage==7.5.1", "pytest-mock==3.14.0", + "trio == 0.25.1" ] dev = [ "ipython==8.25.0", @@ -198,8 +200,8 @@ extend-immutable-calls = ["fastapi.Depends", "fastapi.Query"] ban-relative-imports = "all" [tool.ruff.lint.per-file-ignores] -# Tests can use magic values, assertions, and relative imports -"tests/**/*" = ["PLR2004", "S101", "TID252"] +# Tests can use magic values, assertions, relative imports, print, and unused args (mock) +"tests/**/*" = ["PLR2004", "S101", "TID252","T201", "ARG001", "ARG002"] "**/migrations/**/*" = ["F401", "ISC001"] [tool.pytest.ini_options] diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 1343938..b30a9a8 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,16 +1,26 @@ import datetime import os +import urllib.parse +import uuid +from collections.abc import AsyncGenerator +from http import HTTPStatus from io import BytesIO from pathlib import Path +from typing import Any import pytest # pyright: ignore [reportMissingImports] -from fastapi.testclient import TestClient +import requests +from httpx import AsyncClient +from starlette.testclient import TestClient from api.database import Session -from api.database.models import Archive, File, Project, User +from api.database.models import Archive, ArchiveConfig, File, Project, User from api.entrypoint import app from api.files import save_file from api.routes.archives import ArchiveStatus +from api.s3 import s3_storage + +pytestmark = pytest.mark.asyncio(scope="package") @pytest.fixture() @@ -33,6 +43,19 @@ def client(): return TestClient(app) +@pytest.fixture(scope="module") # pyright: ignore +async def aclient() -> AsyncGenerator[AsyncClient, Any]: + async with AsyncClient(app=app, base_url="http://localhost") as client: + yield client + + +@pytest.fixture() +async def alogged_in_client(user_id: str): + async with AsyncClient(app=app, base_url="http://localhost") as client: + client.cookies = {"user_id": str(user_id)} + yield client + + @pytest.fixture def non_existent_project_id(): return "94e430c6-8888-456a-9440-c10e4a04627c" @@ -48,23 +71,28 @@ def missing_user_cookie(missing_user_id): return {"user_id": missing_user_id} -@pytest.fixture +@pytest.fixture() def test_project_name(): return "test_project_name" -@pytest.fixture +@pytest.fixture() +def test_expiring_project_name(): + return "test_expiring_project_name" + + +@pytest.fixture() def test_archive_name(): return "test_archive_name.zim" -@pytest.fixture +@pytest.fixture() def missing_archive_id(): return "55a345a6-20d2-40a7-b85a-7ec37e55b986" @pytest.fixture() -def logged_in_client(client, user_id) -> str: +def logged_in_client(client, user_id: str) -> str: cookie = {"user_id": str(user_id)} client.cookies = cookie return client @@ -147,9 +175,34 @@ def project_id(test_project_name, user_id): created_id = new_project.id yield created_id with Session.begin() as session: - user = session.get(User, created_id) + project = session.get(Project, created_id) + if project: + session.delete(project) + + +@pytest.fixture() +def expiring_project_id(test_expiring_project_name, user_id): + now = datetime.datetime.now(datetime.UTC) + new_project = Project( + name=test_expiring_project_name, + created_on=now, + expire_on=now + datetime.timedelta(minutes=30), + files=[], + archives=[], + ) + with Session.begin() as session: + user = session.get(User, user_id) if user: - session.delete(user) + user.projects.append(new_project) + session.add(new_project) + session.flush() + session.refresh(new_project) + created_id = new_project.id + yield created_id + with Session.begin() as session: + project = session.get(Project, created_id) + if project: + session.delete(project) @pytest.fixture() @@ -158,7 +211,16 @@ def archive_id(test_archive_name, project_id): new_archive = Archive( created_on=now, status=ArchiveStatus.PENDING, - config={"filename": test_archive_name}, + config=ArchiveConfig.init_with( + filename=test_archive_name, + title="A Title", + description="A Description", + name="a_name", + creator="a creator", + publisher="a publisher", + languages="eng", + tags=[], + ), filesize=None, requested_on=None, completed_on=None, @@ -177,6 +239,135 @@ def archive_id(test_archive_name, project_id): created_id = new_archive.id yield created_id with Session.begin() as session: - archives = session.get(Archive, created_id) - if archives: - session.delete(archives) + archive = session.get(Archive, created_id) + if archive: + session.delete(archive) + + +@pytest.fixture() +def expiring_archive_id(test_archive_name, expiring_project_id): + now = datetime.datetime.now(datetime.UTC) + new_archive = Archive( + created_on=now, + status=ArchiveStatus.PENDING, + config=ArchiveConfig.init_with( + filename=test_archive_name, + title="A Title", + description="A Description", + name="a_name", + creator="a creator", + publisher="a publisher", + languages="eng", + tags=[], + ), + filesize=None, + requested_on=None, + completed_on=None, + download_url=None, + collection_json_path=None, + zimfarm_task_id=None, + email=None, + ) + with Session.begin() as session: + project = session.get(Project, expiring_project_id) + if project: + project.archives.append(new_archive) + session.add(new_archive) + session.flush() + session.refresh(new_archive) + created_id = new_archive.id + yield created_id + with Session.begin() as session: + archive = session.get(Archive, created_id) + if archive: + session.delete(archive) + + +class SuccessStorage: + + def upload_file(*args, **kwargs): ... + + def upload_fileobj(*args, **kwargs): ... + + def set_object_autodelete_on(*args, **kwargs): ... + + def has_object(*args, **kwargs): + return True + + def check_credentials(*args, **kwargs): + return True + + def delete_object(*args, **kwargs): ... + + +@pytest.fixture +def successful_s3_upload_file(monkeypatch): + """Requests.get() mocked to return {'mock_key':'mock_response'}.""" + + monkeypatch.setattr(s3_storage, "_storage", SuccessStorage()) + yield True + + +class SuccessfulRequestResponse: + status_code = HTTPStatus.OK + text = "text" + + @staticmethod + def raise_for_status(): ... + + +class SuccessfulAuthResponse(SuccessfulRequestResponse): + @staticmethod + def json(): + return { + "access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9." + "eyJpc3MiOiJkaXNwYXRjaGVyIiwiZXhwIj", + "token_type": "bearer", + "expires_in": 3600, + "refresh_token": "aea891db-090b-4cbb-6qer-57c0928b42e6", + } + + +class ScheduleCreatedResponse(SuccessfulRequestResponse): + status_code = HTTPStatus.CREATED + + @staticmethod + def json(): + return {"_id": uuid.uuid4().hex} + + +class TaskRequestedResponse(SuccessfulRequestResponse): + status_code = HTTPStatus.CREATED + + @staticmethod + def json(): + return {"requested": [uuid.uuid4().hex]} + + +class ScheduleDeletedResponse(SuccessfulRequestResponse): + + @staticmethod + def json(): + return {} + + +@pytest.fixture +def successful_zimfarm_request_task(monkeypatch): + """Requests.get() mocked to return {'mock_key':'mock_response'}.""" + + def requests_post(**kwargs): + uri = urllib.parse.urlparse(kwargs.get("url")) + if uri.path == "/v1/auth/authorize": + return SuccessfulAuthResponse() + if uri.path == "/v1/schedules/": + return ScheduleCreatedResponse() + if uri.path == "/v1/requested-tasks/": + return TaskRequestedResponse() + raise ValueError(f"Unhandled {kwargs}") + + def requests_delete(*args, **kwargs): + return ScheduleDeletedResponse() + + monkeypatch.setattr(requests, "post", requests_post) + monkeypatch.setattr(requests, "delete", requests_delete) + yield True diff --git a/backend/tests/routes/test_archives.py b/backend/tests/routes/test_archives.py index eae4d78..99da268 100644 --- a/backend/tests/routes/test_archives.py +++ b/backend/tests/routes/test_archives.py @@ -1,6 +1,9 @@ import uuid from http import HTTPStatus +import pytest +from httpx import AsyncClient + from api.constants import constants @@ -71,6 +74,7 @@ def test_update_archive_correct_data(logged_in_client, project_id, archive_id): "creator": "test_creator", "languages": "en", "tags": ["test_tags"], + "illustration": "", }, } response = logged_in_client.patch( @@ -214,3 +218,30 @@ def test_upload_illustration_without_wrong_authorization( files=file, ) assert response.status_code == HTTPStatus.UNAUTHORIZED + + +@pytest.mark.anyio +async def test_request_archive_not_ready(alogged_in_client, project_id, archive_id): + response = await alogged_in_client.post( + f"{constants.api_version_prefix}/projects/" + f"{project_id}/archives/{archive_id}/request" + ) + assert response.status_code == HTTPStatus.CONFLICT + + +@pytest.mark.anyio +async def test_request_archive_ready( + alogged_in_client: AsyncClient, + archive_id, + project_id, + expiring_project_id, + expiring_archive_id, + successful_s3_upload_file, + successful_zimfarm_request_task, +): + + response = await alogged_in_client.post( + f"{constants.api_version_prefix}/projects/" + f"{expiring_project_id}/archives/{expiring_archive_id}/request" + ) + assert response.status_code == HTTPStatus.CREATED diff --git a/backend/tests/routes/test_projects.py b/backend/tests/routes/test_projects.py index 485ea3e..bd24749 100644 --- a/backend/tests/routes/test_projects.py +++ b/backend/tests/routes/test_projects.py @@ -28,6 +28,14 @@ def test_create_project_wrong_authorization(client, missing_user_cookie): assert response.status_code == HTTPStatus.UNAUTHORIZED +def test_get_all_projects_no_data(logged_in_client): + response = logged_in_client.get(f"{constants.api_version_prefix}/projects") + json_result = response.json() + assert response.status_code == HTTPStatus.OK + assert json_result is not None + assert len(json_result) == 0 + + def test_get_all_projects_correct_data(logged_in_client, project_id): response = logged_in_client.get(f"{constants.api_version_prefix}/projects") json_result = response.json()