diff --git a/synapseclient/client.py b/synapseclient/client.py index ce25aa5ab..61bcf73c5 100644 --- a/synapseclient/client.py +++ b/synapseclient/client.py @@ -114,6 +114,7 @@ is_integer, is_json, require_param, + validate_submission_id, ) from synapseclient.core.version_check import version_check @@ -4807,12 +4808,15 @@ def _POST_paginated(self, uri: str, body, **kwargs): if next_page_token is None: break - def getSubmission(self, id, **kwargs): + def getSubmission( + self, id: typing.Union[str, int, collections.abc.Mapping], **kwargs + ) -> Submission: """ - Gets a [synapseclient.evaluation.Submission][] object by its id. + Gets a [synapseclient.evaluation.Submission][] object based on a given ID + or previous [synapseclient.evaluation.Submission][] object. Arguments: - id: The id of the submission to retrieve + id: The ID of the submission to retrieve or a [synapseclient.evaluation.Submission][] object Returns: A [synapseclient.evaluation.Submission][] object @@ -4823,7 +4827,7 @@ def getSubmission(self, id, **kwargs): on the *downloadFile*, *downloadLocation*, and *ifcollision* parameters """ - submission_id = id_of(id) + submission_id = validate_submission_id(id) uri = Submission.getURI(submission_id) submission = Submission(**self.restGET(uri)) @@ -4852,18 +4856,20 @@ def getSubmission(self, id, **kwargs): return submission - def getSubmissionStatus(self, submission): + def getSubmissionStatus( + self, submission: typing.Union[str, int, collections.abc.Mapping] + ) -> SubmissionStatus: """ - Downloads the status of a Submission. + Downloads the status of a Submission given its ID or previous [synapseclient.evaluation.Submission][] object. Arguments: - submission: The submission to lookup + submission: The submission to lookup (ID or [synapseclient.evaluation.Submission][] object) Returns: A [synapseclient.evaluation.SubmissionStatus][] object """ - submission_id = id_of(submission) + submission_id = validate_submission_id(submission) uri = SubmissionStatus.getURI(submission_id) val = self.restGET(uri) return SubmissionStatus(**val) diff --git a/synapseclient/core/utils.py b/synapseclient/core/utils.py index 7af693729..d1cefaa0b 100644 --- a/synapseclient/core/utils.py +++ b/synapseclient/core/utils.py @@ -11,6 +11,7 @@ import hashlib import importlib import inspect +import logging import numbers import os import platform @@ -30,6 +31,8 @@ import requests from opentelemetry import trace +from synapseclient.core.logging_setup import DEFAULT_LOGGER_NAME + if TYPE_CHECKING: from synapseclient.models import File, Folder, Project @@ -47,6 +50,10 @@ SLASH_PREFIX_REGEX = re.compile(r"\/[A-Za-z]:") +# Set up logging +LOGGER_NAME = DEFAULT_LOGGER_NAME +LOGGER = logging.getLogger(LOGGER_NAME) + def md5_for_file( filename: str, block_size: int = 2 * MB, callback: typing.Callable = None @@ -242,6 +249,43 @@ def id_of(obj: typing.Union[str, collections.abc.Mapping, numbers.Number]) -> st raise ValueError("Invalid parameters: couldn't find id of " + str(obj)) +def validate_submission_id( + submission_id: typing.Union[str, int, collections.abc.Mapping] +) -> str: + """ + Ensures that a given submission ID is either an integer or a string that + can be converted to an integer. Version notation is not supported for submission + IDs, therefore decimals are not allowed. + + Arguments: + submission_id: The submission ID to validate + + Returns: + The submission ID as a string + + """ + if isinstance(submission_id, int): + return str(submission_id) + elif isinstance(submission_id, str) and submission_id.isdigit(): + return submission_id + elif isinstance(submission_id, collections.abc.Mapping): + syn_id = _get_from_members_items_or_properties(submission_id, "id") + if syn_id is not None: + return validate_submission_id(syn_id) + else: + try: + int_submission_id = int(float(submission_id)) + except ValueError: + raise ValueError( + f"Submission ID '{submission_id}' is not a valid submission ID. Please use digits only." + ) + LOGGER.warning( + f"Submission ID '{submission_id}' contains decimals which are not supported. " + f"Submission ID will be converted to '{int_submission_id}'." + ) + return str(int_submission_id) + + def concrete_type_of(obj: collections.abc.Mapping): """ Return the concrete type of an object representing a Synapse entity. diff --git a/tests/unit/synapseclient/core/unit_test_utils.py b/tests/unit/synapseclient/core/unit_test_utils.py index 5b521bcea..30478f7fe 100644 --- a/tests/unit/synapseclient/core/unit_test_utils.py +++ b/tests/unit/synapseclient/core/unit_test_utils.py @@ -1,6 +1,7 @@ # unit tests for utils.py import base64 +import logging import os import re import tempfile @@ -100,6 +101,49 @@ def __init__(self, id_attr_name: str, id: str) -> None: assert utils.id_of(foo) == "123" +@pytest.mark.parametrize( + "input_value, expected_output, expected_warning", + [ + # Test 1: Valid inputs + ("123", "123", None), + (123, "123", None), + ({"id": "222"}, "222", None), + # Test 2: Invalid inputs that should be corrected + ( + "123.0", + "123", + "Submission ID '123.0' contains decimals which are not supported", + ), + ( + 123.0, + "123", + "Submission ID '123.0' contains decimals which are not supported", + ), + ( + {"id": "999.222"}, + "999", + "Submission ID '999.222' contains decimals which are not supported", + ), + ], +) +def test_validate_submission_id(input_value, expected_output, expected_warning, caplog): + with caplog.at_level(logging.WARNING): + assert utils.validate_submission_id(input_value) == expected_output + if expected_warning: + assert expected_warning in caplog.text + else: + assert not caplog.text + + +def test_validate_submission_id_letters_input() -> None: + letters_input = "syn123" + expected_error = f"Submission ID '{letters_input}' is not a valid submission ID. Please use digits only." + with pytest.raises(ValueError) as err: + utils.validate_submission_id(letters_input) + + assert str(err.value) == expected_error + + # TODO: Add a test for is_synapse_id_str(...) # https://sagebionetworks.jira.com/browse/SYNPY-1425 diff --git a/tests/unit/synapseclient/unit_test_client.py b/tests/unit/synapseclient/unit_test_client.py index 3489d9526..e7810efbe 100644 --- a/tests/unit/synapseclient/unit_test_client.py +++ b/tests/unit/synapseclient/unit_test_client.py @@ -7,6 +7,7 @@ import logging import os import tempfile +import typing import urllib.request as urllib_request import uuid from pathlib import Path @@ -59,6 +60,7 @@ ) from synapseclient.core.models.dict_object import DictObject from synapseclient.core.upload import upload_functions +from synapseclient.evaluation import Submission, SubmissionStatus GET_FILE_HANDLE_FOR_DOWNLOAD = ( "synapseclient.core.download.download_functions.get_file_handle_for_download_async" @@ -2995,6 +2997,156 @@ def test_get_submission_with_annotations(syn: Synapse) -> None: assert evaluation_id == response["evaluationId"] +def run_get_submission_test( + syn: Synapse, + submission_id: typing.Union[str, int], + expected_id: str, + should_warn: bool = False, + caplog=None, +) -> None: + """ + Common code for test_get_submission_valid_id and test_get_submission_invalid_id. + Generates a dummy submission dictionary for regression testing, mocks the API calls, + and validates the expected output for getSubmission. For invalid submission IDs, this + will check that a warning was logged for the user before transforming their input. + + Arguments: + syn: Synapse object + submission_id: Submission ID to test + expected_id: Submission ID that should be returned + should_warn: Whether or not a warning should be logged + caplog: pytest caplog fixture + + Returns: + None + + """ + evaluation_id = (98765,) + submission = { + "evaluationId": evaluation_id, + "entityId": submission_id, + "versionNumber": 1, + "entityBundleJSON": json.dumps({}), + } + + with patch.object(syn, "restGET") as restGET, patch.object( + syn, "_getWithEntityBundle" + ) as get_entity: + restGET.return_value = submission + + if should_warn: + with caplog.at_level(logging.WARNING): + syn.getSubmission(submission_id) + assert f"contains decimals which are not supported" in caplog.text + else: + syn.getSubmission(submission_id) + + restGET.assert_called_once_with(f"/evaluation/submission/{expected_id}") + get_entity.assert_called_once_with( + entityBundle={}, + entity=submission_id, + submission=str(expected_id), + ) + + +@pytest.mark.parametrize( + "submission_id, expected_id", + [("123", "123"), (123, "123"), ({"id": 123}, "123"), ({"id": "123"}, "123")], +) +def test_get_submission_valid_id(syn: Synapse, submission_id, expected_id) -> None: + """Test getSubmission with valid submission ID""" + run_get_submission_test(syn, submission_id, expected_id) + + +@pytest.mark.parametrize( + "submission_id, expected_id", + [ + ("123.0", "123"), + (123.0, "123"), + ({"id": 123.0}, "123"), + ({"id": "123.0"}, "123"), + ], +) +def test_get_submission_invalid_id( + syn: Synapse, submission_id, expected_id, caplog +) -> None: + """Test getSubmission with invalid submission ID""" + run_get_submission_test( + syn, submission_id, expected_id, should_warn=True, caplog=caplog + ) + + +def test_get_submission_and_submission_status_interchangeability( + syn: Synapse, caplog +) -> None: + """Test interchangeability of getSubmission and getSubmissionStatus.""" + + # Establish some dummy variables to work with + evaluation_id = 98765 + submission_id = 9745366.0 + expected_submission_id = "9745366" + + # Establish an expected return for `getSubmissionStatus` + submission_status_return = { + "id": expected_submission_id, + "etag": "000", + "status": "RECEIVED", + } + + # Establish an expected return for `getSubmission` + submission_return = { + "id": expected_submission_id, + "evaluationId": evaluation_id, + "entityId": expected_submission_id, + "versionNumber": 1, + "entityBundleJSON": json.dumps({}), + } + + # Let's mock all the API calls made within these two methods + with patch.object(syn, "restGET") as restGET, patch.object( + Submission, "getURI" + ) as get_submission_uri, patch.object( + SubmissionStatus, "getURI" + ) as get_status_uri, patch.object( + syn, "_getWithEntityBundle" + ): + get_submission_uri.return_value = ( + f"/evaluation/submission/{expected_submission_id}" + ) + get_status_uri.return_value = ( + f"/evaluation/submission/{expected_submission_id}/status" + ) + + # Establish a return for all the calls to restGET we will be making in this test + restGET.side_effect = [ + # Step 1 call to `getSubmission` + submission_return, + # Step 2 call to `getSubmissionStatus` + submission_status_return, + ] + + # Step 1: Call `getSubmission` with float ID + restGET.return_value = submission_return + submission_result = syn.getSubmission(submission_id) + + # Step 2: Call `getSubmissionStatus` with the `Submission` object from above + restGET.reset_mock() + restGET.return_value = submission_status_return + submission_status_result = syn.getSubmissionStatus(submission_result) + + # Validate that getSubmission and getSubmissionStatus are called with correct URIs + # in `getURI` calls + get_submission_uri.assert_called_once_with(expected_submission_id) + get_status_uri.assert_called_once_with(expected_submission_id) + + # Validate final output is as expected + assert ( + submission_result["id"] + == submission_status_result["id"] + == expected_submission_id + ) + + class TestTableSnapshot: def test__create_table_snapshot(self, syn: Synapse) -> None: """Testing creating table snapshots"""