Skip to content

Commit

Permalink
Merge pull request #250 from wiwski/feat/stream-zip-files
Browse files Browse the repository at this point in the history
download share directory and stream back a zip on the fly
  • Loading branch information
wiwski authored Dec 6, 2023
2 parents 5377208 + 54400e0 commit 854fe73
Show file tree
Hide file tree
Showing 12 changed files with 495 additions and 19 deletions.
69 changes: 67 additions & 2 deletions api/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from datetime import datetime
import pathlib

from fastapi import APIRouter, Depends, HTTPException, Path
from fastapi.responses import JSONResponse

from fastapi.responses import JSONResponse, StreamingResponse
from auth import (
generate_token_for_path,
verify_path_permission,
User,
get_current_user,
verify_is_euphrosyne_backend,
Expand All @@ -18,7 +20,9 @@
RunDataNotFound,
validate_project_document_file_path,
validate_run_data_file_path,
extract_info_from_path,
)
from clients.azure.stream import stream_zip_from_azure_files
from dependencies import get_storage_azure_client

router = APIRouter(prefix="/data", tags=["data"])
Expand Down Expand Up @@ -53,6 +57,45 @@ def list_project_documents(
)


@router.get(
"/run-data-zip",
status_code=200,
dependencies=[Depends(verify_path_permission)],
)
def zip_project_run_data(
path: pathlib.Path,
azure_client: DataAzureClient = Depends(get_storage_azure_client),
):
"""
Stream a zip file containing all the run data files. The path must point
to a run data directory (raw_data, processed_data, ...).
Returns:
StreamingResponse: A streaming response containing the zip file.
"""
try:
path_info = extract_info_from_path(path)
except IncorrectDataFilePath as error:
raise HTTPException(
status_code=422,
detail=[{"loc": ["query", "path"], "msg": error.message}],
) from error
try:
files = azure_client.iter_project_run_files(
path_info["project_name"], path_info["run_name"], path_info.get("data_type")
)
except RunDataNotFound:
raise HTTPException(status_code=404, detail="Run data not found.")
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
return StreamingResponse(
stream_zip_from_azure_files(files),
media_type="application/zip",
headers={
"Content-Disposition": f"attachment; filename={path_info['run_name']}-{timestamp}.zip"
},
)


@router.get(
"/{project_name}/runs/{run_name}/{data_type}",
status_code=200,
Expand Down Expand Up @@ -144,6 +187,28 @@ def generate_project_documents_upload_shared_access_signature(
return {"url": url}


@router.get(
"/{project_name}/token",
status_code=200,
dependencies=[Depends(verify_project_membership)],
)
def generate_signed_url_for_path(
path: pathlib.Path,
current_user: User = Depends(get_current_user),
):
"""Return a auth token for a given path. It is used to grant access to project data via
a GET request without revealing jwt access token. It is like an Azure SAS token."""
try:
validate_run_data_file_path(path, current_user)
except IncorrectDataFilePath as error:
raise HTTPException(
status_code=422,
detail=[{"loc": ["query", "path"], "msg": error.message}],
) from error
token = generate_token_for_path(str(path))
return {"token": token}


@router.post(
"/{project_name}/init",
status_code=204,
Expand Down
39 changes: 36 additions & 3 deletions auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Any, Optional

from dotenv import load_dotenv
from fastapi import Depends, HTTPException, status
Expand All @@ -24,6 +24,7 @@

api_key_header_auth = APIKeyHeader(name="X-API-KEY", auto_error=False)
api_key_query_auth = APIKeyQuery(name="api_key")
token_query_auth = APIKeyQuery(name="token")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)


Expand Down Expand Up @@ -54,6 +55,8 @@ async def get_current_user(
return User(id=0, projects=[], is_admin=True)
raise JWT_CREDENTIALS_EXCEPTION
payload = _decode_jwt(jwt_token)
if not payload.get("user_id"):
raise JWT_CREDENTIALS_EXCEPTION
return User(
id=payload.get("user_id"),
projects=payload.get("projects"),
Expand Down Expand Up @@ -96,12 +99,42 @@ def verify_has_azure_permission(api_key: Optional[str] = Depends(api_key_query_a
raise HTTPException(status_code=403, detail="Not allowed")


def verify_path_permission(path: str, token: str | None = Depends(token_query_auth)):
if not token:
raise JWT_CREDENTIALS_EXCEPTION
payload = _decode_jwt(token)
if not payload.get("path"):
raise JWT_CREDENTIALS_EXCEPTION
if payload["path"] != path:
raise HTTPException(status_code=403, detail="Token not allowed for this path")


def generate_token_for_path(path: str):
"""
Generates a JWT token for a specific path.
Args:
path (str): The path for which the token is generated.
Returns:
str: The generated JWT token.
"""
return _generate_jwt_token(
payload={
"path": path,
}
)


def _generate_jwt_token(payload: dict[str, Any]):
return jwt.encode(payload, os.environ["JWT_SECRET_KEY"], algorithm=ALGORITHM)


def _decode_jwt(jwt_token: str):
try:
secret_key = os.environ["JWT_SECRET_KEY"]
payload = jwt.decode(jwt_token, secret_key, algorithms=[ALGORITHM])
except JWTError as error:
raise JWT_CREDENTIALS_EXCEPTION from error
if not payload.get("user_id"):
raise JWT_CREDENTIALS_EXCEPTION
return payload
80 changes: 71 additions & 9 deletions clients/azure/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class ProjectFile(BaseModel):
name: str
last_modified: Optional[datetime] = None
size: int
path: Optional[str]
path: str


class AzureFileShareFile(io.BytesIO):
Expand Down Expand Up @@ -177,6 +177,27 @@ def get_run_files(
except ResourceNotFoundError as error:
raise RunDataNotFound from error

def iter_project_run_files(
self, project_name: str, run_name: str, data_type: RunDataTypeType | None = None
):
"""
Yield files from a run directory.
Args:
project_name (str): The name of the project.
run_name (str): The name of the run.
data_type (RunDataTypeType | None, optional): The data type folder in the run.
Returns:
Iterator[ProjectFile]: An iterator of ProjectFile objects representing
the files in the run directory.
"""
projects_path_prefix = _get_projects_path()
dir_path = os.path.join(projects_path_prefix, project_name, "runs", run_name)
if data_type:
os.path.join(data_type)
return self._iter_directory_files(dir_path)

def download_run_file(
self,
filepath: str,
Expand Down Expand Up @@ -405,16 +426,46 @@ def _list_files_recursive(
else:
yield ProjectFile(**{**file, "path": path})

def _iter_directory_files(self, dir_path: str):
"""Stream a directory from the Fileshare."""
dir_client = ShareDirectoryClient.from_connection_string(
conn_str=self._storage_connection_string,
share_name=self.share_name,
directory_path=dir_path,
)
if not dir_client.exists():
raise RunDataNotFound()
files = self._list_files_recursive(dir_path)
for file in files:
file_client = ShareFileClient.from_connection_string(
conn_str=self._storage_connection_string,
share_name=self.share_name,
file_path=file.path,
)
yield file_client.download_file()


def extract_info_from_path(path: Path):
"""Extract project and run name from a path."""
_validate_run_data_file_path_regex(path)
projects_path_prefix = _get_projects_path()
path_without_prefix = Path(str(path).replace(projects_path_prefix + "/", "", 1))
info: dict[str, str | None] = {
"project_name": None,
"run_name": None,
"data_type": None,
}
if len(path_without_prefix.parts) > 0:
info["project_name"] = path_without_prefix.parts[0]
if len(path_without_prefix.parts) > 2:
info["run_name"] = path_without_prefix.parts[2]
if len(path_without_prefix.parts) > 3:
info["data_type"] = path_without_prefix.parts[3]
return info


def validate_run_data_file_path(path: Path, current_user: User):
if not re.match(
rf"^{_get_projects_path()}\/[\w\- ]+\/runs\/[\w\- ]+\/(raw_data|processed_data|HDF5)\/", # noqa: E501
str(path),
):
# pylint: disable=line-too-long
raise IncorrectDataFilePath(
"path must start with {projects_path_prefix}/<project_name>/runs/<run_name>/(processed_data|raw_data|HDF5)/" # noqa: E501
)
_validate_run_data_file_path_regex(path)
_validate_project_file_path(path, current_user)


Expand Down Expand Up @@ -448,3 +499,14 @@ def _generate_base_dir_path(project_name: str, run_name: str = ""):
if run_name:
base_dir_path = os.path.join(base_dir_path, "runs", run_name)
return base_dir_path


def _validate_run_data_file_path_regex(path: Path):
if not re.match(
rf"^{_get_projects_path()}\/[\w\- ]+\/runs\/[\w\- ]+\/(raw_data|processed_data|HDF5)", # noqa: E501
str(path),
):
# pylint: disable=line-too-long
raise IncorrectDataFilePath(
"path must start with {projects_path_prefix}/<project_name>/runs/<run_name>/(processed_data|raw_data|HDF5)/" # noqa: E501
)
58 changes: 58 additions & 0 deletions clients/azure/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Streaming zip files from Azure file shares.
This module provides functions for streaming and zipping files from Azure file shares.
"""

from typing import Any, Generator
from stat import S_IFREG

from azure.storage.fileshare._download import StorageStreamDownloader

from stream_zip import stream_zip, ZIP_AUTO


def iter_files_zip_attr(files: Generator[StorageStreamDownloader, Any, None]):
"""
Iterates over a generator of `StorageStreamDownloader` objects and
yields file attributes required for zipping.
Args:
files: A generator of `StorageStreamDownloader` objects.
Yields:
A tuple containing the file attributes required for zipping:
- File name
- Last modified timestamp
- File mode (S_IFREG | 0o600)
- File size (ZIP_AUTO)
- File contents (generator)
"""

def contents(stream_obj: StorageStreamDownloader):
for chunk in stream_obj.chunks():
yield chunk

for stream_obj in files:
yield (
stream_obj.name,
stream_obj.properties.last_modified,
S_IFREG | 0o600,
ZIP_AUTO(stream_obj.size),
contents(stream_obj),
)


def stream_zip_from_azure_files(files: Generator[StorageStreamDownloader, Any, None]):
"""
Streams a zip file from Azure files.
Args:
files: A generator of `StorageStreamDownloader` objects.
Returns:
A streaming zip file.
"""
return stream_zip(iter_files_zip_attr(files))
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ exclude = '''
)/
'''

[tool.mypy]
mypy_path = "stubs"

[[tool.mypy.overrides]]
module="azure.mgmt.*"
ignore_missing_imports = true
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pytest==7.4.3
requests==2.31.0
ruff==0.1.6
sentry-sdk[fastapi]==1.38.0
stream-zip==0.0.67
types-python-jose==3.3.4.8
types-python-slugify==8.0.0.3
types-requests==2.31.0.10
Expand Down
33 changes: 33 additions & 0 deletions stubs/stream_zip.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import datetime
from typing import Any, Callable, Generator

def NO_COMPRESSION_32(uncompressed_size, crc_32): ...
def NO_COMPRESSION_64(uncompressed_size, crc_32): ...
def ZIP_32(offset, default_get_compressobj): ...
def ZIP_64(offset, default_get_compressobj): ...
def ZIP_AUTO(uncompressed_size, level: int = ...): ...
def stream_zip(
files: Generator[
tuple[str, datetime.datetime, int, Callable, Generator[bytes, Any, None]],
Any,
None,
]
| tuple[str, datetime.datetime, int, Callable, Generator[bytes, Any, None]]
| tuple[str, datetime.datetime, int, Callable, bytes],
chunk_size: int = ...,
get_compressobj=...,
extended_timestamps: bool = ...,
): ...

class ZipError(Exception): ...
class ZipValueError(ZipError, ValueError): ...
class ZipIntegrityError(ZipValueError): ...
class CRC32IntegrityError(ZipIntegrityError): ...
class UncompressedSizeIntegrityError(ZipIntegrityError): ...
class ZipOverflowError(ZipValueError, OverflowError): ...
class UncompressedSizeOverflowError(ZipOverflowError): ...
class CompressedSizeOverflowError(ZipOverflowError): ...
class CentralDirectorySizeOverflowError(ZipOverflowError): ...
class OffsetOverflowError(ZipOverflowError): ...
class CentralDirectoryNumberOfEntriesOverflowError(ZipOverflowError): ...
class NameLengthOverflowError(ZipOverflowError): ...
Loading

0 comments on commit 854fe73

Please sign in to comment.