Skip to content

Commit

Permalink
Merge pull request #87 from synsense/86-log-inputs-to-spiking-layers-…
Browse files Browse the repository at this point in the history
…to-be-able-to-calculate-sparsity-loss

derive spiking layer stats from single saved output tensor and also s…
  • Loading branch information
Nogay Küpelioğlu authored Mar 9, 2023
2 parents 9f09379 + 2c81b39 commit 199d730
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 43 deletions.
67 changes: 25 additions & 42 deletions sinabs/synopcounter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import warnings
from functools import partial

import torch
import torch.nn as nn
from numpy import product

import sinabs.layers as sl
from sinabs.layers import NeuromorphicReLU
Expand All @@ -14,15 +12,14 @@ def spiking_hook(self, input_, output):
Calculates n_neurons (scalar), firing_rate_per_neuron (C,H,W) and average firing_rate (scalar).
"""
input_ = input_[0]
if isinstance(self, sl.SqueezeMixin):
input_ = input_.reshape(self.batch_size, self.num_timesteps, *input_.shape[1:])
output = output.reshape(self.batch_size, self.num_timesteps, *output.shape[1:])
self.n_neurons = output[0, 0].numel()
self.firing_rate_per_neuron = output.mean((0, 1))
self.acc_firing_rate_per_neuron = (
self.acc_firing_rate_per_neuron.detach() + self.firing_rate_per_neuron
)
self.firing_rate = output.mean()
self.acc_firing_rate = self.acc_firing_rate.detach() + self.firing_rate
self.input_ = input_
self.output_ = output
self.acc_output = self.acc_output.detach() + output
self.n_batches = self.n_batches + 1


Expand Down Expand Up @@ -114,10 +111,7 @@ def _setup_hooks(self):

for layer in self.model.modules():
if isinstance(layer, sl.StatefulLayer):
layer.firing_rate_per_neuron = torch.tensor(0)
layer.acc_firing_rate_per_neuron = torch.tensor(0)
layer.firing_rate = torch.tensor(0)
layer.acc_firing_rate = torch.tensor(0)
layer.acc_output = torch.tensor(0)
layer.n_batches = 0
handle = layer.register_forward_hook(spiking_hook)
self.handles.append(handle)
Expand Down Expand Up @@ -153,32 +147,21 @@ def get_layer_statistics(self, average: bool = False) -> dict:
spike_dict["parameter"] = {}
scale_facts = []
for name, module in self.model.named_modules():
if (
hasattr(module, "acc_firing_rate")
or hasattr(module, "firing_rate_per_neuron")
or hasattr(module, "n_neurons")
):
spike_dict["spiking"][name] = {}
if hasattr(module, "acc_firing_rate"):
spike_dict["spiking"][name].update(
{
"firing_rate": module.acc_firing_rate / module.n_batches
if average
else module.firing_rate
}
)
if hasattr(module, "firing_rate_per_neuron"):
spike_dict["spiking"][name].update(
{
"firing_rate_per_neuron": module.acc_firing_rate_per_neuron
/ module.n_batches
if average
else module.firing_rate_per_neuron
}
)
if hasattr(module, "n_neurons"):
spike_dict["spiking"][name].update({"n_neurons": module.n_neurons})

if hasattr(module, "acc_output"):
spike_dict["spiking"][name] = {
"n_neurons": module.n_neurons,
"input": module.input_,
"output": module.acc_output / module.n_batches
if average
else module.output_,
"firing_rate": module.acc_output.mean() / module.n_batches
if average
else module.output_.mean(),
"firing_rate_per_neuron": module.acc_output.mean((0, 1))
/ module.n_batches
if average
else module.output_.mean((0, 1)),
}
if isinstance(module, torch.nn.AvgPool2d):
# Average pooling scales down the number of counted synops due to the averaging.
# We need to correct for that by accumulating the scaling factors and multiplying
Expand Down Expand Up @@ -213,7 +196,7 @@ def get_layer_statistics(self, average: bool = False) -> dict:
}
return spike_dict

def get_model_statistics(self, average=False):
def get_model_statistics(self, average: bool = False) -> dict:
"""Outputs a dictionary with statistics that are summarised across all layers.
Parameters:
Expand All @@ -224,12 +207,12 @@ def get_model_statistics(self, average=False):
synops = torch.tensor(0.0)
n_neurons = torch.tensor(0.0)
for name, module in self.model.named_modules():
if hasattr(module, "firing_rate_per_neuron"):
if hasattr(module, "acc_output"):
if module.n_batches > 0:
firing_rates.append(
module.acc_firing_rate_per_neuron.ravel() / module.n_batches
module.acc_output.mean((0, 1)).ravel() / module.n_batches
if average
else module.firing_rate_per_neuron.ravel()
else module.output_.mean((0, 1)).ravel()
)
else:
firing_rates.append(torch.tensor([0.0]))
Expand Down
13 changes: 12 additions & 1 deletion tests/test_synops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,16 @@ def test_linear_synops_counter_across_batches():
input2[0, 0] = 6
analyzer = SNNAnalyzer(model)
model(input1)
batch1_stats = analyzer.get_model_statistics(average=False)
model(input2)
batch2_stats = analyzer.get_model_statistics(average=False)
model_stats = analyzer.get_model_statistics(average=True)
layer_stats = analyzer.get_layer_statistics(average=True)["parameter"][""]

# (3+6)/2 spikes * 5 channels
assert model_stats["synops"] == 22.5
assert layer_stats["synops"] == 22.5
assert 2 * batch1_stats["synops"] == batch2_stats["synops"]


def test_conv_synops_counter():
Expand All @@ -129,16 +132,19 @@ def test_conv_synops_counter():
assert layer_stats["synops"] == 30


def test_conv_synops_counter_counts_across_inputs():
def test_conv_synops_counter_counts_across_batches():
model = nn.Conv2d(1, 5, kernel_size=2)
input1 = torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(2, 1, 1, 1)
input2 = torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(2, 1, 1, 1) * 2
analyzer = SNNAnalyzer(model)
model(input1)
batch1_stats = analyzer.get_model_statistics(average=False)
model(input2)
batch2_stats = analyzer.get_model_statistics(average=False)
model_stats = analyzer.get_model_statistics(average=True)
layer_stats = analyzer.get_layer_statistics(average=True)["parameter"][""]

assert 2 * batch1_stats["synops"] == batch2_stats["synops"]
assert model_stats["synops"] == 45
assert layer_stats["synops"] == 45

Expand Down Expand Up @@ -166,12 +172,15 @@ def test_spiking_layer_firing_rate_across_batches():

analyzer = sinabs.SNNAnalyzer(layer)
output = layer(input1)
batch1_stats = analyzer.get_model_statistics(average=False)
sinabs.reset_states(layer)
output = layer(input2)
batch2_stats = analyzer.get_model_statistics(average=False)
model_stats = analyzer.get_model_statistics(average=True)
layer_stats = analyzer.get_layer_statistics(average=True)["spiking"][""]

assert (output == input2).all()
assert 2 * batch1_stats["firing_rate"] == batch2_stats["firing_rate"]
assert model_stats["firing_rate"] == 0.375
assert layer_stats["firing_rate"] == 0.375
assert layer_stats["firing_rate_per_neuron"].shape == (4, 4)
Expand Down Expand Up @@ -217,6 +226,8 @@ def test_snn_analyzer_statistics():
model_stats = analyzer.get_model_statistics(average=True)

# spiking layer checks
assert spike_layer_stats["3"]["input"].shape[0] == batch_size
assert spike_layer_stats["3"]["input"].shape[1] == num_timesteps
assert (
spike_layer_stats["3"]["firing_rate"] == output.mean()
), "The output mean should be equivalent to the firing rate of the last spiking layer"
Expand Down

0 comments on commit 199d730

Please sign in to comment.