Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dataset configuration code as constructors #659

Merged
merged 5 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 24 additions & 83 deletions src/neuroconv/tools/nwb_helpers/_dataset_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import h5py
import numpy as np
import zarr
from hdmf import Container
from hdmf.data_utils import DataChunkIterator, DataIO, GenericDataChunkIterator
from hdmf.utils import get_data_shape
from hdmf_zarr import NWBZarrIO
Expand Down Expand Up @@ -49,91 +48,15 @@ def _is_dataset_written_to_file(
)


def _find_location_in_memory_nwbfile(current_location: str, neurodata_object: Container) -> str:
"""
Method for determining the location of a neurodata object within an in-memory NWBFile object.

Distinct from methods from other packages, such as the NWB Inspector, which rely on such files being read from disk.
"""
parent = neurodata_object.parent
if isinstance(parent, NWBFile):
# Items in defined top-level places like acquisition, intervals, etc. do not act as 'containers'
# in that they do not set the `.parent` attribute; ask if object is in their in-memory dictionaries instead
for parent_field_name, parent_field_value in parent.fields.items():
if isinstance(parent_field_value, dict) and neurodata_object.name in parent_field_value:
return parent_field_name + "/" + neurodata_object.name + "/" + current_location
return neurodata_object.name + "/" + current_location
return _find_location_in_memory_nwbfile(
current_location=neurodata_object.name + "/" + current_location, neurodata_object=parent
)


def _infer_dtype_using_data_chunk_iterator(candidate_dataset: Union[h5py.Dataset, zarr.Array]):
"""
The DataChunkIterator has one of the best generic dtype inference, though logic is hard to peel out of it.

It can fail in rare cases but not essential to our default configuration
"""
try:
return DataChunkIterator(candidate_dataset).dtype
except Exception as exception:
if str(exception) != "Data type could not be determined. Please specify dtype in DataChunkIterator init.":
raise exception
else:
return np.dtype("object")


def _get_dataset_metadata(
neurodata_object: Union[TimeSeries, DynamicTable], field_name: str, backend: Literal["hdf5", "zarr"]
) -> Union[HDF5DatasetIOConfiguration, ZarrDatasetIOConfiguration, None]:
"""Fill in the Dataset model with as many values as can be automatically detected or inferred."""
DatasetIOConfigurationClass = BACKEND_TO_DATASET_CONFIGURATION[backend]

candidate_dataset = getattr(neurodata_object, field_name)

# For now, skip over datasets already wrapped in DataIO
# Could maybe eventually support modifying chunks in place
# But setting buffer shape only possible if iterator was wrapped first
if isinstance(candidate_dataset, DataIO):
return None

dtype = _infer_dtype_using_data_chunk_iterator(candidate_dataset=candidate_dataset)
full_shape = get_data_shape(data=candidate_dataset)

if isinstance(candidate_dataset, GenericDataChunkIterator):
chunk_shape = candidate_dataset.chunk_shape
buffer_shape = candidate_dataset.buffer_shape
elif dtype != "unknown":
# TODO: eventually replace this with staticmethods on hdmf.data_utils.GenericDataChunkIterator
chunk_shape = SliceableDataChunkIterator.estimate_default_chunk_shape(
chunk_mb=10.0, maxshape=full_shape, dtype=np.dtype(dtype)
)
buffer_shape = SliceableDataChunkIterator.estimate_default_buffer_shape(
buffer_gb=0.5, chunk_shape=chunk_shape, maxshape=full_shape, dtype=np.dtype(dtype)
)
else:
pass # TODO: think on this; perhaps zarr's standalone estimator?

location = _find_location_in_memory_nwbfile(current_location=field_name, neurodata_object=neurodata_object)
dataset_info = DatasetInfo(
object_id=neurodata_object.object_id,
object_name=neurodata_object.name,
location=location,
full_shape=full_shape,
dtype=dtype,
)
dataset_configuration = DatasetIOConfigurationClass(
dataset_info=dataset_info, chunk_shape=chunk_shape, buffer_shape=buffer_shape
)
return dataset_configuration


def get_default_dataset_io_configurations(
nwbfile: NWBFile,
backend: Union[None, Literal["hdf5", "zarr"]] = None, # None for auto-detect from append mode, otherwise required
) -> Generator[DatasetIOConfiguration, None, None]:
"""
Method for automatically detecting all objects in the file that could be wrapped in a DataIO.
Generate DatasetIOConfiguration objects for wrapping NWB file objects with a specific backend.

This method automatically detects all objects in an NWB file that can be wrapped in a DataIO. It supports auto-detection
of the backend if the NWB file is in append mode, otherwise it requires a backend specification.

Parameters
----------
Expand All @@ -147,6 +70,8 @@ def get_default_dataset_io_configurations(
DatasetIOConfiguration
A summary of each detected object that can be wrapped in a DataIO.
"""
DatasetIOConfigurationClass = BACKEND_TO_DATASET_CONFIGURATION[backend]

if backend is None and nwbfile.read_io is None:
raise ValueError(
"Keyword argument `backend` (either 'hdf5' or 'zarr') must be specified if the `nwbfile` was not "
Expand Down Expand Up @@ -185,7 +110,15 @@ def get_default_dataset_io_configurations(
):
continue # skip

yield _get_dataset_metadata(neurodata_object=column, field_name="data", backend=backend)
# Skip over columns that are already wrapped in DataIO
if isinstance(candidate_dataset, DataIO):
continue

dataset_io_configuration = DatasetIOConfigurationClass.from_neurodata_object(
neurodata_object=column, field_name="data"
)

yield dataset_io_configuration
else:
# Primarily for TimeSeries, but also any extended class that has 'data' or 'timestamps'
# The most common example of this is ndx-events Events/LabeledEvents types
Expand All @@ -201,8 +134,16 @@ def get_default_dataset_io_configurations(
):
continue # skip

# Skip over datasets that are already wrapped in DataIO
if isinstance(candidate_dataset, DataIO):
continue

# Edge case of in-memory ImageSeries with external mode; data is in fields and is empty array
if isinstance(candidate_dataset, np.ndarray) and candidate_dataset.size == 0:
continue # skip

yield _get_dataset_metadata(neurodata_object=time_series, field_name=field_name, backend=backend)
dataset_io_configuration = DatasetIOConfigurationClass.from_neurodata_object(
neurodata_object=time_series, field_name=field_name
)

yield dataset_io_configuration
83 changes: 83 additions & 0 deletions src/neuroconv/tools/nwb_helpers/_models/_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,50 @@
import h5py
import numcodecs
import numpy as np
import zarr
from hdmf import Container
from hdmf.container import DataIO
from hdmf.data_utils import DataChunkIterator, DataIO, GenericDataChunkIterator
from hdmf.utils import get_data_shape
from pydantic import BaseModel, Field, root_validator
from pynwb import NWBHDF5IO, NWBFile

from ...hdmf import SliceableDataChunkIterator


def _find_location_in_memory_nwbfile(current_location: str, neurodata_object: Container) -> str:
"""
Method for determining the location of a neurodata object within an in-memory NWBFile object.

Distinct from methods from other packages, such as the NWB Inspector, which rely on such files being read from disk.
"""
parent = neurodata_object.parent
if isinstance(parent, NWBFile):
# Items in defined top-level places like acquisition, intervals, etc. do not act as 'containers'
# in that they do not set the `.parent` attribute; ask if object is in their in-memory dictionaries instead
for parent_field_name, parent_field_value in parent.fields.items():
if isinstance(parent_field_value, dict) and neurodata_object.name in parent_field_value:
return parent_field_name + "/" + neurodata_object.name + "/" + current_location
return neurodata_object.name + "/" + current_location
return _find_location_in_memory_nwbfile(
current_location=neurodata_object.name + "/" + current_location, neurodata_object=parent
)


def _infer_dtype_using_data_chunk_iterator(candidate_dataset: Union[h5py.Dataset, zarr.Array]):
"""
The DataChunkIterator has one of the best generic dtype inference, though logic is hard to peel out of it.

It can fail in rare cases but not essential to our default configuration
"""
try:
data_type = DataChunkIterator(candidate_dataset).dtype
return data_type
except Exception as exception:
if str(exception) != "Data type could not be determined. Please specify dtype in DataChunkIterator init.":
raise exception
else:
return np.dtype("object")


class DatasetInfo(BaseModel):
Expand Down Expand Up @@ -61,6 +103,22 @@ def __init__(self, **values):
values.update(dataset_name=dataset_name)
super().__init__(**values)

@classmethod
def from_neurodata_object(cls, neurodata_object: Container, field_name: str) -> "DatasetInfo":
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
location = _find_location_in_memory_nwbfile(current_location=field_name, neurodata_object=neurodata_object)
candidate_dataset = getattr(neurodata_object, field_name)

full_shape = get_data_shape(data=candidate_dataset)
dtype = _infer_dtype_using_data_chunk_iterator(candidate_dataset=candidate_dataset)

return cls(
object_id=neurodata_object.object_id,
object_name=neurodata_object.name,
location=location,
full_shape=full_shape,
dtype=dtype,
)


class DatasetIOConfiguration(BaseModel, ABC):
"""A data model for configuring options about an object that will become a HDF5 or Zarr Dataset in the file."""
Expand Down Expand Up @@ -182,6 +240,31 @@ def get_data_io_kwargs(self) -> Dict[str, Any]:
"""
raise NotImplementedError

@classmethod
def from_neurodata_object(cls, neurodata_object: Container, field_name: str) -> "DatasetIOConfiguration":
candidate_dataset = getattr(neurodata_object, field_name)

dataset_info = DatasetInfo.from_neurodata_object(neurodata_object=neurodata_object, field_name=field_name)

dtype = dataset_info.dtype
full_shape = dataset_info.full_shape

if isinstance(candidate_dataset, GenericDataChunkIterator):
chunk_shape = candidate_dataset.chunk_shape
buffer_shape = candidate_dataset.buffer_shape
elif dtype != "unknown":
# TODO: eventually replace this with staticmethods on hdmf.data_utils.GenericDataChunkIterator
chunk_shape = SliceableDataChunkIterator.estimate_default_chunk_shape(
chunk_mb=10.0, maxshape=full_shape, dtype=np.dtype(dtype)
)
buffer_shape = SliceableDataChunkIterator.estimate_default_buffer_shape(
buffer_gb=0.5, chunk_shape=chunk_shape, maxshape=full_shape, dtype=np.dtype(dtype)
)
else:
pass

return cls(dataset_info=dataset_info, chunk_shape=chunk_shape, buffer_shape=buffer_shape)


class BackendConfiguration(BaseModel):
"""A model for matching collections of DatasetConfigurations to a specific backend."""
Expand Down
Loading