Skip to content

Commit

Permalink
Add/update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jjnesbitt committed Nov 4, 2024
1 parent fc0f1c6 commit 3c7ad71
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 34 deletions.
37 changes: 34 additions & 3 deletions dandiapi/api/tests/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ def test_asset_total_size(

assert Asset.total_size() == asset_blob.size + zarr_archive.size

# TODO: add testing for embargoed zarr added, whenever embargoed zarrs
# supported, ATM they are not and tested by test_zarr_rest_create_embargoed_dandiset


@pytest.mark.django_db
def test_asset_full_metadata(draft_asset_factory):
Expand Down Expand Up @@ -1048,6 +1045,40 @@ def test_asset_create_existing_path(api_client, user, draft_version, asset_blob,
assert resp.status_code == 409


# Must use transaction=True as the tested function uses a transaction on_commit hook
@pytest.mark.django_db(transaction=True)
def test_asset_create_on_open_dandiset_embargoed_asset_blob(
api_client, user, draft_version, embargoed_asset_blob, mocker
):
mocked = mocker.patch('dandiapi.api.services.asset.remove_asset_blob_embargoed_tag_task.delay')

assert embargoed_asset_blob.embargoed

assign_perm('owner', user, draft_version.dandiset)
api_client.force_authenticate(user)

path = 'test/create/asset.txt'
metadata = {
'encodingFormat': 'application/x-nwb',
'path': path,
}

resp = api_client.post(
f'/api/dandisets/{draft_version.dandiset.identifier}'
f'/versions/{draft_version.version}/assets/',
{'metadata': metadata, 'blob_id': embargoed_asset_blob.blob_id},
format='json',
)
assert resp.status_code == 200

# Check that asset blob is no longer embargoed
embargoed_asset_blob.refresh_from_db()
assert not embargoed_asset_blob.embargoed

# Check that tag removal function called
mocked.assert_called_once()


@pytest.mark.django_db
def test_asset_rest_rename(api_client, user, draft_version, asset_blob):
assign_perm('owner', user, draft_version.dandiset)
Expand Down
122 changes: 92 additions & 30 deletions dandiapi/api/tests/test_unembargo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from dandiapi.api.models.version import Version
from dandiapi.api.services.embargo import (
AssetBlobEmbargoedError,
_remove_dandiset_asset_blob_embargo_tags,
_delete_zarr_object_tags,
_remove_dandiset_embargo_tags,
remove_asset_blob_embargoed_tag,
unembargo_dandiset,
)
Expand All @@ -19,21 +20,15 @@
DandisetActiveUploadsError,
)
from dandiapi.api.services.exceptions import DandiError
from dandiapi.api.storage import get_boto_client
from dandiapi.api.tasks import unembargo_dandiset_task
from dandiapi.zarr.models import ZarrArchive, ZarrArchiveStatus, zarr_s3_path
from dandiapi.zarr.tasks import ingest_zarr_archive

if TYPE_CHECKING:
from dandiapi.api.models.asset import AssetBlob


@pytest.mark.django_db
def test_remove_asset_blob_embargoed_tag_fails_on_embargod(embargoed_asset_blob, asset_blob):
with pytest.raises(AssetBlobEmbargoedError):
remove_asset_blob_embargoed_tag(embargoed_asset_blob)

# Test that error not raised on non-embargoed asset blob
remove_asset_blob_embargoed_tag(asset_blob)


@pytest.mark.django_db
def test_kickoff_dandiset_unembargo_dandiset_not_embargoed(
api_client, user, dandiset_factory, draft_version_factory
Expand Down Expand Up @@ -125,15 +120,13 @@ def test_unembargo_dandiset_uploads_exist(draft_version_factory, upload_factory,


@pytest.mark.django_db
def test_remove_dandiset_asset_blob_embargo_tags_chunks(
def test_remove_dandiset_embargo_tags_chunks(
draft_version_factory,
asset_factory,
embargoed_asset_blob_factory,
mocker,
):
delete_asset_blob_tags_mock = mocker.patch(
'dandiapi.api.services.embargo._delete_asset_blob_tags'
)
delete_asset_blob_tags_mock = mocker.patch('dandiapi.api.services.embargo._delete_object_tags')
chunk_size = mocker.patch('dandiapi.api.services.embargo.ASSET_BLOB_TAG_REMOVAL_CHUNK_SIZE', 2)

draft_version: Version = draft_version_factory(
Expand All @@ -144,39 +137,99 @@ def test_remove_dandiset_asset_blob_embargo_tags_chunks(
asset = asset_factory(blob=embargoed_asset_blob_factory())
draft_version.assets.add(asset)

_remove_dandiset_asset_blob_embargo_tags(dandiset=ds)
_remove_dandiset_embargo_tags(dandiset=ds)

# Assert that _delete_asset_blob_tags was called chunk_size +1 times, to ensure that it works
# Assert that _delete_object_tags was called chunk_size +1 times, to ensure that it works
# correctly across chunks
assert len(delete_asset_blob_tags_mock.mock_calls) == chunk_size + 1


@pytest.mark.django_db
def test_delete_asset_blob_tags_fails(
def test_remove_dandiset_embargo_tags_fails_remove_tags(
draft_version_factory,
asset_factory,
embargoed_asset_blob_factory,
mocker,
):
mocker.patch('dandiapi.api.services.embargo._delete_asset_blob_tags', side_effect=ValueError)
# Patch function to raise error when called
mocker.patch('dandiapi.api.services.embargo._delete_object_tags', side_effect=ValueError)

# Create dandiset/version and add assets
draft_version: Version = draft_version_factory(
dandiset__embargo_status=Dandiset.EmbargoStatus.UNEMBARGOING
)
ds: Dandiset = draft_version.dandiset
asset = asset_factory(blob=embargoed_asset_blob_factory())
draft_version.assets.add(asset)
for _ in range(2):
asset = asset_factory(blob=embargoed_asset_blob_factory())
draft_version.assets.add(asset)

# Remove tags
with pytest.raises(AssetTagRemovalError):
_remove_dandiset_embargo_tags(dandiset=ds)


@pytest.mark.django_db
def test_remove_asset_blob_embargoed_tag_fails_on_embargod(embargoed_asset_blob, asset_blob):
with pytest.raises(AssetBlobEmbargoedError):
remove_asset_blob_embargoed_tag(embargoed_asset_blob)

# Test that error not raised on non-embargoed asset blob
remove_asset_blob_embargoed_tag(asset_blob)


@pytest.mark.django_db
def test_remove_asset_blob_embargoed_tag(asset_blob, mocker):
mocked_func = mocker.patch('dandiapi.api.services.embargo._delete_object_tags')
remove_asset_blob_embargoed_tag(asset_blob)
mocked_func.assert_called_once()


@pytest.mark.django_db
def test_delete_zarr_object_tags_fails_remove_tags(zarr_archive, zarr_file_factory, mocker):
mocked = mocker.patch(
'dandiapi.api.services.embargo._delete_object_tags', side_effect=ValueError
)
files = [zarr_file_factory(zarr_archive) for _ in range(2)]

# Check that if an exception within `_delete_asset_blob_tags` is raised, it's propagated upwards
# as an AssetTagRemovalError
with pytest.raises(AssetTagRemovalError):
_remove_dandiset_asset_blob_embargo_tags(dandiset=ds)
_delete_zarr_object_tags(client=get_boto_client(), zarr=zarr_archive.zarr_id)

# Check that each file was called 4 times total. Once initially, and 3 retries
assert mocked.call_count == 4 * len(files)
for file in files:
calls = [
c
for c in mocked.mock_calls
if c.kwargs['blob']
== zarr_s3_path(zarr_id=zarr_archive.zarr_id, zarr_path=str(file.path))
]
assert len(calls) == 4


@pytest.mark.django_db
def test_delete_zarr_object_tags(zarr_archive, zarr_file_factory, mocker):
mocked_delete_object_tags = mocker.patch('dandiapi.api.services.embargo._delete_object_tags')

# Create files
files = [zarr_file_factory(zarr_archive) for _ in range(10)]

# This should call the mocked function for each file
_delete_zarr_object_tags(client=get_boto_client(), zarr=zarr_archive.zarr_id)

assert mocked_delete_object_tags.call_count == len(files)

called_blobs = sorted([call.kwargs['blob'] for call in mocked_delete_object_tags.mock_calls])
file_bucket_paths = sorted([zarr_archive.s3_path(str(file.path)) for file in files])
assert called_blobs == file_bucket_paths


@pytest.mark.django_db
def test_unembargo_dandiset(
draft_version_factory,
asset_factory,
embargoed_asset_blob_factory,
embargoed_zarr_archive_factory,
zarr_file_factory,
mocker,
mailoutbox,
user_factory,
Expand All @@ -190,20 +243,29 @@ def test_unembargo_dandiset(
assign_perm('owner', user, ds)

embargoed_blob: AssetBlob = embargoed_asset_blob_factory()
asset = asset_factory(blob=embargoed_blob)
draft_version.assets.add(asset)
assert embargoed_blob.embargoed
draft_version.assets.add(asset_factory(blob=embargoed_blob))

zarr_archive: ZarrArchive = embargoed_zarr_archive_factory(
dandiset=ds, status=ZarrArchiveStatus.UPLOADED
)
for _ in range(5):
zarr_file_factory(zarr_archive)
ingest_zarr_archive(zarr_id=zarr_archive.zarr_id)
zarr_archive.refresh_from_db()
draft_version.assets.add(asset_factory(zarr=zarr_archive, blob=None))

assert all(asset.is_embargoed for asset in draft_version.assets.all())

# Patch this function to check if it's been called, since we can't test the tagging directly
patched = mocker.patch('dandiapi.api.services.embargo._delete_asset_blob_tags')
patched = mocker.patch('dandiapi.api.services.embargo._delete_object_tags')

unembargo_dandiset(ds, owners[0])
patched.assert_called_once()

embargoed_blob.refresh_from_db()
assert patched.call_count == 1 + zarr_archive.file_count
assert not any(asset.is_embargoed for asset in draft_version.assets.all())

ds.refresh_from_db()
draft_version.refresh_from_db()
assert not embargoed_blob.embargoed
assert ds.embargo_status == Dandiset.EmbargoStatus.OPEN
assert (
draft_version.metadata['access'][0]['status']
Expand Down
3 changes: 2 additions & 1 deletion dandiapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
UploadFactory,
UserFactory,
)
from dandiapi.zarr.tests.factories import ZarrArchiveFactory
from dandiapi.zarr.tests.factories import EmbargoedZarrArchiveFactory, ZarrArchiveFactory
from dandiapi.zarr.tests.utils import upload_zarr_file

if TYPE_CHECKING:
Expand All @@ -47,6 +47,7 @@

# zarr app
register(ZarrArchiveFactory)
register(EmbargoedZarrArchiveFactory, _name='embargoed_zarr_archive')


# Register zarr file/directory factories
Expand Down
4 changes: 4 additions & 0 deletions dandiapi/zarr/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ class Meta:
zarr_id = factory.Faker('uuid4')
name = factory.Faker('catch_phrase')
dandiset = factory.SubFactory(DandisetFactory)


class EmbargoedZarrArchiveFactory(ZarrArchiveFactory):
embargoed = True
31 changes: 31 additions & 0 deletions dandiapi/zarr/tests/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,37 @@ def test_zarr_rest_get(authenticated_api_client, storage, zarr_archive_factory,
}


@pytest.mark.django_db
def test_zarr_rest_get_embargoed(authenticated_api_client, user, embargoed_zarr_archive):
assert user not in embargoed_zarr_archive.dandiset.owners

resp = authenticated_api_client.get(f'/api/zarr/{embargoed_zarr_archive.zarr_id}/')
assert resp.status_code == 404

embargoed_zarr_archive.dandiset.set_owners([user])
resp = authenticated_api_client.get(f'/api/zarr/{embargoed_zarr_archive.zarr_id}/')
assert resp.status_code == 200


@pytest.mark.django_db
def test_zarr_rest_list_embargoed(authenticated_api_client, user, dandiset, zarr_archive_factory):
# Create some embargoed and some open zarrs
open_zarrs = [zarr_archive_factory() for _ in range(3)]
embargoed_zarrs = [zarr_archive_factory(embargoed=True, dandiset=dandiset) for _ in range(3)]

# Assert only open zarrs are returned
zarrs = authenticated_api_client.get('/api/zarr/').json()['results']
assert sorted(z['zarr_id'] for z in zarrs) == sorted(z.zarr_id for z in open_zarrs)

# Assert that all zarrs returned when user has access to embargoed zarrs
dandiset.set_owners([user])
zarrs = authenticated_api_client.get('/api/zarr/').json()['results']
assert len(zarrs) == len(open_zarrs + embargoed_zarrs)
assert sorted(z['zarr_id'] for z in zarrs) == sorted(
z.zarr_id for z in (open_zarrs + embargoed_zarrs)
)


@pytest.mark.django_db
def test_zarr_rest_list_filter(authenticated_api_client, dandiset_factory, zarr_archive_factory):
# Create dandisets and zarrs
Expand Down

0 comments on commit 3c7ad71

Please sign in to comment.