Skip to content

Commit

Permalink
Tests for Archive request endpoint
Browse files Browse the repository at this point in the history
- introducing a new AsynClient for proper stacktrace in tests errors (all tests should be converted to use it!)
- better ArchiveConfig fixture
- mock (only success for now) of requests calls in zimfarm
- mock (only success for now) of calls to S3
- added test of empty project list (actually tests that fixture works OK)
  • Loading branch information
rgaudin committed Jun 14, 2024
1 parent 95a6a7c commit 42b6b46
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 15 deletions.
8 changes: 5 additions & 3 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ lint = [
"ruff==0.4.3",
]
check = [
"pyright==1.1.366",
"pyright==1.1.367",
"pytest == 8.2.0", # import pytest in tests
]
test = [
"pytest==8.2.0",
"coverage==7.5.1",
"pytest-mock==3.14.0",
"trio == 0.25.1"
]
dev = [
"ipython==8.25.0",
Expand Down Expand Up @@ -198,8 +200,8 @@ extend-immutable-calls = ["fastapi.Depends", "fastapi.Query"]
ban-relative-imports = "all"

[tool.ruff.lint.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"tests/**/*" = ["PLR2004", "S101", "TID252"]
# Tests can use magic values, assertions, relative imports, print, and unused args (mock)
"tests/**/*" = ["PLR2004", "S101", "TID252","T201", "ARG001", "ARG002"]
"**/migrations/**/*" = ["F401", "ISC001"]

[tool.pytest.ini_options]
Expand Down
215 changes: 203 additions & 12 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
import datetime
import os
import urllib.parse
import uuid
from collections.abc import AsyncGenerator
from http import HTTPStatus
from io import BytesIO
from pathlib import Path
from typing import Any

import pytest # pyright: ignore [reportMissingImports]
from fastapi.testclient import TestClient
import requests
from httpx import AsyncClient
from starlette.testclient import TestClient

from api.database import Session
from api.database.models import Archive, File, Project, User
from api.database.models import Archive, ArchiveConfig, File, Project, User
from api.entrypoint import app
from api.files import save_file
from api.routes.archives import ArchiveStatus
from api.s3 import s3_storage

pytestmark = pytest.mark.asyncio(scope="package")


@pytest.fixture()
Expand All @@ -33,6 +43,19 @@ def client():
return TestClient(app)


@pytest.fixture(scope="module") # pyright: ignore
async def aclient() -> AsyncGenerator[AsyncClient, Any]:
async with AsyncClient(app=app, base_url="http://localhost") as client:
yield client


@pytest.fixture()
async def alogged_in_client(user_id: str):
async with AsyncClient(app=app, base_url="http://localhost") as client:
client.cookies = {"user_id": str(user_id)}
yield client


@pytest.fixture
def non_existent_project_id():
return "94e430c6-8888-456a-9440-c10e4a04627c"
Expand All @@ -48,23 +71,28 @@ def missing_user_cookie(missing_user_id):
return {"user_id": missing_user_id}


@pytest.fixture
@pytest.fixture()
def test_project_name():
return "test_project_name"


@pytest.fixture
@pytest.fixture()
def test_expiring_project_name():
return "test_expiring_project_name"


@pytest.fixture()
def test_archive_name():
return "test_archive_name.zim"


@pytest.fixture
@pytest.fixture()
def missing_archive_id():
return "55a345a6-20d2-40a7-b85a-7ec37e55b986"


@pytest.fixture()
def logged_in_client(client, user_id) -> str:
def logged_in_client(client, user_id: str) -> str:
cookie = {"user_id": str(user_id)}
client.cookies = cookie
return client
Expand Down Expand Up @@ -147,9 +175,34 @@ def project_id(test_project_name, user_id):
created_id = new_project.id
yield created_id
with Session.begin() as session:
user = session.get(User, created_id)
project = session.get(Project, created_id)
if project:
session.delete(project)


@pytest.fixture()
def expiring_project_id(test_expiring_project_name, user_id):
now = datetime.datetime.now(datetime.UTC)
new_project = Project(
name=test_expiring_project_name,
created_on=now,
expire_on=now + datetime.timedelta(minutes=30),
files=[],
archives=[],
)
with Session.begin() as session:
user = session.get(User, user_id)
if user:
session.delete(user)
user.projects.append(new_project)
session.add(new_project)
session.flush()
session.refresh(new_project)
created_id = new_project.id
yield created_id
with Session.begin() as session:
project = session.get(Project, created_id)
if project:
session.delete(project)


@pytest.fixture()
Expand All @@ -158,7 +211,16 @@ def archive_id(test_archive_name, project_id):
new_archive = Archive(
created_on=now,
status=ArchiveStatus.PENDING,
config={"filename": test_archive_name},
config=ArchiveConfig.init_with(
filename=test_archive_name,
title="A Title",
description="A Description",
name="a_name",
creator="a creator",
publisher="a publisher",
languages="eng",
tags=[],
),
filesize=None,
requested_on=None,
completed_on=None,
Expand All @@ -177,6 +239,135 @@ def archive_id(test_archive_name, project_id):
created_id = new_archive.id
yield created_id
with Session.begin() as session:
archives = session.get(Archive, created_id)
if archives:
session.delete(archives)
archive = session.get(Archive, created_id)
if archive:
session.delete(archive)


@pytest.fixture()
def expiring_archive_id(test_archive_name, expiring_project_id):
now = datetime.datetime.now(datetime.UTC)
new_archive = Archive(
created_on=now,
status=ArchiveStatus.PENDING,
config=ArchiveConfig.init_with(
filename=test_archive_name,
title="A Title",
description="A Description",
name="a_name",
creator="a creator",
publisher="a publisher",
languages="eng",
tags=[],
),
filesize=None,
requested_on=None,
completed_on=None,
download_url=None,
collection_json_path=None,
zimfarm_task_id=None,
email=None,
)
with Session.begin() as session:
project = session.get(Project, expiring_project_id)
if project:
project.archives.append(new_archive)
session.add(new_archive)
session.flush()
session.refresh(new_archive)
created_id = new_archive.id
yield created_id
with Session.begin() as session:
archive = session.get(Archive, created_id)
if archive:
session.delete(archive)


class SuccessStorage:

def upload_file(*args, **kwargs): ...

def upload_fileobj(*args, **kwargs): ...

def set_object_autodelete_on(*args, **kwargs): ...

def has_object(*args, **kwargs):
return True

def check_credentials(*args, **kwargs):
return True

def delete_object(*args, **kwargs): ...


@pytest.fixture
def successful_s3_upload_file(monkeypatch):
"""Requests.get() mocked to return {'mock_key':'mock_response'}."""

monkeypatch.setattr(s3_storage, "_storage", SuccessStorage())
yield True


class SuccessfulRequestResponse:
status_code = HTTPStatus.OK
text = "text"

@staticmethod
def raise_for_status(): ...


class SuccessfulAuthResponse(SuccessfulRequestResponse):
@staticmethod
def json():
return {
"access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9."
"eyJpc3MiOiJkaXNwYXRjaGVyIiwiZXhwIj",
"token_type": "bearer",
"expires_in": 3600,
"refresh_token": "aea891db-090b-4cbb-6qer-57c0928b42e6",
}


class ScheduleCreatedResponse(SuccessfulRequestResponse):
status_code = HTTPStatus.CREATED

@staticmethod
def json():
return {"_id": uuid.uuid4().hex}


class TaskRequestedResponse(SuccessfulRequestResponse):
status_code = HTTPStatus.CREATED

@staticmethod
def json():
return {"requested": [uuid.uuid4().hex]}


class ScheduleDeletedResponse(SuccessfulRequestResponse):

@staticmethod
def json():
return {}


@pytest.fixture
def successful_zimfarm_request_task(monkeypatch):
"""Requests.get() mocked to return {'mock_key':'mock_response'}."""

def requests_post(**kwargs):
uri = urllib.parse.urlparse(kwargs.get("url"))
if uri.path == "/v1/auth/authorize":
return SuccessfulAuthResponse()
if uri.path == "/v1/schedules/":
return ScheduleCreatedResponse()
if uri.path == "/v1/requested-tasks/":
return TaskRequestedResponse()
raise ValueError(f"Unhandled {kwargs}")

def requests_delete(*args, **kwargs):
return ScheduleDeletedResponse()

monkeypatch.setattr(requests, "post", requests_post)
monkeypatch.setattr(requests, "delete", requests_delete)
yield True
31 changes: 31 additions & 0 deletions backend/tests/routes/test_archives.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import uuid
from http import HTTPStatus

import pytest
from httpx import AsyncClient

from api.constants import constants


Expand Down Expand Up @@ -71,6 +74,7 @@ def test_update_archive_correct_data(logged_in_client, project_id, archive_id):
"creator": "test_creator",
"languages": "en",
"tags": ["test_tags"],
"illustration": "",
},
}
response = logged_in_client.patch(
Expand Down Expand Up @@ -214,3 +218,30 @@ def test_upload_illustration_without_wrong_authorization(
files=file,
)
assert response.status_code == HTTPStatus.UNAUTHORIZED


@pytest.mark.anyio
async def test_request_archive_not_ready(alogged_in_client, project_id, archive_id):
response = await alogged_in_client.post(
f"{constants.api_version_prefix}/projects/"
f"{project_id}/archives/{archive_id}/request"
)
assert response.status_code == HTTPStatus.CONFLICT


@pytest.mark.anyio
async def test_request_archive_ready(
alogged_in_client: AsyncClient,
archive_id,
project_id,
expiring_project_id,
expiring_archive_id,
successful_s3_upload_file,
successful_zimfarm_request_task,
):

response = await alogged_in_client.post(
f"{constants.api_version_prefix}/projects/"
f"{expiring_project_id}/archives/{expiring_archive_id}/request"
)
assert response.status_code == HTTPStatus.CREATED
8 changes: 8 additions & 0 deletions backend/tests/routes/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ def test_create_project_wrong_authorization(client, missing_user_cookie):
assert response.status_code == HTTPStatus.UNAUTHORIZED


def test_get_all_projects_no_data(logged_in_client):
response = logged_in_client.get(f"{constants.api_version_prefix}/projects")
json_result = response.json()
assert response.status_code == HTTPStatus.OK
assert json_result is not None
assert len(json_result) == 0


def test_get_all_projects_correct_data(logged_in_client, project_id):
response = logged_in_client.get(f"{constants.api_version_prefix}/projects")
json_result = response.json()
Expand Down

0 comments on commit 42b6b46

Please sign in to comment.