From 5015630223326434b405a8ee4b5d5f75fac47f0a Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Wed, 30 Oct 2024 23:08:49 +0000 Subject: [PATCH] WIP commit msg --- airlock/views.py | 9 +- jobserver/github.py | 131 ++++++++++++------ jobserver/views/projects.py | 5 +- jobserver/views/repos.py | 27 ++-- jobserver/views/workspaces.py | 7 +- staff/views/dashboards/copiloting.py | 9 +- tests/fakes.py | 70 +++++++++- tests/unit/airlock/test_views.py | 16 +-- tests/unit/jobserver/views/test_projects.py | 13 +- tests/unit/jobserver/views/test_repos.py | 22 +-- tests/unit/jobserver/views/test_workspaces.py | 17 +-- .../staff/views/dashboards/test_copiloting.py | 19 +-- tests/verification/test_github.py | 103 ++++++++++---- tests/verification/test_opencodelists.py | 15 +- tests/verification/test_utils.py | 102 ++++++++++++++ tests/verification/utils.py | 49 +++++-- 16 files changed, 434 insertions(+), 180 deletions(-) create mode 100644 tests/verification/test_utils.py diff --git a/airlock/views.py b/airlock/views.py index 92236542f..bc9f243c4 100644 --- a/airlock/views.py +++ b/airlock/views.py @@ -1,13 +1,12 @@ from dataclasses import dataclass from enum import Enum -from requests.exceptions import HTTPError from rest_framework.authentication import SessionAuthentication from rest_framework.decorators import api_view, authentication_classes from rest_framework.response import Response from jobserver.api.authentication import get_backend_from_token -from jobserver.github import _get_github_api +from jobserver.github import GitHubError, _get_github_api from jobserver.models import User, Workspace from .config import ORG_OUTPUT_CHECKING_REPOS @@ -122,7 +121,7 @@ def create_issue(airlock_event: AirlockEvent, github_api=None): airlock_event.repo, github_api, ) - except HTTPError as e: + except GitHubError as e: raise NotificationError(f"Error creating GitHub issue: {e}") @@ -138,7 +137,7 @@ def close_issue(airlock_event: AirlockEvent, github_api=None): airlock_event.repo, github_api, ) - except HTTPError as e: + except GitHubError as e: raise NotificationError(f"Error closing GitHub issue: {e}") @@ -155,7 +154,7 @@ def update_issue(airlock_event: AirlockEvent, github_api=None, notify_slack=Fals github_api, notify_slack=notify_slack, ) - except HTTPError as e: + except GitHubError as e: raise NotificationError(f"Error creating GitHub issue comment: {e}") diff --git a/jobserver/github.py b/jobserver/github.py index d7437eb85..4de780545 100644 --- a/jobserver/github.py +++ b/jobserver/github.py @@ -21,17 +21,49 @@ "User-Agent": "OpenSAFELY Jobs", } +# Clients should catch GitHubError to handle common, often transient, connection +# issues gracefully. Some HTTPError status codes indicate specific API errors +# related to remote state (e.g., attempting to create an object that already +# exists). These cases should raise specific HTTPError exceptions from this +# module, allowing clients to handle such errors without relying on internal +# implementation details. + class GitHubError(Exception): - """Base exception to target all other exceptions we define here""" + """Base exception for this module. A problem contacting or using the GitHub + API.""" + + +class Timeout(GitHubError): + """A request to the GitHub API timed out.""" + + +class ConnectionException(GitHubError): + """A connection error occurred while contacting the GitHub API.""" + + # ConnectionError is a Python default exception class, so let's avoid using + # that name. Otherwise we might have mirrored the name of + # requests.exceptions.ConnectionError. + + +class HTTPError(GitHubError): + """An HTTP request with an error status code was returned by the GitHub + API.""" + +class RepoAlreadyExists(HTTPError): + """Tried to create a repo that already existed.""" -class RepoAlreadyExists(GitHubError): - """An API call failed as the repo to be created already exists.""" +class RepoNotYetCreated(HTTPError): + """Tried to delete a repo that did not already exist.""" -class RepoNotYetCreated(GitHubError): - """An API call failed as the repo to be deleted already exists.""" + # Attach request and response because some unit tests inspect them.""" + + def __init__(self, request, response): + self.request = request + self.response = response + super().__init__(f"") class GitHubAPI: @@ -51,7 +83,7 @@ class GitHubAPI: def __init__(self, _session=session, token=None): """ - Initialise the wrapper with a session and maybe token + Initialise the wrapper with a session and maybe token. We pass in the session here so that tests can pass in a fake object to test internals. @@ -73,31 +105,39 @@ def _put(self, *args, **kwargs): def _request(self, method, *args, **kwargs): """ - Thin wrapper of requests.Session._request() - - This wrapper exists solely to inject the Authorization header if a - token has been set on the current instance and that headers hasn't - already been set in a given requests headers. - - This solves a tension between using an application-level session object - and wanting GitHubAPI instance-level authentication. We want to - support the use of different tokens for typical running (eg in prod), - verification tests (eg in CI), and the ability to query the API without - a token (less likely but can be useful) so we can't just set the header - on the session when it's defined at the module level. However if we - set it on the session then it persists beyond the life time of a given - GitHubAPI instance. + Make a request to the remote GitHub API. + + A wrapper for `requests.Session.request` that injects an + `Authorization` header if a token is set on the API instance and not + already included in the request headers. + + This design allows for instance-level authentication with different + tokens (e.g., for production, CI verification tests, or unauthenticated + queries) without setting the header globally on the session. + + Raises locally-defined Exceptions for common connection errors. """ headers = kwargs.pop("headers", {}) if self.token and "Authorization" not in headers: headers = headers | {"Authorization": f"bearer {self.token}"} - return self.session.request(method, *args, headers=headers, **kwargs) + try: + return self.session.request(method, *args, headers=headers, **kwargs) + except requests.Timeout as exc: + raise Timeout(exc) + except requests.ConnectionError as exc: + raise ConnectionException(exc) + + def _raise_for_status(self, request): + try: + request.raise_for_status() + except requests.HTTPError as exc: + raise HTTPError(exc.request, exc.response) def _get_query_page(self, *, query, session, cursor, **kwargs): """ - Get a page of the given query + Get a page of the given query. This uses the GraphQL API to avoid making O(N) calls to GitHub's (v3) REST API. The passed cursor is a GraphQL cursor [1] allowing us to call this @@ -116,7 +156,7 @@ def _get_query_page(self, *, query, session, cursor, **kwargs): print(r.headers) print(r.content) - r.raise_for_status() + self._raise_for_status(r) results = r.json() # In some cases graphql will return a 200 response when there are errors. @@ -133,7 +173,7 @@ def _get_query_page(self, *, query, session, cursor, **kwargs): def _iter_query_results(self, query, **kwargs): """ - Get results from a GraphQL query + Get results from a GraphQL query. Given a GraphQL query, return all results across one or more pages as a single generator. We currently assume all results live under @@ -158,7 +198,7 @@ def _iter_query_results(self, query, **kwargs): if not data["pageInfo"]["hasNextPage"]: break - # update the cursor we pass into the GraphQL query + # Update the cursor we pass into the GraphQL query. cursor = data["pageInfo"]["endCursor"] # pragma: no cover def _url(self, path_segments, query_args=None): @@ -192,7 +232,7 @@ def add_repo_to_team(self, team, org, repo): } r = self._put(url, headers=headers, json=payload) - r.raise_for_status() + self._raise_for_status(r) return @@ -225,7 +265,7 @@ def create_issue(self, org, repo, title, body, labels): } r = self._post(url, headers=headers, json=payload) - r.raise_for_status() + self._raise_for_status(r) return r.json() @@ -262,7 +302,7 @@ def get_issue_number_from_title( } r = self._get(url, headers=headers, params=payload) - r.raise_for_status() + self._raise_for_status(r) results = r.json() count = results["total_count"] @@ -307,7 +347,7 @@ def create_issue_comment( } r = self._post(url, headers=headers, json=payload) - r.raise_for_status() + self._raise_for_status(r) return r.json() @@ -320,7 +360,7 @@ def _change_issue_state(self, org, repo, issue_number, to_state): url = self._url(path_segments) r = self._post(url, headers=headers, json=payload) - r.raise_for_status() + self._raise_for_status(r) def close_issue(self, org, repo, title_text, comment=None, latest=True): if settings.DEBUG: # pragma: no cover @@ -354,7 +394,7 @@ def close_issue(self, org, repo, title_text, comment=None, latest=True): } r = self._post(url, headers=headers, json=payload) - r.raise_for_status() + self._raise_for_status(r) if comment is not None: self.create_issue_comment( @@ -379,11 +419,10 @@ def create_repo(self, org, repo): "Accept": "application/vnd.github.v3+json", } r = self._post(url, headers=headers, json=payload) + if r.status_code == 422: + raise RepoAlreadyExists() - try: - r.raise_for_status() - except requests.HTTPError as e: - raise RepoAlreadyExists from e + self._raise_for_status(r) return r.json() @@ -415,15 +454,15 @@ def delete_repo(self, org, repo): # pragma: no cover print(r.content) if r.status_code == 403: - # it's possible for us to create and then attempt to delete a repo + # It's possible for us to create and then attempt to delete a repo # faster than GitHub can create it on disk, so lets wait and retry - # if that's happened - # Note: 403 isn't just used for this state + # if that's happened. + # Note: 403 isn't just used for this state. msg = "Repository cannot be deleted until it is done being created on disk." if msg in r.json().get("message", ""): raise RepoNotYetCreated() - r.raise_for_status() + self._raise_for_status(r) def get_branch(self, org, repo, branch): path_segments = [ @@ -443,7 +482,7 @@ def get_branch(self, org, repo, branch): if r.status_code == 404: return - r.raise_for_status() + self._raise_for_status(r) return r.json() @@ -464,7 +503,7 @@ def get_branches(self, org, repo): if r.status_code == 404: return [] - r.raise_for_status() + self._raise_for_status(r) return r.json() @@ -484,7 +523,7 @@ def get_tag_sha(self, org, repo, tag): } r = self._get(url, headers=headers) - r.raise_for_status() + self._raise_for_status(r) return r.json()["object"]["sha"] @@ -507,7 +546,7 @@ def get_file(self, org, repo, branch, filepath="project.yaml"): if r.status_code == 404: return - r.raise_for_status() + self._raise_for_status(r) return r.text @@ -527,7 +566,7 @@ def get_repo(self, org, repo): if r.status_code == 404: return - r.raise_for_status() + self._raise_for_status(r) return r.json() @@ -587,7 +626,7 @@ def get_repos_with_branches(self, org): topics = [n["topic"]["name"] for n in repo["repositoryTopics"]["nodes"]] if "non-research" in topics: - continue # ignore non-research repos + continue # Ignore non-research repos. yield { "name": repo["name"], @@ -682,7 +721,7 @@ def set_repo_topics(self, org, repo, topics): } r = self._put(url, headers=headers, json=payload) - r.raise_for_status() + self._raise_for_status(r) return r.json() diff --git a/jobserver/views/projects.py b/jobserver/views/projects.py index b40ed1251..34180437f 100644 --- a/jobserver/views/projects.py +++ b/jobserver/views/projects.py @@ -2,7 +2,6 @@ import itertools import operator -import requests from django.core.exceptions import PermissionDenied from django.db.models import Min, OuterRef, Subquery from django.db.models.functions import Least, Lower @@ -17,7 +16,7 @@ from jobserver.utils import set_from_qs from ..authorization import has_permission, permissions -from ..github import _get_github_api +from ..github import GitHubError, _get_github_api from ..models import Job, JobRequest, Project, PublishRequest, Repo, Snapshot @@ -173,7 +172,7 @@ def get_repo(repo, ctx): is_private = self.get_github_api().get_repo_is_private( repo.owner, repo.name ) - except (requests.HTTPError, requests.Timeout, requests.ConnectionError): + except GitHubError: is_private = None span = trace.get_current_span() span.set_attribute("repo_owner", repo.owner) diff --git a/jobserver/views/repos.py b/jobserver/views/repos.py index 981f76048..a35a34b81 100644 --- a/jobserver/views/repos.py +++ b/jobserver/views/repos.py @@ -1,7 +1,6 @@ import itertools from urllib.parse import quote, unquote -import requests from django.contrib import messages from django.contrib.auth.decorators import login_required from django.http import Http404 @@ -16,7 +15,7 @@ send_repo_signed_off_notification_to_researchers, send_repo_signed_off_notification_to_staff, ) -from ..github import _get_github_api +from ..github import GitHubError, _get_github_api from ..models import Org, Project, ProjectMembership, Repo from ..slacks import notify_copilots_of_repo_sign_off @@ -164,7 +163,7 @@ def render_to_response(self): try: is_private = github_api.get_repo_is_private(self.repo.owner, self.repo.name) - except requests.HTTPError: + except GitHubError: is_private = None repo = { @@ -177,19 +176,27 @@ def render_to_response(self): "url": self.repo.url, } - workspaces = [build_workspace(w, github_api) for w in self.workspaces] + try: + workspaces = [build_workspace(w, github_api) for w in self.workspaces] + except GitHubError: + workspaces = [] # build up a list of branches without a workspace - branches = [ - b["name"] for b in github_api.get_branches(self.repo.owner, self.repo.name) - ] - workspace_branches = [w["branch"] for w in workspaces] - branches = [b for b in branches if b not in workspace_branches] + + try: + branches = [ + b["name"] + for b in github_api.get_branches(self.repo.owner, self.repo.name) + ] + workspace_branches = [w["branch"] for w in workspaces] + branches = [b for b in branches if b not in workspace_branches] + except GitHubError: + branches = [] workspaces_signed_off = not self.workspaces.filter(signed_off_at=None).exists() # TODO: when we have dealt with all the cross-project repos and are - # enforcing repos can't be used acrosss projects this check can be + # enforcing repos can't be used across projects this check can be # skipped. projects = Project.objects.filter(workspaces__repo=self.repo).distinct() if projects.count() == 1: diff --git a/jobserver/views/workspaces.py b/jobserver/views/workspaces.py index 7d2215659..d9a57d7f2 100644 --- a/jobserver/views/workspaces.py +++ b/jobserver/views/workspaces.py @@ -1,6 +1,5 @@ from datetime import timedelta -import requests from csp.decorators import csp_exempt from django.contrib import messages from django.core.exceptions import PermissionDenied @@ -27,7 +26,7 @@ WorkspaceEditForm, WorkspaceNotificationsToggleForm, ) -from ..github import _get_github_api +from ..github import GitHubError, _get_github_api from ..models import ( Backend, Job, @@ -134,7 +133,7 @@ def dispatch(self, request, *args, **kwargs): self.repos_with_branches = list( self.get_github_api().get_repos_with_branches(gh_org) ) - except requests.HTTPError: + except GitHubError: # gracefully handle not being able to access GitHub's API msg = ( "An error occurred while retrieving the list of repositories from GitHub, " @@ -215,7 +214,7 @@ def get(self, request, *args, **kwargs): repo_is_private = self.get_github_api().get_repo_is_private( workspace.repo.owner, workspace.repo.name ) - except (requests.ConnectionError, requests.HTTPError): + except GitHubError: repo_is_private = None show_publish_repo_warning = ( is_member diff --git a/staff/views/dashboards/copiloting.py b/staff/views/dashboards/copiloting.py index ca8869265..3b196a47f 100644 --- a/staff/views/dashboards/copiloting.py +++ b/staff/views/dashboards/copiloting.py @@ -1,6 +1,5 @@ import itertools -import requests import structlog from csp.decorators import csp_exempt from django.conf import settings @@ -12,7 +11,7 @@ from jobserver.authorization import StaffAreaAdministrator from jobserver.authorization.decorators import require_role -from jobserver.github import _get_github_api +from jobserver.github import GitHubError, _get_github_api from jobserver.models import Project, ReleaseFile, Repo @@ -45,10 +44,10 @@ def build_repos_by_project(projects, get_github_api=_get_github_api): try: github_repos = list(get_github_api().get_repos_with_status_and_url(repo_orgs)) - except requests.HTTPError: - # if the GitHub API is down log some details but don't block the page + except GitHubError as exc: logger.exception( - "Failed to get repo status and URL from GitHub API", repo_orgs=repo_orgs + f"Failed to get repo status and URL from GitHub API due to: {exc}", + repo_orgs=repo_orgs, ) return {} diff --git a/tests/fakes.py b/tests/fakes.py index 771597e96..7728d625d 100644 --- a/tests/fakes.py +++ b/tests/fakes.py @@ -3,8 +3,13 @@ from django.utils import timezone +from jobserver.github import GitHubError + class FakeGitHubAPI: + """Fake GitHubAPI that returns reasonable values for each corresponding + public function.""" + def add_repo_to_team(self, team, org, repo): return @@ -56,7 +61,7 @@ def get_branch_sha(self, org, repo, branch): def get_tag_sha(self, org, repo, tag): return "test_sha" - def get_file(self, org, repo, branch, filepath=None): + def get_file(self, org, repo, branch, filepath="project.yaml"): return textwrap.dedent( """ actions: @@ -160,6 +165,69 @@ def set_repo_topics(self, org, repo, topics): } +class FakeGitHubAPIWithErrors: + """Fake GitHubAPI that returns an error for each corresponding public + function.""" + + def add_repo_to_team(self, team, org, repo): + raise GitHubError() + + def create_issue(self, org, repo, title, body, labels): + # Some unit tests want to check the message. + raise GitHubError("An error occurred") + + def get_issue_number_from_title( + self, org, repo, title_text, latest=True, state=None + ): + raise GitHubError() + + def create_issue_comment( + self, org, repo, title_text, body, latest=True, issue_number=1 + ): + # Some unit tests want to check the message. + raise GitHubError("An error occurred") + + def close_issue(self, org, repo, title_text, comment=None, latest=True): + # Some unit tests want to check the message. + raise GitHubError("An error occurred") + + def create_repo(self, org, repo): + raise GitHubError() + + def get_branch(self, org, repo, branch): + raise GitHubError() + + def get_branches(self, org, repo): + raise GitHubError() + + def get_branch_sha(self, org, repo, branch): + raise GitHubError() + + def get_tag_sha(self, org, repo, tag): + raise GitHubError() + + def get_file(self, org, repo, branch, filepath="project.yaml"): + raise GitHubError() + + def get_repo(self, org, repo): + raise GitHubError() + + def get_repo_is_private(self, org, repo): + raise GitHubError() + + def get_repos_with_branches(self, org): + raise GitHubError() + + def get_repos_with_dates(self, org): + raise GitHubError() + + def get_repos_with_status_and_url(self, orgs): + raise GitHubError() + + def set_repo_topics(self, org, repo, topics): + raise GitHubError() + + class FakeOpenCodelistsAPI: def get_codelists(self, coding_system): return [ diff --git a/tests/unit/airlock/test_views.py b/tests/unit/airlock/test_views.py index 7cf973b92..5171236c1 100644 --- a/tests/unit/airlock/test_views.py +++ b/tests/unit/airlock/test_views.py @@ -1,7 +1,6 @@ from unittest.mock import patch import pytest -from requests.exceptions import HTTPError from airlock.views import AirlockEvent, EventType, airlock_event_view from tests.factories import ( @@ -11,18 +10,7 @@ UserFactory, WorkspaceFactory, ) -from tests.fakes import FakeGitHubAPI - - -class FakeGithubApiWithError: - def create_issue(*args, **kwargs): - raise HTTPError("An error occurred") - - def create_issue_comment(*args, **kwargs): - raise HTTPError("An error occurred") - - def close_issue(*args, **kwargs): - raise HTTPError("An error occurred") +from tests.fakes import FakeGitHubAPI, FakeGitHubAPIWithErrors @pytest.mark.parametrize( @@ -205,7 +193,7 @@ def test_api_post_release_request_default_org_and_repo(mock_create_issue, api_rf ("bad_event_type", None, "Unknown event type 'BAD_EVENT_TYPE'"), ], ) -@patch("airlock.views._get_github_api", FakeGithubApiWithError) +@patch("airlock.views._get_github_api", FakeGitHubAPIWithErrors) def test_api_airlock_event_error(api_rf, event_type, updates, error): author = UserFactory() user = UserFactory() diff --git a/tests/unit/jobserver/views/test_projects.py b/tests/unit/jobserver/views/test_projects.py index 73df1b763..4e15a19e2 100644 --- a/tests/unit/jobserver/views/test_projects.py +++ b/tests/unit/jobserver/views/test_projects.py @@ -1,5 +1,4 @@ import pytest -import requests from django.contrib.auth.models import AnonymousUser from django.core.exceptions import PermissionDenied from django.http import Http404 @@ -14,6 +13,8 @@ ProjectEventLog, ProjectReportList, ) +from tests.fakes import FakeGitHubAPI, FakeGitHubAPIWithErrors +from tests.utils import minutes_ago from ....factories import ( JobFactory, @@ -27,8 +28,6 @@ UserFactory, WorkspaceFactory, ) -from ....fakes import FakeGitHubAPI -from ....utils import minutes_ago @pytest.mark.parametrize("user", [UserFactory, AnonymousUser]) @@ -173,7 +172,7 @@ def test_projectdetail_with_multiple_releases(rf, freezer): assert snapshot4 not in snapshots -def test_projectdetail_with_no_github(rf): +def test_projectdetail_with_github_error(rf): project = ProjectFactory(org=OrgFactory()) WorkspaceFactory( project=project, repo=RepoFactory(url="https://github.com/owner/repo") @@ -183,11 +182,7 @@ def test_projectdetail_with_no_github(rf): request = rf.get("/") request.user = UserFactory() - class BrokenGitHubAPI: - def get_repo_is_private(self, *args): - raise requests.HTTPError - - response = ProjectDetail.as_view(get_github_api=BrokenGitHubAPI)( + response = ProjectDetail.as_view(get_github_api=FakeGitHubAPIWithErrors)( request, project_slug=project.slug ) diff --git a/tests/unit/jobserver/views/test_repos.py b/tests/unit/jobserver/views/test_repos.py index b78d12362..4ff0dda1c 100644 --- a/tests/unit/jobserver/views/test_repos.py +++ b/tests/unit/jobserver/views/test_repos.py @@ -1,13 +1,13 @@ from urllib.parse import quote import pytest -import requests from django.contrib.auth.models import AnonymousUser from django.contrib.messages.storage.fallback import FallbackStorage from django.http import Http404 from django.utils import timezone from jobserver.views.repos import RepoHandler, SignOffRepo +from tests.fakes import FakeGitHubAPI, FakeGitHubAPIWithErrors from ....factories import ( OrgFactory, @@ -16,7 +16,6 @@ UserFactory, WorkspaceFactory, ) -from ....fakes import FakeGitHubAPI def test_repohandler_with_broken_repo_url(rf): @@ -160,31 +159,20 @@ def test_signoffrepo_get_success_with_broken_github(rf, project_membership): project_membership(project=project, user=user) - workspaces = WorkspaceFactory.create_batch(5, project=project, repo=repo) + WorkspaceFactory.create_batch(5, project=project, repo=repo) WorkspaceFactory.create_batch(5, project=project) request = rf.get("/") request.user = user - class BrokenGitHubAPI: - def get_branch(self, owner, repo, branch): - return {} - - def get_branches(self, owner, repo): - return [] - - def get_repo_is_private(self, owner, repo): - raise requests.HTTPError() - - response = SignOffRepo.as_view(get_github_api=BrokenGitHubAPI)( + response = SignOffRepo.as_view(get_github_api=FakeGitHubAPIWithErrors)( request, repo_url=repo.quoted_url ) assert response.status_code == 200 - expected = {w["name"] for w in response.context_data["workspaces"]} - assert {w.name for w in workspaces} == expected - + assert response.context_data["workspaces"] == [] + assert response.context_data["branches"] == [] assert response.context_data["repo"]["is_private"] is None assert response.context_data["repo"]["name"] == "name" assert response.context_data["repo"]["status"] == "public" diff --git a/tests/unit/jobserver/views/test_workspaces.py b/tests/unit/jobserver/views/test_workspaces.py index dffb1c3a4..9a02ed41b 100644 --- a/tests/unit/jobserver/views/test_workspaces.py +++ b/tests/unit/jobserver/views/test_workspaces.py @@ -2,7 +2,6 @@ from datetime import timedelta import pytest -import requests from django.contrib.auth.models import AnonymousUser from django.contrib.messages.storage.fallback import FallbackStorage from django.core.exceptions import PermissionDenied @@ -24,6 +23,8 @@ WorkspaceNotificationsToggle, WorkspaceOutputList, ) +from tests.fakes import FakeGitHubAPI, FakeGitHubAPIWithErrors +from tests.utils import minutes_ago from ....factories import ( AnalysisRequestFactory, @@ -41,8 +42,6 @@ UserFactory, WorkspaceFactory, ) -from ....fakes import FakeGitHubAPI -from ....utils import minutes_ago # this is what defines "private" @@ -180,11 +179,7 @@ def test_workspacecreate_without_github(rf, project_membership, user, role_facto messages = FallbackStorage(request) request._messages = messages - class BrokenGitHubAPI: - def get_repos_with_branches(self, *args): - raise requests.HTTPError - - response = WorkspaceCreate.as_view(get_github_api=BrokenGitHubAPI)( + response = WorkspaceCreate.as_view(get_github_api=FakeGitHubAPIWithErrors)( request, project_slug=project.slug ) @@ -612,11 +607,7 @@ def test_workspacedetail_with_no_github(rf): request = rf.get("/") request.user = UserFactory() - class BrokenGitHubAPI: - def get_repo_is_private(self, *args): - raise requests.HTTPError - - response = WorkspaceDetail.as_view(get_github_api=BrokenGitHubAPI)( + response = WorkspaceDetail.as_view(get_github_api=FakeGitHubAPIWithErrors)( request, project_slug=workspace.project.slug, workspace_slug=workspace.name, diff --git a/tests/unit/staff/views/dashboards/test_copiloting.py b/tests/unit/staff/views/dashboards/test_copiloting.py index c3b93cdc4..b1e680e0a 100644 --- a/tests/unit/staff/views/dashboards/test_copiloting.py +++ b/tests/unit/staff/views/dashboards/test_copiloting.py @@ -1,7 +1,6 @@ from datetime import UTC, datetime import pytest -import requests from django.core.exceptions import PermissionDenied from jobserver.models import Project @@ -10,6 +9,7 @@ MissingGitHubReposError, build_repos_by_project, ) +from tests.fakes import FakeGitHubAPI, FakeGitHubAPIWithErrors from .....factories import ( JobFactory, @@ -21,7 +21,6 @@ UserFactory, WorkspaceFactory, ) -from .....fakes import FakeGitHubAPI def test_build_repos_by_project_missing_github_repos(): @@ -42,12 +41,9 @@ def test_build_repos_by_project_with_broken_github_api(): projects = Project.objects.all() - class BrokenGitHubAPI: - def get_repos_with_status_and_url(self, orgs): - # simulate the GitHub API being down - raise requests.HTTPError() - - assert build_repos_by_project(projects, get_github_api=BrokenGitHubAPI) == {} + assert ( + build_repos_by_project(projects, get_github_api=FakeGitHubAPIWithErrors) == {} + ) def test_copiloting_success(rf, staff_area_administrator): @@ -101,12 +97,7 @@ def test_copiloting_with_broken_github_api(rf, staff_area_administrator): request = rf.get("/") request.user = staff_area_administrator - class BrokenGitHubAPI: - def get_repos_with_status_and_url(self, orgs): - # simulate the GitHub API being down - raise requests.HTTPError() - - response = Copiloting.as_view(get_github_api=BrokenGitHubAPI)(request) + response = Copiloting.as_view(get_github_api=FakeGitHubAPIWithErrors)(request) assert response.status_code == 200 diff --git a/tests/verification/test_github.py b/tests/verification/test_github.py index 5f11331ce..359c8f2f9 100644 --- a/tests/verification/test_github.py +++ b/tests/verification/test_github.py @@ -1,18 +1,42 @@ import pytest +import requests.exceptions import stamina from environs import Env -from requests.exceptions import HTTPError - -from jobserver.github import GitHubAPI, RepoAlreadyExists, RepoNotYetCreated +from requests.models import Response + +from jobserver.github import ( + ConnectionException, + GitHubAPI, + HTTPError, + RepoAlreadyExists, + RepoNotYetCreated, + Timeout, +) from jobserver.models.common import new_ulid_str -from ..fakes import FakeGitHubAPI -from .utils import compare +from ..fakes import FakeGitHubAPI, FakeGitHubAPIWithErrors +from .utils import assert_deep_type_equality, assert_public_method_signature_equality pytestmark = [pytest.mark.verification, pytest.mark.disable_db] +def test_fake_public_method_signatures(): + assert_public_method_signature_equality( + GitHubAPI, + FakeGitHubAPI, + ignored_methods=["delete_repo"], + ) + + +def test_fake_with_errors_public_method_signatures(): + assert_public_method_signature_equality( + GitHubAPI, + FakeGitHubAPIWithErrors, + ignored_methods=["delete_repo"], + ) + + @pytest.fixture def clear_topics(github_api): args = [ @@ -60,7 +84,7 @@ def test_create_issue(enable_network, github_api): real = github_api.create_issue(*args) fake = FakeGitHubAPI().create_issue(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) assert real is not None @@ -72,7 +96,7 @@ def test_get_issue_number(enable_network, github_api): real = github_api.get_issue_number_from_title(*args) fake = FakeGitHubAPI().get_issue_number_from_title(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) assert real is not None @@ -110,7 +134,7 @@ def test_create_issue_comment(enable_network, github_api): real = github_api.create_issue_comment(*args) fake = FakeGitHubAPI().create_issue_comment(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) assert real is not None @@ -133,7 +157,7 @@ def test_close_issue(enable_network, github_api, comment): real = github_api.close_issue(*args) fake = FakeGitHubAPI().close_issue(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) assert real is not None @@ -161,8 +185,7 @@ def test_create_repo(enable_network, github_api): real = github_api.create_repo(*args) fake = FakeGitHubAPI().create_repo(*args) - # does the fake work as expected? - compare(fake, real) + assert_deep_type_equality(fake, real) assert real is not None @@ -183,7 +206,7 @@ def test_get_branch(enable_network, github_api): real = github_api.get_branch(*args) fake = FakeGitHubAPI().get_branch(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) assert real is not None @@ -198,7 +221,7 @@ def test_get_branches(enable_network, github_api): real = github_api.get_branches(*args) fake = FakeGitHubAPI().get_branches(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) assert real is not None @@ -209,7 +232,7 @@ def test_get_branches_with_unknown_org(enable_network, github_api): real = github_api.get_branches(*args) fake = FakeGitHubAPI().get_branches(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) assert real == [] @@ -220,7 +243,7 @@ def test_get_branches_with_unknown_repo(enable_network, github_api): real = github_api.get_branches(*args) fake = FakeGitHubAPI().get_branches(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) assert real == [] @@ -234,7 +257,7 @@ def test_get_branch_sha(enable_network, github_api): assert real == "71650d527c9288f90aa01d089f5a9884b683f7ed" # get_branch_sha is a wrapper for get_branch to extract just the SHA as a - # string so no need to throw compare() at it + # string so no need to throw assert_deep_type_equality() at it assert isinstance(fake, str) @@ -265,7 +288,7 @@ def test_get_default_file(enable_network, github_api): real = github_api.get_file(*args) fake = FakeGitHubAPI().get_file(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) def test_get_file(enable_network, github_api): @@ -274,7 +297,7 @@ def test_get_file(enable_network, github_api): real = github_api.get_file(*args) fake = FakeGitHubAPI().get_file(*args, filepath="README.md") - compare(fake, real) + assert_deep_type_equality(fake, real) def test_get_file_missing_project_yml(enable_network, github_api): @@ -293,7 +316,7 @@ def test_get_repo(enable_network, github_api): real = github_api.get_repo(*args) fake = FakeGitHubAPI().get_repo(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) def test_get_repo_is_private(enable_network, github_api): @@ -317,7 +340,7 @@ def test_get_repos_with_branches(enable_network, github_api): real = list(github_api.get_repos_with_branches(*args)) fake = FakeGitHubAPI().get_repos_with_branches(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) def test_get_repos_with_dates(enable_network, github_api): @@ -326,7 +349,7 @@ def test_get_repos_with_dates(enable_network, github_api): real = list(github_api.get_repos_with_dates(*args)) fake = FakeGitHubAPI().get_repos_with_dates(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) def test_get_repos_with_status_and_url(enable_network, github_api): @@ -335,7 +358,7 @@ def test_get_repos_with_status_and_url(enable_network, github_api): real = list(github_api.get_repos_with_status_and_url(args)) fake = FakeGitHubAPI().get_repos_with_status_and_url(args) - compare(fake, real) + assert_deep_type_equality(fake, real) def test_graphql_error_handling(enable_network, github_api): @@ -367,8 +390,7 @@ def test_set_repo_topics(enable_network, github_api, clear_topics): real = github_api.set_repo_topics(*args) fake = FakeGitHubAPI().set_repo_topics(*args) - # does the fake work as expected? - compare(fake, real) + assert_deep_type_equality(fake, real) assert real is not None @@ -388,3 +410,36 @@ def test_unauthenticated_request(enable_network): assert True else: raise + + +def test_timeout(enable_network, github_api, monkeypatch): + def mock_request(*args, **kwargs): + raise requests.exceptions.Timeout() + + monkeypatch.setattr(github_api.session, "request", mock_request) + + with pytest.raises(Timeout): + github_api.get_repo("opensafely-testing", "github-api-testing") + + +def test_connection_error(enable_network, github_api, monkeypatch): + def mock_request(*args, **kwargs): + raise requests.exceptions.ConnectionError() + + monkeypatch.setattr(github_api.session, "request", mock_request) + + with pytest.raises(ConnectionException): + github_api.get_repo("opensafely-testing", "github-api-testing") + + +def test_http_error(enable_network, github_api, monkeypatch): + def mock_request(*args, **kwargs): + mock_response = Response() + mock_response.status_code = 403 + mock_response._content = b"Not allowed" + return mock_response + + monkeypatch.setattr(github_api.session, "request", mock_request) + + with pytest.raises(HTTPError): + github_api.get_repo("opensafely-testing", "github-api-testing") diff --git a/tests/verification/test_opencodelists.py b/tests/verification/test_opencodelists.py index ae551db26..5694a0e24 100644 --- a/tests/verification/test_opencodelists.py +++ b/tests/verification/test_opencodelists.py @@ -3,7 +3,7 @@ from jobserver.opencodelists import OpenCodelistsAPI from ..fakes import FakeOpenCodelistsAPI -from .utils import compare +from .utils import assert_deep_type_equality, assert_public_method_signature_equality pytestmark = [pytest.mark.verification, pytest.mark.disable_db] @@ -15,13 +15,20 @@ def opencodelists_api(): return OpenCodelistsAPI() +def test_fake_public_method_signatures(): + assert_public_method_signature_equality( + OpenCodelistsAPI, + FakeOpenCodelistsAPI, + ) + + def test_get_codelists(enable_network, opencodelists_api): args = ["snomedct"] real = opencodelists_api.get_codelists(*args) fake = FakeOpenCodelistsAPI().get_codelists(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) assert real is not None @@ -32,7 +39,7 @@ def test_get_codelists_with_unknown_coding_system(enable_network, opencodelists_ real = opencodelists_api.get_codelists(*args) fake = FakeOpenCodelistsAPI().get_codelists(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) assert real == [] @@ -43,4 +50,4 @@ def test_check_codelists(enable_network, opencodelists_api): real = opencodelists_api.check_codelists(*args) fake = FakeOpenCodelistsAPI().check_codelists(*args) - compare(fake, real) + assert_deep_type_equality(fake, real) diff --git a/tests/verification/test_utils.py b/tests/verification/test_utils.py new file mode 100644 index 000000000..9408bde29 --- /dev/null +++ b/tests/verification/test_utils.py @@ -0,0 +1,102 @@ +"""Tests of this test package's utils module.""" + +import pytest + +from .utils import assert_public_method_signature_equality + + +##################################################################### +# Classes for tests of assert_public_method_signature_equality + + +class ClassA: + def method_one(self, x, y): + pass + + def method_two(self, z): + pass + + +# The same as A. +class ClassB: + def method_one(self, x, y): + pass + + def method_two(self, z): + pass + + +# The same as A, but a method has a different signature. +class ClassC: + # Different signature to A.method_one. + def method_one(self, x): + pass + + def method_two(self, z): + pass + + +# The same as A, but with an extra method. +class ClassD: + def method_one(self, x, y): + pass + + def method_two(self, z): + pass + + def method_three(self): + pass + + +# The same as A, but with a private method. +class ClassE: + def method_one(self, x, y): + pass + + def method_two(self, z): + pass + + def _private_method(self): + pass + + +##################################################################### + + +class TestPublicMethodSignatureEquality: + def test_identity(cls): + """Test when applied against the same class.""" + for cls in (ClassA, ClassB, ClassC, ClassD, ClassE): + assert_public_method_signature_equality(cls, cls) + + def test_matching_methods(cls): + """Test when methods match between classes.""" + assert_public_method_signature_equality(ClassA, ClassB) + + def test_signature_mismatch(cls): + """Test when method signatures differ between classes.""" + with pytest.raises(AssertionError, match="signature mismatch"): + assert_public_method_signature_equality(ClassA, ClassC) + + def test_extra_method(cls): + """Test when the second class has an extra method.""" + with pytest.raises(AssertionError, match="methods mismatch"): + assert_public_method_signature_equality(ClassA, ClassD) + + def test_missing_method(cls): + """Test when a method is missing in the second class.""" + with pytest.raises(AssertionError, match="methods mismatch"): + assert_public_method_signature_equality(ClassD, ClassA) + + def test_ignore_methods(cls): + """Test when certain methods are ignored in the comparison.""" + assert_public_method_signature_equality( + ClassA, ClassD, ignored_methods=["method_three"] + ) + + def test_ignore_private_methods(cls): + """Test that extra private methods are ignored.""" + assert_public_method_signature_equality(ClassA, ClassE) + + +##################################################################### diff --git a/tests/verification/utils.py b/tests/verification/utils.py index 0cbf56a6e..d63ee1c11 100644 --- a/tests/verification/utils.py +++ b/tests/verification/utils.py @@ -1,27 +1,54 @@ -def compare(fake, real): +import inspect + + +def assert_deep_type_equality(fake, real): """ - Compare outputs of Fake* instances to those from API instances + Do two objects have the same types, compared deeply? + + Fake API instances return dummy data from the methods they implement. We + want to be able to validate that at least the returned types are the same. - Fake API instances return partial, non-real data from the methods they - implement. For tests we haven't found the need to have those responses be - real data, but we still what their shape and values to be correct in terms - of the API response schemas. This function checks the correctness of those - values. + For example, if `fake` is a list, is `real` a list, and do the + corresponding elements of those lists also have the same types? + + Works for str, int, list, dict, recursively. These are currently the types + we expect the relevant API endpoints to return. """ assert type(fake) is type(real) if isinstance(fake, list): for x, y in zip(fake, real): - compare(x, y) + assert_deep_type_equality(x, y) return if isinstance(fake, str | int): return for key, value in fake.items(): - assert key in real - assert isinstance(value, type(real[key])) + assert key in real and isinstance(value, type(real[key])) if isinstance(value, dict): - compare(fake[key], real[key]) + assert_deep_type_equality(fake[key], real[key]) + + +def _get_public_method_signatures(cls, ignored_methods): + return { + name: inspect.signature(method) + for name, method in inspect.getmembers(cls, predicate=inspect.isfunction) + if not name.startswith("_") and name not in ignored_methods + } + + +def assert_public_method_signature_equality(first, second, ignored_methods=None): + if not ignored_methods: + ignored_methods = [] + first_methods = _get_public_method_signatures(first, ignored_methods) + second_methods = _get_public_method_signatures(second, ignored_methods) + + assert set(first_methods) == set(second_methods), "methods mismatch" + + # Check if every public method of the first class is in the second class + # with the same signature. + for name, sig in first_methods.items(): + assert second_methods[name] == sig, "signature mismatch"