diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 4f259eb5..618fa283 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -29,12 +29,14 @@ jobs: git config --global user.email "CI@example.com" git config --global user.name "CI Almighty" pip install wheel # needed for scanimage - - name: Test minimal installation - run: pip install . - - name: Test full installation - run: pip install .[full] - - name: Install testing requirements (-e needed for codecov report) - run: pip install -e .[test] + + - name: Install roiextractors with minimal requirements + run: pip install .[test] + - name: Run minimal tests + run: pytest tests/test_internals -n auto --dist loadscope + + - name: Test full installation (-e needed for codecov report) + run: pip install -e .[full] - name: Get ophys_testing_data current head hash id: ophys diff --git a/requirements-minimal.txt b/requirements-minimal.txt index dc5dbf11..59234aae 100644 --- a/requirements-minimal.txt +++ b/requirements-minimal.txt @@ -1,6 +1,5 @@ h5py>=2.10.0 pynwb>=2.0.1 -spikeextractors>=0.9.0 tqdm>=4.48.2 lazy_ops>=0.2.0 dill>=0.3.2 diff --git a/requirements-testing.txt b/requirements-testing.txt index 67abac62..070ae743 100644 --- a/requirements-testing.txt +++ b/requirements-testing.txt @@ -1,3 +1,4 @@ pytest pytest-cov parameterized==0.8.1 +spikeextractors>=0.9.10 diff --git a/setup.py b/setup.py index 97e4ecb7..9b9bbf8e 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,5 @@ from pathlib import Path from setuptools import setup, find_packages -from copy import copy from shutil import copy as copy_file @@ -11,9 +10,8 @@ install_requires = f.readlines() with open(root / "requirements-full.txt") as f: full_dependencies = f.readlines() -testing_dependencies = copy(full_dependencies) with open(root / "requirements-testing.txt") as f: - testing_dependencies.extend(f.readlines()) + testing_dependencies = f.readlines() extras_require = dict(full=full_dependencies, test=testing_dependencies) # Create a local copy for the gin test configuration file based on the master file `base_gin_test_config.json` diff --git a/src/roiextractors/example_datasets/toy_example.py b/src/roiextractors/example_datasets/toy_example.py index c534e737..cc135ff5 100644 --- a/src/roiextractors/example_datasets/toy_example.py +++ b/src/roiextractors/example_datasets/toy_example.py @@ -1,5 +1,4 @@ import numpy as np -import spikeextractors as se from ..extractors.numpyextractors import ( NumpyImagingExtractor, @@ -110,6 +109,8 @@ def toy_example( The output segmentation extractor """ + import spikeextractors as se + # generate ROIs num_rois = int(num_rois) roi_pixels, im, means = _generate_rois( diff --git a/src/roiextractors/extraction_tools.py b/src/roiextractors/extraction_tools.py index 283e2c71..54084e95 100644 --- a/src/roiextractors/extraction_tools.py +++ b/src/roiextractors/extraction_tools.py @@ -1,14 +1,13 @@ from functools import wraps from pathlib import Path from typing import Union, Tuple -from dataclasses import dataclass, field +from dataclasses import dataclass import lazy_ops import scipy import numpy as np from numpy.typing import ArrayLike, DTypeLike from tqdm import tqdm -from spikeextractors.extraction_tools import cast_start_end_frame try: import h5py @@ -22,6 +21,13 @@ else: from scipy.io.matlab.mio5_params import mat_struct + HAVE_Scipy = True +except AttributeError: + if hasattr(scipy, "io") and hasattr(scipy.io.matlab, "mat_struct"): + from scipy.io import mat_struct + else: + from scipy.io.matlab.mio5_params import mat_struct + HAVE_Scipy = True except ImportError: HAVE_Scipy = False @@ -306,6 +312,25 @@ def corrected_args(imaging, frame_idxs, channel=0): return corrected_args +def _cast_start_end_frame(start_frame, end_frame): + if isinstance(start_frame, float): + start_frame = int(start_frame) + elif isinstance(start_frame, (int, np.integer, type(None))): + start_frame = start_frame + else: + raise ValueError("start_frame must be an int, float (not infinity), or None") + if isinstance(end_frame, float) and np.isfinite(end_frame): + end_frame = int(end_frame) + elif isinstance(end_frame, (int, np.integer, type(None))): + end_frame = end_frame + # else end_frame is infinity (accepted for get_unit_spike_train) + if start_frame is not None: + start_frame = int(start_frame) + if end_frame is not None and np.isfinite(end_frame): + end_frame = int(end_frame) + return start_frame, end_frame + + def check_get_videos_args(func): @wraps(func) def corrected_args(imaging, start_frame=None, end_frame=None, channel=0): @@ -325,7 +350,7 @@ def corrected_args(imaging, start_frame=None, end_frame=None, channel=0): end_frame = imaging.get_num_frames() assert end_frame - start_frame > 0, "'start_frame' must be less than 'end_frame'!" - start_frame, end_frame = cast_start_end_frame(start_frame, end_frame) + start_frame, end_frame = _cast_start_end_frame(start_frame, end_frame) channel = int(channel) get_videos_correct_arg = func(imaging, start_frame=start_frame, end_frame=end_frame, channel=channel) diff --git a/src/roiextractors/imagingextractor.py b/src/roiextractors/imagingextractor.py index db862a15..c8c29ccb 100644 --- a/src/roiextractors/imagingextractor.py +++ b/src/roiextractors/imagingextractor.py @@ -1,27 +1,20 @@ """Base class definitions for all ImagingExtractors.""" from abc import ABC, abstractmethod from typing import Union, Optional, Tuple -import numpy as np from copy import deepcopy -from spikeextractors.baseextractor import BaseExtractor +import numpy as np -from .extraction_tools import ( - ArrayType, - PathType, - DtypeType, - FloatType, - check_get_videos_args, -) +from .extraction_tools import ArrayType, PathType, DtypeType, FloatType -class ImagingExtractor(ABC, BaseExtractor): +class ImagingExtractor(ABC): """Abstract class that contains all the meta-data and input data from the imaging data.""" - def __init__(self) -> None: - BaseExtractor.__init__(self) - assert self.installed, self.installation_mesg - self._memmapped = False + def __init__(self, *args, **kwargs) -> None: + self._args = args + self._kwargs = kwargs + self._times = None @abstractmethod def get_image_size(self) -> Tuple[int, int]: @@ -119,12 +112,12 @@ def set_times(self, times: ArrayType) -> None: assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!" self._times = np.array(times).astype("float64") - def copy_times(self, extractor: BaseExtractor) -> None: + def copy_times(self, extractor) -> None: """This function copies times from another extractor. Parameters ---------- - extractor: BaseExtractor + extractor The extractor from which the epochs will be copied """ if extractor._times is not None: diff --git a/src/roiextractors/segmentationextractor.py b/src/roiextractors/segmentationextractor.py index 61d86150..ef4bb758 100644 --- a/src/roiextractors/segmentationextractor.py +++ b/src/roiextractors/segmentationextractor.py @@ -2,13 +2,12 @@ from typing import Union import numpy as np -from spikeextractors.baseextractor import BaseExtractor from .extraction_tools import ArrayType, IntType, FloatType from .extraction_tools import _pixel_mask_extractor -class SegmentationExtractor(ABC, BaseExtractor): +class SegmentationExtractor(ABC): """ An abstract class that contains all the meta-data and output data from the ROI segmentation operation when applied to the pre-processed data. @@ -18,13 +17,9 @@ class SegmentationExtractor(ABC, BaseExtractor): format specific classes that inherit from this. """ - installed = True - installation_mesg = "" - def __init__(self): - assert self.installed, self.installation_mesg - BaseExtractor.__init__(self) self._sampling_frequency = None + self._times = None self._channel_names = ["OpticalChannel"] self._num_planes = 1 self._roi_response_raw = None