diff --git a/sinabs/synopcounter.py b/sinabs/synopcounter.py index 0e45b22f..df5ab419 100644 --- a/sinabs/synopcounter.py +++ b/sinabs/synopcounter.py @@ -17,13 +17,12 @@ def spiking_hook(self, input_, output): if isinstance(self, sl.SqueezeMixin): output = output.reshape(self.batch_size, self.num_timesteps, *output.shape[1:]) self.n_neurons = output[0, 0].numel() - if self.firing_rate_per_neuron == None: - self.firing_rate_per_neuron = output.mean((0, 1)) - else: - self.firing_rate_per_neuron = ( - self.firing_rate_per_neuron.detach() + output.mean((0, 1)) - ) - self.tracked_firing_rate = self.tracked_firing_rate.detach() + output.mean() + self.firing_rate_per_neuron = output.mean((0, 1)) + self.acc_firing_rate_per_neuron = ( + self.acc_firing_rate_per_neuron.detach() + self.firing_rate_per_neuron + ) + self.firing_rate = output.mean() + self.acc_firing_rate = self.acc_firing_rate.detach() + self.firing_rate self.n_batches = self.n_batches + 1 @@ -115,8 +114,10 @@ def _setup_hooks(self): for layer in self.model.modules(): if isinstance(layer, sl.StatefulLayer): - layer.firing_rate_per_neuron = None - layer.tracked_firing_rate = torch.tensor(0) + layer.firing_rate_per_neuron = torch.tensor(0) + layer.acc_firing_rate_per_neuron = torch.tensor(0) + layer.firing_rate = torch.tensor(0) + layer.acc_firing_rate = torch.tensor(0) layer.n_batches = 0 handle = layer.register_forward_hook(spiking_hook) self.handles.append(handle) @@ -153,20 +154,26 @@ def get_layer_statistics(self, average: bool = False) -> dict: scale_facts = [] for name, module in self.model.named_modules(): if ( - hasattr(module, "tracked_firing_rate") + hasattr(module, "acc_firing_rate") or hasattr(module, "firing_rate_per_neuron") or hasattr(module, "n_neurons") ): spike_dict["spiking"][name] = {} - if hasattr(module, "tracked_firing_rate"): + if hasattr(module, "acc_firing_rate"): spike_dict["spiking"][name].update( - {"firing_rate": module.tracked_firing_rate / module.n_batches} + { + "firing_rate": module.acc_firing_rate / module.n_batches + if average + else module.firing_rate + } ) if hasattr(module, "firing_rate_per_neuron"): spike_dict["spiking"][name].update( { - "firing_rate_per_neuron": module.firing_rate_per_neuron + "firing_rate_per_neuron": module.acc_firing_rate_per_neuron / module.n_batches + if average + else module.firing_rate_per_neuron } ) if hasattr(module, "n_neurons"): @@ -219,7 +226,9 @@ def get_model_statistics(self, average=False): if hasattr(module, "firing_rate_per_neuron"): if module.n_batches > 0: firing_rates.append( - module.firing_rate_per_neuron.ravel() / module.n_batches + module.acc_firing_rate_per_neuron.ravel() / module.n_batches + if average + else module.firing_rate_per_neuron.ravel() ) else: firing_rates.append(torch.tensor([0.0])) diff --git a/tests/test_synops_counter.py b/tests/test_synops_counter.py index 97264fb8..05f241d2 100644 --- a/tests/test_synops_counter.py +++ b/tests/test_synops_counter.py @@ -149,8 +149,8 @@ def test_spiking_layer_firing_rate(): analyzer = sinabs.SNNAnalyzer(layer) output = layer(input_) - model_stats = analyzer.get_model_statistics() - layer_stats = analyzer.get_layer_statistics()["spiking"][""] + model_stats = analyzer.get_model_statistics(average=True) + layer_stats = analyzer.get_layer_statistics(average=True)["spiking"][""] assert (output == input_).all() assert model_stats["firing_rate"] == 0.25 @@ -168,8 +168,8 @@ def test_spiking_layer_firing_rate_across_batches(): output = layer(input1) sinabs.reset_states(layer) output = layer(input2) - model_stats = analyzer.get_model_statistics() - layer_stats = analyzer.get_layer_statistics()["spiking"][""] + model_stats = analyzer.get_model_statistics(average=True) + layer_stats = analyzer.get_layer_statistics(average=True)["spiking"][""] assert (output == input2).all() assert model_stats["firing_rate"] == 0.375 @@ -188,8 +188,8 @@ def test_analyzer_reset(): sinabs.reset_states(layer) analyzer.reset() output = layer(input_) - model_stats = analyzer.get_model_statistics() - layer_stats = analyzer.get_layer_statistics()["spiking"][""] + model_stats = analyzer.get_model_statistics(average=True) + layer_stats = analyzer.get_layer_statistics(average=True)["spiking"][""] assert (output == input_).all() assert model_stats["firing_rate"] == 0.5 @@ -212,9 +212,9 @@ def test_snn_analyzer_statistics(): input_ = torch.rand((batch_size, num_timesteps, 1, 16, 16)) * 100 input_flattended = input_.flatten(0, 1) output = model(input_flattended) - spike_layer_stats = analyzer.get_layer_statistics()["spiking"] - param_layer_stats = analyzer.get_layer_statistics()["parameter"] - model_stats = analyzer.get_model_statistics() + spike_layer_stats = analyzer.get_layer_statistics(average=True)["spiking"] + param_layer_stats = analyzer.get_layer_statistics(average=True)["parameter"] + model_stats = analyzer.get_model_statistics(average=True) # spiking layer checks assert (