Skip to content

Commit

Permalink
Add support for zip files
Browse files Browse the repository at this point in the history
  • Loading branch information
apdavison committed Aug 29, 2023
1 parent d5525f2 commit 698db2a
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 48 deletions.
42 changes: 29 additions & 13 deletions api/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import hashlib
from urllib.request import urlopen, urlretrieve, HTTPError
from urllib.parse import urlparse, urlunparse
import zipfile
from fastapi import HTTPException, status
import neo.io
import quantities as pq
Expand Down Expand Up @@ -58,20 +59,21 @@ def list_files_to_download(resolved_url, cache_dir, io_cls=None):
root_path, ext = os.path.splitext(main_file)
io_mode = getattr(io_cls, "rawmode", None)
if io_mode == "one-dir":
# In general, we don't know the names of the individual files
# and have no way to get a directory listing from a URL
# so we raise an exception
if io_cls.__name__ in ("PhyIO"):
# for the exceptions, resolved_url must represent a directory
raise NotImplementedError # todo: for these ios, the file names are known
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
"Cannot download files from a URL representing a directory. "
"Please provide the URL of a zip or tar archive of the directory."
if not resolved_url.endswith(".zip"):
# In general, we don't know the names of the individual files
# and have no way to get a directory listing from a URL
# so we raise an exception
if io_cls.__name__ in ("PhyIO"):
# for the exceptions, resolved_url must represent a directory
raise NotImplementedError # todo: for these ios, the file names are known
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
"Cannot download files from a URL representing a directory. "
"Please provide the URL of a zip or tar archive of the directory."
)
)
)
elif io_mode == "multi-file":
# Here the resolved_url represents a single file, with or without the file extension.
# By taking the base/root path and adding various extensions we get a list of files to download
Expand Down Expand Up @@ -153,9 +155,23 @@ def download_neo_data(url, io_cls=None):
main_path = files_to_download[0][1]
else:
main_path = os.path.join(cache_dir, main_file)
if main_path.endswith(".zip"):
main_path = get_archive_dir(main_path, cache_dir)
return main_path


def get_archive_dir(archive_path, cache_dir):
with zipfile.ZipFile(archive_path) as zf:
contents = zf.infolist()
dir_name = contents[0].filename.strip("/")
main_path = os.path.join(cache_dir, dir_name)
if not os.path.exists(main_path):
zf.extractall(path=cache_dir)
# we are assuming the zipfile unpacks to a single directory
# todo: check this is the case, and if not either raise an Exception
# or create our own directory to unpack in to
return main_path


extra_kwargs = {
"NestIO": {
Expand Down
22 changes: 20 additions & 2 deletions api/test/test_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@
"""

import os.path
import shutil
import tempfile
from urllib.request import urlretrieve
from neo.io import BrainVisionIO
from ..data_handler import get_base_url_and_path, get_cache_path, list_files_to_download
from ..data_handler import (
get_base_url_and_path,
get_cache_path,
list_files_to_download,
get_archive_dir,
)


def test_get_base_url_and_path():
Expand Down Expand Up @@ -37,12 +45,22 @@ def test_get_cache_path():
)
assert filename == "File_brainvision_1.vhdr"


def test_list_files_to_download():
url = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data/raw/master/brainvision/File_brainvision_1.vhdr"
files_to_download = list_files_to_download(url, "the_cache_dir", BrainVisionIO)
expected = [
(url, "the_cache_dir/File_brainvision_1.vhdr", True),
(url.replace(".vhdr", ".eeg"), "the_cache_dir/File_brainvision_1.eeg", True),
(url.replace(".vhdr", ".vmrk"), "the_cache_dir/File_brainvision_1.vmrk", True)
(url.replace(".vhdr", ".vmrk"), "the_cache_dir/File_brainvision_1.vmrk", True),
]
assert files_to_download == expected


def test_download_neo_data_zip():
cache_dir = tempfile.mkdtemp()
file_url = "https://data-proxy.ebrains.eu/api/v1/buckets/myspace/neo-viewer-test-data/ephy_testing_data_neuralynx_Cheetah_v5.6.3_original_data.zip"
archive_path, headers = urlretrieve(file_url, os.path.join(cache_dir, "ephy_testing_data_neuralynx_Cheetah_v5.6.3_original_data.zip"))
main_path = get_archive_dir(archive_path, cache_dir)
assert main_path == os.path.join(cache_dir, "original_data")
shutil.rmtree(cache_dir)
79 changes: 46 additions & 33 deletions api/test/test_example_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

test_client = TestClient(app)

base_data_url = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data/raw/master/"
gin_data_url = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data/raw/master/"

test_data = {
test_data_gin = {
200: {
"AsciiSpikeTrainIO": ["asciispiketrain/File_ascii_spiketrain_1.txt"],
"AxographIO": [
Expand Down Expand Up @@ -230,7 +230,8 @@
"neuralynx/Cheetah_v4.0.2/original_data",
"neuralynx/Cheetah_v5.4.0/original_data",
"neuralynx/Cheetah_v5.5.1/original_data",
"neuralynx/Cheetah_v5.6.3/original_data",
# "neuralynx/Cheetah_v5.6.3/original_data",
"https://data-proxy.ebrains.eu/api/v1/buckets/myspace/neo-viewer-test-data/ephy_testing_data_neuralynx_Cheetah_v5.6.3_original_data.zip",
"neuralynx/Cheetah_v5.7.4/original_data",
"neuralynx/Cheetah_v6.3.2/incomplete_blocks",
],
Expand Down Expand Up @@ -263,47 +264,59 @@
},
}

test_data_other = {
200: {
"NeuralynxIO": [
"https://data-proxy.ebrains.eu/api/v1/buckets/myspace/neo-viewer-test-data/ephy_testing_data_neuralynx_Cheetah_v5.6.3_original_data.zip",
]
}
}

expected_success = [
(io_cls, test_file)
for io_cls, test_files in test_data[200].items()
(io_cls, f"{gin_data_url}{test_file}")
for io_cls, test_files in test_data_gin[200].items()
for test_file in test_files
] + [
(io_cls, test_file_url)
for io_cls, test_files in test_data_other[200].items()
for test_file_url in test_files
]

expected_400_failure_block = [
(io_cls, test_file)
for io_cls, test_files in test_data[400]["block"].items()
(io_cls, f"{gin_data_url}{test_file}")
for io_cls, test_files in test_data_gin[400]["block"].items()
for test_file in test_files
]

expected_400_failure_segment = [
(io_cls, test_file)
for io_cls, test_files in test_data[400]["segment"].items()
(io_cls, f"{gin_data_url}{test_file}")
for io_cls, test_files in test_data_gin[400]["segment"].items()
for test_file in test_files
]

expected_400_failure_signal = [
(io_cls, test_file)
for io_cls, test_files in test_data[400]["signal"].items()
(io_cls, f"{gin_data_url}{test_file}")
for io_cls, test_files in test_data_gin[400]["signal"].items()
for test_file in test_files
]

expected_415_failure = [
(io_cls, test_file)
for io_cls, test_files in test_data[415].items()
(io_cls, f"{gin_data_url}{test_file}")
for io_cls, test_files in test_data_gin[415].items()
for test_file in test_files
]

expected_500_failure = [
(io_cls, test_file)
for io_cls, test_files in test_data[500].items()
(io_cls, f"{gin_data_url}{test_file}")
for io_cls, test_files in test_data_gin[500].items()
for test_file in test_files
]


@pytest.mark.parametrize("io_cls,test_file", expected_success)
def test_datasets_expected_success(io_cls, test_file):
@pytest.mark.parametrize("io_cls,test_file_url", expected_success)
def test_datasets_expected_success(io_cls, test_file_url):
encode = urllib.parse.urlencode
params = {"url": f"{base_data_url}{test_file}", "type": io_cls}
params = {"url": test_file_url, "type": io_cls}
response = test_client.get(f"/api/blockdata/?{encode(params)}")
assert response.status_code == 200

Expand All @@ -323,10 +336,10 @@ def test_datasets_expected_success(io_cls, test_file):
# todo: test irregularlysampledsignals - do we have any cases in the example data?


@pytest.mark.parametrize("io_cls,test_file", expected_400_failure_block)
def test_datasets_expected_400_failure_blockdata(io_cls, test_file):
@pytest.mark.parametrize("io_cls,test_file_url", expected_400_failure_block)
def test_datasets_expected_400_failure_blockdata(io_cls, test_file_url):
encode = urllib.parse.urlencode
params = {"url": f"{base_data_url}{test_file}", "type": io_cls}
params = {"url": test_file_url, "type": io_cls}
response = test_client.get(f"/api/blockdata/?{encode(params)}")

if response.status_code != 400:
Expand All @@ -336,10 +349,10 @@ def test_datasets_expected_400_failure_blockdata(io_cls, test_file):
pytest.xfail(response.json()["detail"])


@pytest.mark.parametrize("io_cls,test_file", expected_400_failure_segment)
def test_datasets_expected_400_failure_segmentdata(io_cls, test_file):
@pytest.mark.parametrize("io_cls,test_file_url", expected_400_failure_segment)
def test_datasets_expected_400_failure_segmentdata(io_cls, test_file_url):
encode = urllib.parse.urlencode
params = {"url": f"{base_data_url}{test_file}", "type": io_cls}
params = {"url": test_file_url, "type": io_cls}
response = test_client.get(f"/api/blockdata/?{encode(params)}")
assert response.status_code == 200

Expand All @@ -351,10 +364,10 @@ def test_datasets_expected_400_failure_segmentdata(io_cls, test_file):
pytest.xfail(response2.json()["detail"])


@pytest.mark.parametrize("io_cls,test_file", expected_400_failure_signal)
def test_datasets_expected_400_failure_analogsignaldata(io_cls, test_file):
@pytest.mark.parametrize("io_cls,test_file_url", expected_400_failure_signal)
def test_datasets_expected_400_failure_analogsignaldata(io_cls, test_file_url):
encode = urllib.parse.urlencode
params = {"url": f"{base_data_url}{test_file}", "type": io_cls}
params = {"url": test_file_url, "type": io_cls}
response = test_client.get(f"/api/blockdata/?{encode(params)}")
assert response.status_code == 200

Expand All @@ -372,21 +385,21 @@ def test_datasets_expected_400_failure_analogsignaldata(io_cls, test_file):
pytest.xfail(response3.json()["detail"])


@pytest.mark.parametrize("io_cls,test_file", expected_415_failure)
def test_datasets_expected_415_failure(io_cls, test_file):
@pytest.mark.parametrize("io_cls,test_file_url", expected_415_failure)
def test_datasets_expected_415_failure(io_cls, test_file_url):
encode = urllib.parse.urlencode
params = {"url": f"{base_data_url}{test_file}", "type": io_cls}
params = {"url": test_file_url, "type": io_cls}
response = test_client.get(f"/api/blockdata/?{encode(params)}")
if response.status_code != 415:
raise Exception("error")
if response.status_code != 200:
pytest.xfail(response.json()["detail"])


@pytest.mark.parametrize("io_cls,test_file", expected_500_failure)
def test_datasets_expected_500_failure(io_cls, test_file):
@pytest.mark.parametrize("io_cls,test_file_url", expected_500_failure)
def test_datasets_expected_500_failure(io_cls, test_file_url):
encode = urllib.parse.urlencode
params = {"url": f"{base_data_url}{test_file}", "type": io_cls}
params = {"url": test_file_url, "type": io_cls}
response = test_client.get(f"/api/blockdata/?{encode(params)}")
assert response.status_code == 200

Expand Down

0 comments on commit 698db2a

Please sign in to comment.