Skip to content

Commit

Permalink
[SYNPY-1513] Validate input submission ID in getSubmission(...) (#1135
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jaymedina authored Oct 1, 2024
1 parent 56ebcd0 commit d9f3786
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 8 deletions.
22 changes: 14 additions & 8 deletions synapseclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
is_integer,
is_json,
require_param,
validate_submission_id,
)
from synapseclient.core.version_check import version_check

Expand Down Expand Up @@ -4807,12 +4808,15 @@ def _POST_paginated(self, uri: str, body, **kwargs):
if next_page_token is None:
break

def getSubmission(self, id, **kwargs):
def getSubmission(
self, id: typing.Union[str, int, collections.abc.Mapping], **kwargs
) -> Submission:
"""
Gets a [synapseclient.evaluation.Submission][] object by its id.
Gets a [synapseclient.evaluation.Submission][] object based on a given ID
or previous [synapseclient.evaluation.Submission][] object.
Arguments:
id: The id of the submission to retrieve
id: The ID of the submission to retrieve or a [synapseclient.evaluation.Submission][] object
Returns:
A [synapseclient.evaluation.Submission][] object
Expand All @@ -4823,7 +4827,7 @@ def getSubmission(self, id, **kwargs):
on the *downloadFile*, *downloadLocation*, and *ifcollision* parameters
"""

submission_id = id_of(id)
submission_id = validate_submission_id(id)
uri = Submission.getURI(submission_id)
submission = Submission(**self.restGET(uri))

Expand Down Expand Up @@ -4852,18 +4856,20 @@ def getSubmission(self, id, **kwargs):

return submission

def getSubmissionStatus(self, submission):
def getSubmissionStatus(
self, submission: typing.Union[str, int, collections.abc.Mapping]
) -> SubmissionStatus:
"""
Downloads the status of a Submission.
Downloads the status of a Submission given its ID or previous [synapseclient.evaluation.Submission][] object.
Arguments:
submission: The submission to lookup
submission: The submission to lookup (ID or [synapseclient.evaluation.Submission][] object)
Returns:
A [synapseclient.evaluation.SubmissionStatus][] object
"""

submission_id = id_of(submission)
submission_id = validate_submission_id(submission)
uri = SubmissionStatus.getURI(submission_id)
val = self.restGET(uri)
return SubmissionStatus(**val)
Expand Down
44 changes: 44 additions & 0 deletions synapseclient/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import hashlib
import importlib
import inspect
import logging
import numbers
import os
import platform
Expand All @@ -30,6 +31,8 @@
import requests
from opentelemetry import trace

from synapseclient.core.logging_setup import DEFAULT_LOGGER_NAME

if TYPE_CHECKING:
from synapseclient.models import File, Folder, Project

Expand All @@ -47,6 +50,10 @@

SLASH_PREFIX_REGEX = re.compile(r"\/[A-Za-z]:")

# Set up logging
LOGGER_NAME = DEFAULT_LOGGER_NAME
LOGGER = logging.getLogger(LOGGER_NAME)


def md5_for_file(
filename: str, block_size: int = 2 * MB, callback: typing.Callable = None
Expand Down Expand Up @@ -242,6 +249,43 @@ def id_of(obj: typing.Union[str, collections.abc.Mapping, numbers.Number]) -> st
raise ValueError("Invalid parameters: couldn't find id of " + str(obj))


def validate_submission_id(
submission_id: typing.Union[str, int, collections.abc.Mapping]
) -> str:
"""
Ensures that a given submission ID is either an integer or a string that
can be converted to an integer. Version notation is not supported for submission
IDs, therefore decimals are not allowed.
Arguments:
submission_id: The submission ID to validate
Returns:
The submission ID as a string
"""
if isinstance(submission_id, int):
return str(submission_id)
elif isinstance(submission_id, str) and submission_id.isdigit():
return submission_id
elif isinstance(submission_id, collections.abc.Mapping):
syn_id = _get_from_members_items_or_properties(submission_id, "id")
if syn_id is not None:
return validate_submission_id(syn_id)
else:
try:
int_submission_id = int(float(submission_id))
except ValueError:
raise ValueError(
f"Submission ID '{submission_id}' is not a valid submission ID. Please use digits only."
)
LOGGER.warning(
f"Submission ID '{submission_id}' contains decimals which are not supported. "
f"Submission ID will be converted to '{int_submission_id}'."
)
return str(int_submission_id)


def concrete_type_of(obj: collections.abc.Mapping):
"""
Return the concrete type of an object representing a Synapse entity.
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/synapseclient/core/unit_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# unit tests for utils.py

import base64
import logging
import os
import re
import tempfile
Expand Down Expand Up @@ -100,6 +101,49 @@ def __init__(self, id_attr_name: str, id: str) -> None:
assert utils.id_of(foo) == "123"


@pytest.mark.parametrize(
"input_value, expected_output, expected_warning",
[
# Test 1: Valid inputs
("123", "123", None),
(123, "123", None),
({"id": "222"}, "222", None),
# Test 2: Invalid inputs that should be corrected
(
"123.0",
"123",
"Submission ID '123.0' contains decimals which are not supported",
),
(
123.0,
"123",
"Submission ID '123.0' contains decimals which are not supported",
),
(
{"id": "999.222"},
"999",
"Submission ID '999.222' contains decimals which are not supported",
),
],
)
def test_validate_submission_id(input_value, expected_output, expected_warning, caplog):
with caplog.at_level(logging.WARNING):
assert utils.validate_submission_id(input_value) == expected_output
if expected_warning:
assert expected_warning in caplog.text
else:
assert not caplog.text


def test_validate_submission_id_letters_input() -> None:
letters_input = "syn123"
expected_error = f"Submission ID '{letters_input}' is not a valid submission ID. Please use digits only."
with pytest.raises(ValueError) as err:
utils.validate_submission_id(letters_input)

assert str(err.value) == expected_error


# TODO: Add a test for is_synapse_id_str(...)
# https://sagebionetworks.jira.com/browse/SYNPY-1425

Expand Down
152 changes: 152 additions & 0 deletions tests/unit/synapseclient/unit_test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import os
import tempfile
import typing
import urllib.request as urllib_request
import uuid
from pathlib import Path
Expand Down Expand Up @@ -59,6 +60,7 @@
)
from synapseclient.core.models.dict_object import DictObject
from synapseclient.core.upload import upload_functions
from synapseclient.evaluation import Submission, SubmissionStatus

GET_FILE_HANDLE_FOR_DOWNLOAD = (
"synapseclient.core.download.download_functions.get_file_handle_for_download_async"
Expand Down Expand Up @@ -2995,6 +2997,156 @@ def test_get_submission_with_annotations(syn: Synapse) -> None:
assert evaluation_id == response["evaluationId"]


def run_get_submission_test(
syn: Synapse,
submission_id: typing.Union[str, int],
expected_id: str,
should_warn: bool = False,
caplog=None,
) -> None:
"""
Common code for test_get_submission_valid_id and test_get_submission_invalid_id.
Generates a dummy submission dictionary for regression testing, mocks the API calls,
and validates the expected output for getSubmission. For invalid submission IDs, this
will check that a warning was logged for the user before transforming their input.
Arguments:
syn: Synapse object
submission_id: Submission ID to test
expected_id: Submission ID that should be returned
should_warn: Whether or not a warning should be logged
caplog: pytest caplog fixture
Returns:
None
"""
evaluation_id = (98765,)
submission = {
"evaluationId": evaluation_id,
"entityId": submission_id,
"versionNumber": 1,
"entityBundleJSON": json.dumps({}),
}

with patch.object(syn, "restGET") as restGET, patch.object(
syn, "_getWithEntityBundle"
) as get_entity:
restGET.return_value = submission

if should_warn:
with caplog.at_level(logging.WARNING):
syn.getSubmission(submission_id)
assert f"contains decimals which are not supported" in caplog.text
else:
syn.getSubmission(submission_id)

restGET.assert_called_once_with(f"/evaluation/submission/{expected_id}")
get_entity.assert_called_once_with(
entityBundle={},
entity=submission_id,
submission=str(expected_id),
)


@pytest.mark.parametrize(
"submission_id, expected_id",
[("123", "123"), (123, "123"), ({"id": 123}, "123"), ({"id": "123"}, "123")],
)
def test_get_submission_valid_id(syn: Synapse, submission_id, expected_id) -> None:
"""Test getSubmission with valid submission ID"""
run_get_submission_test(syn, submission_id, expected_id)


@pytest.mark.parametrize(
"submission_id, expected_id",
[
("123.0", "123"),
(123.0, "123"),
({"id": 123.0}, "123"),
({"id": "123.0"}, "123"),
],
)
def test_get_submission_invalid_id(
syn: Synapse, submission_id, expected_id, caplog
) -> None:
"""Test getSubmission with invalid submission ID"""
run_get_submission_test(
syn, submission_id, expected_id, should_warn=True, caplog=caplog
)


def test_get_submission_and_submission_status_interchangeability(
syn: Synapse, caplog
) -> None:
"""Test interchangeability of getSubmission and getSubmissionStatus."""

# Establish some dummy variables to work with
evaluation_id = 98765
submission_id = 9745366.0
expected_submission_id = "9745366"

# Establish an expected return for `getSubmissionStatus`
submission_status_return = {
"id": expected_submission_id,
"etag": "000",
"status": "RECEIVED",
}

# Establish an expected return for `getSubmission`
submission_return = {
"id": expected_submission_id,
"evaluationId": evaluation_id,
"entityId": expected_submission_id,
"versionNumber": 1,
"entityBundleJSON": json.dumps({}),
}

# Let's mock all the API calls made within these two methods
with patch.object(syn, "restGET") as restGET, patch.object(
Submission, "getURI"
) as get_submission_uri, patch.object(
SubmissionStatus, "getURI"
) as get_status_uri, patch.object(
syn, "_getWithEntityBundle"
):
get_submission_uri.return_value = (
f"/evaluation/submission/{expected_submission_id}"
)
get_status_uri.return_value = (
f"/evaluation/submission/{expected_submission_id}/status"
)

# Establish a return for all the calls to restGET we will be making in this test
restGET.side_effect = [
# Step 1 call to `getSubmission`
submission_return,
# Step 2 call to `getSubmissionStatus`
submission_status_return,
]

# Step 1: Call `getSubmission` with float ID
restGET.return_value = submission_return
submission_result = syn.getSubmission(submission_id)

# Step 2: Call `getSubmissionStatus` with the `Submission` object from above
restGET.reset_mock()
restGET.return_value = submission_status_return
submission_status_result = syn.getSubmissionStatus(submission_result)

# Validate that getSubmission and getSubmissionStatus are called with correct URIs
# in `getURI` calls
get_submission_uri.assert_called_once_with(expected_submission_id)
get_status_uri.assert_called_once_with(expected_submission_id)

# Validate final output is as expected
assert (
submission_result["id"]
== submission_status_result["id"]
== expected_submission_id
)


class TestTableSnapshot:
def test__create_table_snapshot(self, syn: Synapse) -> None:
"""Testing creating table snapshots"""
Expand Down

0 comments on commit d9f3786

Please sign in to comment.