Skip to content

Commit

Permalink
[NSETM-2281] Fix calculation of dynamic offset (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
GianlucaFicarelli authored Mar 5, 2024
1 parent d5362f0 commit e1a8bdb
Show file tree
Hide file tree
Showing 51 changed files with 405 additions and 397 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish-sdist.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.11

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run-tox.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand Down
9 changes: 9 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
Changelog
=========

Version 0.7.0
-------------

New Features
~~~~~~~~~~~~

- Allow to specify ``trial_steps_label`` to calculate the dynamic offset of trial steps [NSETM-2281]


Version 0.6.0
-------------

Expand Down
19 changes: 14 additions & 5 deletions src/blueetl/config/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import deepcopy
from itertools import chain
from pathlib import Path
from typing import NamedTuple, Optional, Union
from typing import NamedTuple, Union

from blueetl.config.analysis_model import (
FeaturesConfig,
Expand All @@ -20,13 +20,23 @@
L = logging.getLogger(__name__)


def _resolve_paths(global_config: MultiAnalysisConfig, base_path: Optional[Path] = None) -> None:
def _resolve_paths(global_config: MultiAnalysisConfig, base_path: Path) -> None:
"""Resolve any relative path."""
base_path = base_path or Path()
global_config.output = base_path / global_config.output
global_config.simulation_campaign = base_path / global_config.simulation_campaign


def _resolve_trial_steps(global_config: MultiAnalysisConfig):
"""Set trial_steps_config.base_path to the same value as global_config.output.
In this way, the custom function can use it as the base path to save any figure.
"""
for config in global_config.analysis.values():
for trial_steps_config in config.extraction.trial_steps.values():
trial_steps_config.base_path = str(global_config.output)


def _resolve_windows(global_config: MultiAnalysisConfig) -> None:
"""Calculate the hash of any referenced windows in each single analysis configuration.
Expand Down Expand Up @@ -140,13 +150,12 @@ def _resolve_analysis_configs(global_config: MultiAnalysisConfig) -> None:
config.features = _resolve_features(config.features)


def init_multi_analysis_configuration(
global_config: dict, base_path: Optional[Path] = None
) -> MultiAnalysisConfig:
def init_multi_analysis_configuration(global_config: dict, base_path: Path) -> MultiAnalysisConfig:
"""Return a config object from a config dict."""
validate_config(global_config, schema=read_schema("analysis_config"))
config = MultiAnalysisConfig(**global_config)
_resolve_paths(config, base_path=base_path)
_resolve_trial_steps(config)
_resolve_windows(config)
_resolve_analysis_configs(config)
return config
13 changes: 10 additions & 3 deletions src/blueetl/config/analysis_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ class WindowConfig(BaseModel):
@model_validator(mode="after")
def validate_values(self):
"""Validate the values after loading them."""
if self.trial_steps_label:
raise ValueError("trial_steps_label cannot be used yet, see NSETM-2281")
if self.trial_steps_list and (self.n_trials or self.trial_steps_value):
raise ValueError("trial_steps_list cannot be set with n_trials or trial_steps_value")
if self.n_trials > 1 and not self.trial_steps_value:
Expand All @@ -84,13 +82,22 @@ class TrialStepsConfig(BaseModel):
**BaseModel.model_config,
"extra": "allow",
}
_forbidden_extra_fields: set[str] = {
"initial_offset",
}
function: str
initial_offset: float = 0.0
bounds: tuple[float, float]
population: Optional[str] = None
node_set: Optional[str] = None
limit: Optional[int] = None

@model_validator(mode="after")
def forbid_fields(self):
"""Verify that the forbidden extra fields have not been specified."""
if found := self._forbidden_extra_fields.intersection(self.model_extra):
raise ValueError(f"Forbidden extra fields: {found}")
return self


class NeuronClassConfig(BaseModel):
"""NeuronClassConfig Model."""
Expand Down
2 changes: 0 additions & 2 deletions src/blueetl/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
DURATION = "duration"
WINDOW_TYPE = "window_type"
COUNT = "count"
TRIAL_STEPS_LABEL = "trial_steps_label"
TRIAL_STEPS_VALUE = "trial_steps_value"
TIMES = "times"
BIN = "bin"
VALUE = "value"
Expand Down
85 changes: 54 additions & 31 deletions src/blueetl/external/bnac/calculate_trial_step.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,50 @@
"""Trial steps functions adapted from BlueNetworkActivityComparison/bnac/onset.py."""

import logging
from pathlib import Path

import numpy as np
from scipy.ndimage import gaussian_filter

L = logging.getLogger(__name__)


def _get_bounds(params):
lower_bound, upper_bound = params["bounds"]
assert lower_bound <= 0
assert upper_bound >= 0
return lower_bound, upper_bound


def _histogram_from_spikes(spikes, params):
onset_test_window_length = params["post_window"][1] - params["pre_window"][0]
lower_bound, upper_bound = _get_bounds(params)
onset_test_window_length = upper_bound - lower_bound
histogram, _ = np.histogram(
spikes,
range=[params["pre_window"][0], params["post_window"][1]],
range=(lower_bound, upper_bound),
bins=int(onset_test_window_length * params["histo_bins_per_ms"]),
)
return histogram


def _onset_from_histogram(histogram, params):
onset_dict = {}
lower_bound, _ = _get_bounds(params)
smoothed_histogram = gaussian_filter(histogram, sigma=params["smoothing_width"])

onset_pre_window_length = params["pre_window"][1] - params["pre_window"][0]
onset_zeroed_post_start = params["post_window"][0] - params["pre_window"][0]
pre_smoothed_histogram = smoothed_histogram[
: onset_pre_window_length * params["histo_bins_per_ms"]
]
post_smoothed_histogram = smoothed_histogram[
onset_zeroed_post_start * params["histo_bins_per_ms"] :
]

onset_dict["pre_mean"] = np.mean(pre_smoothed_histogram)
onset_dict["pre_std"] = np.std(pre_smoothed_histogram)
onset_dict["post_max"] = np.max(post_smoothed_histogram)
pre_window_length = -lower_bound
pre_window_bins = int(pre_window_length * params["histo_bins_per_ms"])
pre_smoothed_histogram = smoothed_histogram[:pre_window_bins]
post_smoothed_histogram = smoothed_histogram[pre_window_bins:]

onset_dict = {
"pre_mean": np.mean(pre_smoothed_histogram),
"pre_std": np.std(pre_smoothed_histogram),
"post_max": np.max(post_smoothed_histogram),
}
onset_dict["pre_mean_post_max_ratio"] = onset_dict["pre_mean"] / onset_dict["post_max"]

where_above_thresh = np.where(
post_smoothed_histogram
> (onset_dict["pre_mean"] + params["threshold_std_multiple"] * onset_dict["pre_std"])
)[0]
threshold = onset_dict["pre_mean"] + params["threshold_std_multiple"] * onset_dict["pre_std"]
where_above_thresh = np.where(post_smoothed_histogram > threshold)[0]
index_above_std = 0
if len(where_above_thresh) > 0:
index_above_std = where_above_thresh[0]
Expand All @@ -45,13 +54,10 @@ def _onset_from_histogram(histogram, params):
cortical_onset_index = index_above_std

onset_dict["cortical_onset"] = (
float(cortical_onset_index) / float(params["histo_bins_per_ms"])
+ float(params["post_window"][0])
+ params["ms_post_offset"]
float(cortical_onset_index) / float(params["histo_bins_per_ms"]) + params["ms_post_offset"]
)
if params["fig_paths"]:
if params.get("figures_path"):
_plot(smoothed_histogram, params, onset_dict)
onset_dict["trial_steps_value"] = onset_dict["cortical_onset"]
return onset_dict


Expand All @@ -63,27 +69,44 @@ def _plot(smoothed_histogram, params, onset_dict):
import matplotlib.pyplot as plt
import seaborn as sns

lower_bound, upper_bound = _get_bounds(params)
plt.figure()
sns.set()
sns.set_style("ticks")
x_vals = list(np.arange(params["pre_window"][0], params["post_window"][1], 0.2))
x_vals = list(np.arange(lower_bound, upper_bound, 0.2))
plt.plot(x_vals, smoothed_histogram)
plt.scatter(
onset_dict["cortical_onset"],
[onset_dict["pre_mean"] + params["threshold_std_multiple"] * onset_dict["pre_std"]],
)
plt.gca().set_xlim([params["pre_window"][0], params["post_window"][1]])
plt.gca().set_xlim([lower_bound, upper_bound])
plt.gca().set_xlabel("Time (ms)")
plt.gca().set_ylabel("Number of spikes")
plt.gca().spines["right"].set_visible(False)
plt.gca().spines["top"].set_visible(False)
for fig_path in params["fig_paths"]:
plt.savefig(fig_path)
filepath = Path(params["base_path"], params["figures_path"], "plot.pdf")
L.info("Figures path: %s", filepath)
filepath.parent.mkdir(exist_ok=True)
plt.savefig(filepath)
plt.close()


def onset_from_spikes(spikes, params):
"""Calculate trial steps from spikes."""
def onset_from_spikes(spikes_list, params):
"""Calculate the cortical onset from a list of spikes, one for each trial.
Args:
spikes_list: list of spikes as numpy arrays.
params: dictionary of parameters from the trial steps configuration.
Returns:
float representing the dynamic offset to be added to the initial offset of each trial step.
"""
L.info(
"onset_from_spikes: processing %s arrays of spikes using params=%r",
len(spikes_list),
params,
)
spikes = np.concatenate(spikes_list)
histogram = _histogram_from_spikes(spikes, params)
onset_dict = _onset_from_histogram(histogram, params)
return onset_dict
return onset_dict["cortical_onset"]
117 changes: 0 additions & 117 deletions src/blueetl/extract/trial_steps.py

This file was deleted.

Loading

0 comments on commit e1a8bdb

Please sign in to comment.