diff --git a/pylossless/dash/topo_viz.py b/pylossless/dash/topo_viz.py index 50b4991..ca0c9cc 100644 --- a/pylossless/dash/topo_viz.py +++ b/pylossless/dash/topo_viz.py @@ -33,6 +33,13 @@ yaxis.update({"scaleanchor": "x", "scaleratio": 1}) +def pick_montage(montage, ch_names): + """Pick a subset of channels from a montage.""" + digs = montage.remove_fiducials().dig + digs = [dig for dig, ch_name in zip(digs, montage.ch_names) if ch_name in ch_names] + return mne.channels.DigMontage(dig=digs, ch_names=ch_names) + + class TopoPlot: # TODO: Fix/finish doc comments for this class. """Representation of a classic EEG topographic map as a plotly figure.""" @@ -47,7 +54,7 @@ def __init__( res=64, width=None, height=None, - cmap="RdBu_r", + cmap=None, show_sensors=True, colorbar=False, ): @@ -162,9 +169,11 @@ def set_data(self, data): self.info = create_info(names, sfreq=256, ch_types="eeg") with warnings.catch_warnings(): warnings.simplefilter("ignore") + # To update self.info with channels positions RawArray( np.zeros((len(names), 1)), self.info, copy=None, verbose=False ).set_montage(self.montage) + assert np.all(np.array(names) == np.array(self.info.ch_names)) self.set_head_pos_contours() # TODO: Finish/fix docstring @@ -278,12 +287,24 @@ def plot_topo(self, **kwargs): ------- A plotly.graph_objects.Figure object. """ + from .utils import _setup_vmin_vmax + if self.__data is None: return + data = np.array(list(self.__data.values())) + norm = min(np.array(data)) >= 0 + vmin, vmax = _setup_vmin_vmax(data, None, None, norm) + if self.cmap is None: + cmap = "Reds" if norm else "RdBu_r" + else: + cmap = self.cmap + heatmap_trace = go.Heatmap( showscale=self.colorbar, - colorscale=self.cmap, + colorscale=cmap, + zmin=vmin, + zmax=vmax, **self.get_heatmap_data(**kwargs) ) @@ -341,9 +362,10 @@ def __init__( See mne.channels.make_standard_montage(), and mne.channels.get_builtin_montages() for more information on making montage objects in MNE. - data : mne.preprocessing.ICA | None - The data to use for the topoplots. Can be an instance of - mne.preprocessing.ICA. + data : list | None + The data to use for the topoplots. Should be a list of + dictionaries, one per topomap. The dictionaries should + have the channel names as keys. figure : plotly.graph_objects.Figure | None Figure to use (if not None) for plotting. color : str @@ -635,13 +657,11 @@ def initialize_layout(self, slider_val=None, show_sensors=True): # The indexing with ch_names is to ensure the order # of the channels are compatible between plot_data and the montage - ch_names = [ - ch_name - for ch_name in self.montage.ch_names - if ch_name in self.data.topo_values.columns + montage = pick_montage(self.montage, self.data.topo_values.columns) + ch_names = montage.ch_names + plot_data = [ + OrderedDict(self.data.topo_values.loc[title, ch_names]) for title in titles ] - plot_data = self.data.topo_values.loc[titles, ch_names] - plot_data = list(plot_data.T.to_dict().values()) if len(plot_data) < self.nb_sel_topo: nb_missing_topo = self.nb_sel_topo - len(plot_data) @@ -650,7 +670,7 @@ def initialize_layout(self, slider_val=None, show_sensors=True): self.figure = GridTopoPlot( rows=self.rows, cols=self.cols, - montage=self.montage, + montage=montage, data=plot_data, color=colors, res=self.res, @@ -807,16 +827,14 @@ def init_vars(self, montage, ica, ic_labels): return None data = TopoData( - [ - dict(zip(montage.ch_names, component)) - for component in ica.get_components().T - ] + [dict(zip(ica.ch_names, component)) for component in ica.get_components().T] ) + data.topo_values.index = ica._ica_names + if ic_labels: self.head_contours_color = { comp: ic_label_cmap[label] for comp, label in ic_labels.items() } - data.topo_values.index = list(ic_labels.keys()) return data def load_recording(self, montage, ica, ic_labels): diff --git a/pylossless/dash/utils.py b/pylossless/dash/utils.py new file mode 100644 index 0000000..7d57e6b --- /dev/null +++ b/pylossless/dash/utils.py @@ -0,0 +1,24 @@ +import numpy as np + + +def _setup_vmin_vmax(data, vmin, vmax, norm=False): + """Handle vmin and vmax parameters for visualizing topomaps. + + This is a simplified copy of mne.viz.utils._setup_vmin_vmax. + https://github.com/mne-tools/mne-python/blob/main/mne/viz/utils.py + + Notes + ----- + For the normal use-case (when `vmin` and `vmax` are None), the parameter + `norm` drives the computation. When norm=False, data is supposed to come + from a mag and the output tuple (vmin, vmax) is symmetric range + (-x, x) where x is the max(abs(data)). When norm=True (a.k.a. data is the + L2 norm of a gradiometer pair) the output tuple corresponds to (0, x). + + in the MNE version vmin and vmax can be callables that drive the operation, + but for the sake of simplicity this was not copied over. + """ + if vmax is None and vmin is None: + vmax = np.abs(data).max() + vmin = 0.0 if norm else -vmax + return vmin, vmax