From 2c81b39362981763e02552e9714ab1dc4f7cfd30 Mon Sep 17 00:00:00 2001 From: Gregor Lenz Date: Thu, 9 Mar 2023 16:17:09 +0100 Subject: [PATCH] expand saved input in spiking layers for analyzer --- sinabs/synopcounter.py | 1 + tests/test_synops_counter.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/sinabs/synopcounter.py b/sinabs/synopcounter.py index 3cddcf5c..9deb7e81 100644 --- a/sinabs/synopcounter.py +++ b/sinabs/synopcounter.py @@ -14,6 +14,7 @@ def spiking_hook(self, input_, output): """ input_ = input_[0] if isinstance(self, sl.SqueezeMixin): + input_ = input_.reshape(self.batch_size, self.num_timesteps, *input_.shape[1:]) output = output.reshape(self.batch_size, self.num_timesteps, *output.shape[1:]) self.n_neurons = output[0, 0].numel() self.input_ = input_ diff --git a/tests/test_synops_counter.py b/tests/test_synops_counter.py index 953713dc..306cf2ff 100644 --- a/tests/test_synops_counter.py +++ b/tests/test_synops_counter.py @@ -226,6 +226,8 @@ def test_snn_analyzer_statistics(): model_stats = analyzer.get_model_statistics(average=True) # spiking layer checks + assert spike_layer_stats["3"]["input"].shape[0] == batch_size + assert spike_layer_stats["3"]["input"].shape[1] == num_timesteps assert ( spike_layer_stats["3"]["firing_rate"] == output.mean() ), "The output mean should be equivalent to the firing rate of the last spiking layer"