Skip to content

Commit

Permalink
expand saved input in spiking layers for analyzer
Browse files Browse the repository at this point in the history
  • Loading branch information
biphasic committed Mar 9, 2023
1 parent 03dd104 commit 2c81b39
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
1 change: 1 addition & 0 deletions sinabs/synopcounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def spiking_hook(self, input_, output):
"""
input_ = input_[0]
if isinstance(self, sl.SqueezeMixin):
input_ = input_.reshape(self.batch_size, self.num_timesteps, *input_.shape[1:])
output = output.reshape(self.batch_size, self.num_timesteps, *output.shape[1:])
self.n_neurons = output[0, 0].numel()
self.input_ = input_
Expand Down
2 changes: 2 additions & 0 deletions tests/test_synops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def test_snn_analyzer_statistics():
model_stats = analyzer.get_model_statistics(average=True)

# spiking layer checks
assert spike_layer_stats["3"]["input"].shape[0] == batch_size
assert spike_layer_stats["3"]["input"].shape[1] == num_timesteps
assert (
spike_layer_stats["3"]["firing_rate"] == output.mean()
), "The output mean should be equivalent to the firing rate of the last spiking layer"
Expand Down

0 comments on commit 2c81b39

Please sign in to comment.