Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fiber photometry project] Add fiber photometry conversion #13

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/constantinople_lab_to_nwb/fiber_photometry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .fiber_photometry_nwbconverter import FiberPhotometryNWBConverter
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import os
import re
from datetime import datetime
from pathlib import Path
from typing import Union, Optional

from dateutil import tz
from neuroconv.utils import load_dict_from_file, dict_deep_update

from constantinople_lab_to_nwb.fiber_photometry import FiberPhotometryNWBConverter
from ndx_pose import PoseEstimation


def session_to_nwb(
raw_fiber_photometry_file_path: Union[str, Path],
nwbfile_path: Union[str, Path],
dlc_file_path: Optional[Union[str, Path]] = None,
video_file_path: Optional[Union[str, Path]] = None,
stub_test: bool = False,
overwrite: bool = False,
verbose: bool = False,
):
"""Converts a fiber photometry session to NWB.

Parameters
----------
raw_fiber_photometry_file_path : Union[str, Path]
Path to the raw fiber photometry file.
nwbfile_path : Union[str, Path]
Path to the NWB file.
dlc_file_path : Union[str, Path], optional
Path to the DLC file, by default None.
video_file_path : Union[str, Path], optional
Path to the video file, by default None.
overwrite : bool, optional
Whether to overwrite the NWB file if it already exists, by default False.
verbose : bool, optional
Controls verbosity.
"""
source_data = dict()
conversion_options = dict()

raw_fiber_photometry_file_path = Path(raw_fiber_photometry_file_path)
raw_fiber_photometry_file_name = raw_fiber_photometry_file_path.stem
subject_id, session_id = raw_fiber_photometry_file_name.split("_", maxsplit=1)
session_id = session_id.replace("_", "-")

# Add fiber photometry data
file_suffix = raw_fiber_photometry_file_path.suffix
if file_suffix == ".doric":
raw_stream_name = "/DataAcquisition/FPConsole/Signals/Series0001/AnalogIn"
elif file_suffix == ".csv":
raw_stream_name = "Raw"
else:
raise ValueError(
f"File '{raw_fiber_photometry_file_path}' extension should be either .doric or .csv and not '{file_suffix}'."
)

source_data.update(
dict(
FiberPhotometry=dict(
file_path=raw_fiber_photometry_file_path,
stream_name=raw_stream_name,
)
)
)
conversion_options.update(
dict(
FiberPhotometry=dict(
stub_test=stub_test,
fiber_photometry_series_name="fiber_photometry_response_series_green",
)
)
)

if dlc_file_path is not None:
source_data.update(dict(DeepLabCut=dict(file_path=dlc_file_path)))

if video_file_path is not None:
source_data.update(dict(Video=dict(file_paths=[video_file_path])))

converter = FiberPhotometryNWBConverter(source_data=source_data, verbose=verbose)

# Add datetime to conversion
metadata = converter.get_metadata()
metadata["NWBFile"].update(session_id=session_id)

date_pattern = r"(?P<date>\d{8})"

match = re.search(date_pattern, raw_fiber_photometry_file_name)
if match:
date_str = match.group("date")
date_obj = datetime.strptime(date_str, "%Y%m%d")
session_start_time = date_obj
tzinfo = tz.gettz("America/New_York")
metadata["NWBFile"].update(session_start_time=session_start_time.replace(tzinfo=tzinfo))

# Update default metadata with the editable in the corresponding yaml file
editable_metadata_path = Path(__file__).parent / "metadata" / "fiber_photometry_metadata.yaml"
editable_metadata = load_dict_from_file(editable_metadata_path)
metadata = dict_deep_update(metadata, editable_metadata)

# Run conversion
converter.run_conversion(
nwbfile_path=nwbfile_path,
metadata=metadata,
conversion_options=conversion_options,
overwrite=overwrite,
)


if __name__ == "__main__":
# Parameters for conversion
# Fiber photometry file path
doric_fiber_photometry_file_path = Path(
"/Volumes/T9/Constantinople/Preprocessed_data/J069/Raw/J069_ACh_20230809_HJJ_0002.doric"
)
# DLC file path (optional)
dlc_file_path = Path(
"/Volumes/T9/Constantinople/DeepLabCut/J069/J069-2023-08-09_rig104cam01_0002compDLC_resnet50_GRAB_DA_DMS_RIG104DoricCamera_J029May12shuffle1_500000.h5"
)
# Behavior video file path (optional)
behavior_video_file_path = Path(
"/Volumes/T9/Constantinople/Compressed Videos/J069/J069-2023-08-09_rig104cam01_0002comp.mp4"
)
# NWB file path
nwbfile_path = Path("/Volumes/T9/Constantinople/nwbfiles/J069_ACh_20230809_HJJ_0002.nwb")
if not nwbfile_path.parent.exists():
os.makedirs(nwbfile_path.parent, exist_ok=True)

stub_test = False
overwrite = True

session_to_nwb(
raw_fiber_photometry_file_path=doric_fiber_photometry_file_path,
nwbfile_path=nwbfile_path,
dlc_file_path=dlc_file_path,
video_file_path=behavior_video_file_path,
stub_test=stub_test,
overwrite=overwrite,
verbose=True,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from pathlib import Path

from neuroconv import NWBConverter
from neuroconv.datainterfaces import DeepLabCutInterface, VideoInterface

from constantinople_lab_to_nwb.fiber_photometry.interfaces import (
DoricFiberPhotometryInterface,
DoricCsvFiberPhotometryInterface,
)


class FiberPhotometryNWBConverter(NWBConverter):
"""Primary conversion class for converting the Fiber photometry dataset from the Constantinople Lab."""

data_interface_classes = dict(
DeepLabCut=DeepLabCutInterface,
Video=VideoInterface,
)

def __init__(self, source_data: dict[str, dict], verbose: bool = True):
"""Validate source_data against source_schema and initialize all data interfaces."""
fiber_photometry_source_data = source_data["FiberPhotometry"]
fiber_photometry_file_path = Path(fiber_photometry_source_data["file_path"])
if fiber_photometry_file_path.suffix == ".doric":
self.data_interface_classes["FiberPhotometry"] = DoricFiberPhotometryInterface
elif fiber_photometry_file_path.suffix == ".csv":
self.data_interface_classes["FiberPhotometry"] = DoricCsvFiberPhotometryInterface
super().__init__(source_data=source_data, verbose=verbose)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ndx-fiber-photometry==0.1.0
neuroconv[deeplabcut,video]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .doric_fiber_photometry_interface import DoricFiberPhotometryInterface, DoricCsvFiberPhotometryInterface
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from pathlib import Path
from typing import Union, Literal

import h5py
import numpy as np
import pandas as pd
from ndx_fiber_photometry import FiberPhotometryResponseSeries
from neuroconv import BaseTemporalAlignmentInterface
from neuroconv.tools import get_module
from pynwb import NWBFile

from constantinople_lab_to_nwb.fiber_photometry.utils import add_fiber_photometry_table, add_fiber_photometry_devices


class DoricFiberPhotometryInterface(BaseTemporalAlignmentInterface):
"""Behavior interface for fiber photometry conversion"""

def __init__(
self,
file_path: Union[str, Path],
stream_name: str,
verbose: bool = True,
):
self._timestamps = None
self._time_column_name = "Time"
super().__init__(file_path=file_path, stream_name=stream_name, verbose=verbose)

def load(self, stream_name: str):
file_path = Path(self.source_data["file_path"])
# check if suffix is .doric
if file_path.suffix != ".doric":
raise ValueError(f"File '{file_path}' is not a .doric file.")

channel_group = h5py.File(file_path, mode="r")[stream_name]
if self._time_column_name not in channel_group.keys():
raise ValueError(f"Time not found in '{stream_name}'.")
return channel_group

def get_original_timestamps(self) -> np.ndarray:
channel_group = self.load(stream_name=self.source_data["stream_name"])
return channel_group[self._time_column_name][:]

def get_timestamps(self, stub_test: bool = False) -> np.ndarray:
timestamps = self._timestamps if self._timestamps is not None else self.get_original_timestamps()
if stub_test:
return timestamps[:100]
return timestamps

def set_aligned_timestamps(self, aligned_timestamps: np.ndarray) -> None:
self._timestamps = np.array(aligned_timestamps)

def _get_traces(self, stream_name: str, stream_indices: list, stub_test: bool = False):
traces_to_add = []
data = self.load(stream_name=stream_name)
channel_names = list(data.keys())
for stream_index in stream_indices:
trace = data[channel_names[stream_index]]
trace = trace[:100] if stub_test else trace[:]
traces_to_add.append(trace)

traces = np.vstack(traces_to_add).T
return traces

def add_to_nwbfile(
self,
nwbfile: NWBFile,
metadata: dict,
fiber_photometry_series_name: str,
parent_container: Literal["acquisition", "processing/ophys"] = "acquisition",
stub_test: bool = False,
) -> None:

add_fiber_photometry_devices(nwbfile=nwbfile, metadata=metadata)

fiber_photometry_metadata = metadata["Ophys"]["FiberPhotometry"]
traces_metadata = fiber_photometry_metadata["FiberPhotometryResponseSeries"]
trace_metadata = next(
(trace for trace in traces_metadata if trace["name"] == fiber_photometry_series_name),
None,
)
if trace_metadata is None:
raise ValueError(f"Trace metadata for '{fiber_photometry_series_name}' not found.")

add_fiber_photometry_table(nwbfile=nwbfile, metadata=metadata)
fiber_photometry_table = nwbfile.lab_meta_data["FiberPhotometry"].fiber_photometry_table

row_indices = trace_metadata["fiber_photometry_table_region"]
device_fields = [
"optical_fiber",
"excitation_source",
"photodetector",
"dichroic_mirror",
"indicator",
"excitation_filter",
"emission_filter",
]
for row_index in row_indices:
row_metadata = fiber_photometry_metadata["FiberPhotometryTable"]["rows"][row_index]
row_data = {field: nwbfile.devices[row_metadata[field]] for field in device_fields if field in row_metadata}
row_data["location"] = row_metadata["location"]
if "coordinates" in row_metadata:
row_data["coordinates"] = row_metadata["coordinates"]
if "commanded_voltage_series" in row_metadata:
row_data["commanded_voltage_series"] = nwbfile.acquisition[row_metadata["commanded_voltage_series"]]
fiber_photometry_table.add_row(**row_data)

stream_name = trace_metadata["stream_name"]
stream_indices = trace_metadata["stream_indices"]

traces = self._get_traces(stream_name=stream_name, stream_indices=stream_indices, stub_test=stub_test)

fiber_photometry_table_region = fiber_photometry_table.create_fiber_photometry_table_region(
description=trace_metadata["fiber_photometry_table_region_description"],
region=trace_metadata["fiber_photometry_table_region"],
)

# Get the timing information
timestamps = self.get_timestamps(stub_test=stub_test)

fiber_photometry_response_series = FiberPhotometryResponseSeries(
name=trace_metadata["name"],
description=trace_metadata["description"],
data=traces,
unit=trace_metadata["unit"],
fiber_photometry_table_region=fiber_photometry_table_region,
timestamps=timestamps,
)

if parent_container == "acquisition":
nwbfile.add_acquisition(fiber_photometry_response_series)
elif parent_container == "processing/ophys":
ophys_module = get_module(
nwbfile,
name="ophys",
description="Contains the processed fiber photometry data.",
)
ophys_module.add(fiber_photometry_response_series)


class DoricCsvFiberPhotometryInterface(DoricFiberPhotometryInterface):

def __init__(
self,
file_path: Union[str, Path],
stream_name: str,
verbose: bool = True,
):
super().__init__(file_path=file_path, stream_name=stream_name, verbose=verbose)
self._time_column_name = "Time(s)"

def get_original_timestamps(self) -> np.ndarray:
channel_group = self.load(stream_name=self.source_data["stream_name"])
return channel_group[self._time_column_name].values

def load(self, stream_name: str):
file_path = Path(self.source_data["file_path"])
# check if suffix is .doric
if file_path.suffix != ".csv":
raise ValueError(f"File '{file_path}' is not a .csv file.")

df = pd.read_csv(file_path, header=[0, 1])
df = df.droplevel(0, axis=1)
if self._time_column_name not in df.columns:
raise ValueError(f"Time column not found in '{file_path}'.")
filtered_columns = [col for col in df.columns if stream_name in col]
filtered_columns.append(self._time_column_name)
df = df[filtered_columns]
return df

def _get_traces(self, stream_name: str, stream_indices: list, stub_test: bool = False):
traces_to_add = []
data = self.load(stream_name=stream_name)
channel_names = [col for col in data.columns if col != self._time_column_name]
for stream_index in stream_indices:
trace = data[channel_names[stream_index]]
trace = trace[:100] if stub_test else trace
traces_to_add.append(trace)

traces = np.vstack(traces_to_add).T
return traces
Loading