From 03dd104876561676f4666d0536a6d4f5195407f7 Mon Sep 17 00:00:00 2001 From: Gregor Lenz Date: Thu, 9 Mar 2023 16:12:39 +0100 Subject: [PATCH 1/2] derive spiking layer stats from single saved output tensor and also save input --- sinabs/synopcounter.py | 66 +++++++++++++----------------------- tests/test_synops_counter.py | 11 +++++- 2 files changed, 34 insertions(+), 43 deletions(-) diff --git a/sinabs/synopcounter.py b/sinabs/synopcounter.py index d775b3dc..3cddcf5c 100644 --- a/sinabs/synopcounter.py +++ b/sinabs/synopcounter.py @@ -1,9 +1,7 @@ import warnings -from functools import partial import torch import torch.nn as nn -from numpy import product import sinabs.layers as sl from sinabs.layers import NeuromorphicReLU @@ -14,15 +12,13 @@ def spiking_hook(self, input_, output): Calculates n_neurons (scalar), firing_rate_per_neuron (C,H,W) and average firing_rate (scalar). """ + input_ = input_[0] if isinstance(self, sl.SqueezeMixin): output = output.reshape(self.batch_size, self.num_timesteps, *output.shape[1:]) self.n_neurons = output[0, 0].numel() - self.firing_rate_per_neuron = output.mean((0, 1)) - self.acc_firing_rate_per_neuron = ( - self.acc_firing_rate_per_neuron.detach() + self.firing_rate_per_neuron - ) - self.firing_rate = output.mean() - self.acc_firing_rate = self.acc_firing_rate.detach() + self.firing_rate + self.input_ = input_ + self.output_ = output + self.acc_output = self.acc_output.detach() + output self.n_batches = self.n_batches + 1 @@ -114,10 +110,7 @@ def _setup_hooks(self): for layer in self.model.modules(): if isinstance(layer, sl.StatefulLayer): - layer.firing_rate_per_neuron = torch.tensor(0) - layer.acc_firing_rate_per_neuron = torch.tensor(0) - layer.firing_rate = torch.tensor(0) - layer.acc_firing_rate = torch.tensor(0) + layer.acc_output = torch.tensor(0) layer.n_batches = 0 handle = layer.register_forward_hook(spiking_hook) self.handles.append(handle) @@ -153,32 +146,21 @@ def get_layer_statistics(self, average: bool = False) -> dict: spike_dict["parameter"] = {} scale_facts = [] for name, module in self.model.named_modules(): - if ( - hasattr(module, "acc_firing_rate") - or hasattr(module, "firing_rate_per_neuron") - or hasattr(module, "n_neurons") - ): - spike_dict["spiking"][name] = {} - if hasattr(module, "acc_firing_rate"): - spike_dict["spiking"][name].update( - { - "firing_rate": module.acc_firing_rate / module.n_batches - if average - else module.firing_rate - } - ) - if hasattr(module, "firing_rate_per_neuron"): - spike_dict["spiking"][name].update( - { - "firing_rate_per_neuron": module.acc_firing_rate_per_neuron - / module.n_batches - if average - else module.firing_rate_per_neuron - } - ) - if hasattr(module, "n_neurons"): - spike_dict["spiking"][name].update({"n_neurons": module.n_neurons}) - + if hasattr(module, "acc_output"): + spike_dict["spiking"][name] = { + "n_neurons": module.n_neurons, + "input": module.input_, + "output": module.acc_output / module.n_batches + if average + else module.output_, + "firing_rate": module.acc_output.mean() / module.n_batches + if average + else module.output_.mean(), + "firing_rate_per_neuron": module.acc_output.mean((0, 1)) + / module.n_batches + if average + else module.output_.mean((0, 1)), + } if isinstance(module, torch.nn.AvgPool2d): # Average pooling scales down the number of counted synops due to the averaging. # We need to correct for that by accumulating the scaling factors and multiplying @@ -213,7 +195,7 @@ def get_layer_statistics(self, average: bool = False) -> dict: } return spike_dict - def get_model_statistics(self, average=False): + def get_model_statistics(self, average: bool = False) -> dict: """Outputs a dictionary with statistics that are summarised across all layers. Parameters: @@ -224,12 +206,12 @@ def get_model_statistics(self, average=False): synops = torch.tensor(0.0) n_neurons = torch.tensor(0.0) for name, module in self.model.named_modules(): - if hasattr(module, "firing_rate_per_neuron"): + if hasattr(module, "acc_output"): if module.n_batches > 0: firing_rates.append( - module.acc_firing_rate_per_neuron.ravel() / module.n_batches + module.acc_output.mean((0, 1)).ravel() / module.n_batches if average - else module.firing_rate_per_neuron.ravel() + else module.output_.mean((0, 1)).ravel() ) else: firing_rates.append(torch.tensor([0.0])) diff --git a/tests/test_synops_counter.py b/tests/test_synops_counter.py index 05f241d2..953713dc 100644 --- a/tests/test_synops_counter.py +++ b/tests/test_synops_counter.py @@ -107,13 +107,16 @@ def test_linear_synops_counter_across_batches(): input2[0, 0] = 6 analyzer = SNNAnalyzer(model) model(input1) + batch1_stats = analyzer.get_model_statistics(average=False) model(input2) + batch2_stats = analyzer.get_model_statistics(average=False) model_stats = analyzer.get_model_statistics(average=True) layer_stats = analyzer.get_layer_statistics(average=True)["parameter"][""] # (3+6)/2 spikes * 5 channels assert model_stats["synops"] == 22.5 assert layer_stats["synops"] == 22.5 + assert 2 * batch1_stats["synops"] == batch2_stats["synops"] def test_conv_synops_counter(): @@ -129,16 +132,19 @@ def test_conv_synops_counter(): assert layer_stats["synops"] == 30 -def test_conv_synops_counter_counts_across_inputs(): +def test_conv_synops_counter_counts_across_batches(): model = nn.Conv2d(1, 5, kernel_size=2) input1 = torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(2, 1, 1, 1) input2 = torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(2, 1, 1, 1) * 2 analyzer = SNNAnalyzer(model) model(input1) + batch1_stats = analyzer.get_model_statistics(average=False) model(input2) + batch2_stats = analyzer.get_model_statistics(average=False) model_stats = analyzer.get_model_statistics(average=True) layer_stats = analyzer.get_layer_statistics(average=True)["parameter"][""] + assert 2 * batch1_stats["synops"] == batch2_stats["synops"] assert model_stats["synops"] == 45 assert layer_stats["synops"] == 45 @@ -166,12 +172,15 @@ def test_spiking_layer_firing_rate_across_batches(): analyzer = sinabs.SNNAnalyzer(layer) output = layer(input1) + batch1_stats = analyzer.get_model_statistics(average=False) sinabs.reset_states(layer) output = layer(input2) + batch2_stats = analyzer.get_model_statistics(average=False) model_stats = analyzer.get_model_statistics(average=True) layer_stats = analyzer.get_layer_statistics(average=True)["spiking"][""] assert (output == input2).all() + assert 2 * batch1_stats["firing_rate"] == batch2_stats["firing_rate"] assert model_stats["firing_rate"] == 0.375 assert layer_stats["firing_rate"] == 0.375 assert layer_stats["firing_rate_per_neuron"].shape == (4, 4) From 2c81b39362981763e02552e9714ab1dc4f7cfd30 Mon Sep 17 00:00:00 2001 From: Gregor Lenz Date: Thu, 9 Mar 2023 16:17:09 +0100 Subject: [PATCH 2/2] expand saved input in spiking layers for analyzer --- sinabs/synopcounter.py | 1 + tests/test_synops_counter.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/sinabs/synopcounter.py b/sinabs/synopcounter.py index 3cddcf5c..9deb7e81 100644 --- a/sinabs/synopcounter.py +++ b/sinabs/synopcounter.py @@ -14,6 +14,7 @@ def spiking_hook(self, input_, output): """ input_ = input_[0] if isinstance(self, sl.SqueezeMixin): + input_ = input_.reshape(self.batch_size, self.num_timesteps, *input_.shape[1:]) output = output.reshape(self.batch_size, self.num_timesteps, *output.shape[1:]) self.n_neurons = output[0, 0].numel() self.input_ = input_ diff --git a/tests/test_synops_counter.py b/tests/test_synops_counter.py index 953713dc..306cf2ff 100644 --- a/tests/test_synops_counter.py +++ b/tests/test_synops_counter.py @@ -226,6 +226,8 @@ def test_snn_analyzer_statistics(): model_stats = analyzer.get_model_statistics(average=True) # spiking layer checks + assert spike_layer_stats["3"]["input"].shape[0] == batch_size + assert spike_layer_stats["3"]["input"].shape[1] == num_timesteps assert ( spike_layer_stats["3"]["firing_rate"] == output.mean() ), "The output mean should be equivalent to the firing rate of the last spiking layer"