diff --git a/si_kilosort25/models.py b/si_kilosort25/models.py index 45c730b..589985d 100644 --- a/si_kilosort25/models.py +++ b/si_kilosort25/models.py @@ -46,6 +46,12 @@ class HighpassSpatialFilter(BaseModel): highpass_butter_wn: float = Field(default=0.01, description="Natural frequency for the Butterworth filter") +class MotionCorrection(BaseModel): + compute: bool = Field(default=True, description="Whether to compute motion correction") + apply: bool = Field(default=False, description="Whether to apply motion correction") + preset: str = Field(default="nonrigid_accurate", description="Preset for motion correction") + + class PreprocessingContext(BaseModel): preprocessing_strategy: str = Field(default="cmr", description="Strategy for preprocessing") highpass_filter: HighpassFilter = Field(default=HighpassFilter(), description="Highpass filter") @@ -53,6 +59,7 @@ class PreprocessingContext(BaseModel): detect_bad_channels: DetectBadChannels = Field(default=DetectBadChannels(), description="Detect bad channels") common_reference: CommonReference = Field(default=CommonReference(), description="Common reference") highpass_spatial_filter: HighpassSpatialFilter = Field(default=HighpassSpatialFilter(), description="Highpass spatial filter") + motion_correction: MotionCorrection = Field(default=MotionCorrection(), description="Motion correction") remove_out_channels: bool = Field(default=False, description="Flag to remove out channels") remove_bad_channels: bool = Field(default=False, description="Flag to remove bad channels") max_bad_channel_fraction_to_remove: float = Field(default=1.1, description="Maximum fraction of bad channels to remove") @@ -112,6 +119,7 @@ class PipelineContext(BaseModel): lazy_read_input: bool = Field(default=True, description='Lazy read input file') stub_test: bool = Field(default=False, description='Stub test') recording_context: RecordingContext = Field(description='Recording context') + run_preprocessing: bool = Field(default=True, description='Run preprocessing') preprocessing_context: PreprocessingContext = Field(default=PreprocessingContext(), description='Preprocessing context') sorting_context: SortingContext = Field(default=SortingContext(), description='Sorting context') # postprocessing_context: PostprocessingContext = Field(default=PostprocessingContext(), description='Postprocessing context') diff --git a/si_kilosort25/processor_pipeline.py b/si_kilosort25/processor_pipeline.py index bc18694..fbdcb78 100644 --- a/si_kilosort25/processor_pipeline.py +++ b/si_kilosort25/processor_pipeline.py @@ -44,7 +44,7 @@ def run(context: PipelineContext): recording = recording.frame_slice(start_frame=0, end_frame=n_frames) ############### FOR TESTING -- REMOVE LATER ############ - print(recording) + logger.info(recording) # from spikeinterface.sorters import Kilosort2_5Sorter # Kilosort2_5Sorter.set_kilosort2_5_path(kilosort2_5_path="/mnt/shared_storage/Github/Kilosort") @@ -56,13 +56,16 @@ def run(context: PipelineContext): 'chunk_duration': '1s', 'progress_bar': False } + preprocessing_params = context.preprocessing_context.model_dump() + run_preprocessing = context.run_preprocessing logger.info('Running pipeline') _, sorting, _ = si_pipeline.run_pipeline( recording=recording, scratch_folder="./scratch/", results_folder="./results/", job_kwargs=job_kwargs, - preprocessing_params=context.preprocessing_context.model_dump(), + run_preprocessing=run_preprocessing, + preprocessing_params=preprocessing_params, spikesorting_params=context.sorting_context.model_dump(), # postprocessing_params=context.postprocessing_params, # run_preprocessing=context.run_preprocessing, diff --git a/si_kilosort25/sample_context_1.yaml b/si_kilosort25/sample_context_1.yaml index 99fbc96..c835b38 100644 --- a/si_kilosort25/sample_context_1.yaml +++ b/si_kilosort25/sample_context_1.yaml @@ -1,6 +1,7 @@ -input: https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/3c7/8e6/3c78e6c9-d196-4bea-a7ce-494a315789be +input: https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/eb0/202/eb020241-616b-47ce-8d52-76151fe9e90d output: ./output/sorting.nwb -lazy_read_input: true +lazy_read_input: false stub_test: false recording_context: electrical_series_path: /acquisition/ElectricalSeriesRaw +run_preprocessing: false