Skip to content

Commit

Permalink
refactor: flake8 corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
gtdang committed Oct 11, 2024
1 parent 9717ae7 commit 308ffe8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
7 changes: 4 additions & 3 deletions hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,10 @@ def test_spikes_raster_trial_idx(self, base_simulation_spikes):
fig = net.cell_response.plot_spikes_raster(trial_idx=index_arg,
show=False)
# Check that collections contain data
assert (all([collection.get_positions() != [-1]
for collection in fig.axes[0].collections]),
"No data plotted in raster plot")
assert all(
[collection.get_positions() != [-1]
for collection in fig.axes[0].collections]
), "No data plotted in raster plot"

def test_spikes_raster_colors(self, base_simulation_spikes):
net, _ = base_simulation_spikes
Expand Down
7 changes: 3 additions & 4 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,8 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None,


def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
cell_types = ['L2_basket', 'L2_pyramidal',
'L5_basket','L5_pyramidal'],
cell_types=['L2_basket', 'L2_pyramidal',
'L5_basket', 'L5_pyramidal'],
colors=None,
):
"""Plot the aggregate spiking activity according to cell type.
Expand Down Expand Up @@ -552,7 +552,6 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
f"cell types. {len(colors)} colors provided "
f"for {len(cell_types)} cell types.")


# Extract desired trials
spike_times = np.concatenate(
np.array(cell_response._spike_times, dtype=object)[trial_idx])
Expand All @@ -573,7 +572,7 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
for cell_type in cell_types:
cell_type_gids = np.unique(spike_gids[spike_types == cell_type])
cell_type_times, cell_type_ypos = [], []
color = next(color_iter)
color = next(color_iter)

for gid in cell_type_gids:
gid_time = spike_times[spike_gids == gid]
Expand Down

0 comments on commit 308ffe8

Please sign in to comment.