Skip to content

Commit

Permalink
distinguish between accumulated and mini batch stats for firing rates
Browse files Browse the repository at this point in the history
  • Loading branch information
biphasic committed Feb 23, 2023
1 parent 5f16129 commit 33c642f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 23 deletions.
37 changes: 23 additions & 14 deletions sinabs/synopcounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ def spiking_hook(self, input_, output):
if isinstance(self, sl.SqueezeMixin):
output = output.reshape(self.batch_size, self.num_timesteps, *output.shape[1:])
self.n_neurons = output[0, 0].numel()
if self.firing_rate_per_neuron == None:
self.firing_rate_per_neuron = output.mean((0, 1))
else:
self.firing_rate_per_neuron = (
self.firing_rate_per_neuron.detach() + output.mean((0, 1))
)
self.tracked_firing_rate = self.tracked_firing_rate.detach() + output.mean()
self.firing_rate_per_neuron = output.mean((0, 1))
self.acc_firing_rate_per_neuron = (
self.acc_firing_rate_per_neuron.detach() + self.firing_rate_per_neuron
)
self.firing_rate = output.mean()
self.acc_firing_rate = self.acc_firing_rate.detach() + self.firing_rate
self.n_batches = self.n_batches + 1


Expand Down Expand Up @@ -115,8 +114,10 @@ def _setup_hooks(self):

for layer in self.model.modules():
if isinstance(layer, sl.StatefulLayer):
layer.firing_rate_per_neuron = None
layer.tracked_firing_rate = torch.tensor(0)
layer.firing_rate_per_neuron = torch.tensor(0)
layer.acc_firing_rate_per_neuron = torch.tensor(0)
layer.firing_rate = torch.tensor(0)
layer.acc_firing_rate = torch.tensor(0)
layer.n_batches = 0
handle = layer.register_forward_hook(spiking_hook)
self.handles.append(handle)
Expand Down Expand Up @@ -153,20 +154,26 @@ def get_layer_statistics(self, average: bool = False) -> dict:
scale_facts = []
for name, module in self.model.named_modules():
if (
hasattr(module, "tracked_firing_rate")
hasattr(module, "acc_firing_rate")
or hasattr(module, "firing_rate_per_neuron")
or hasattr(module, "n_neurons")
):
spike_dict["spiking"][name] = {}
if hasattr(module, "tracked_firing_rate"):
if hasattr(module, "acc_firing_rate"):
spike_dict["spiking"][name].update(
{"firing_rate": module.tracked_firing_rate / module.n_batches}
{
"firing_rate": module.acc_firing_rate / module.n_batches
if average
else module.firing_rate
}
)
if hasattr(module, "firing_rate_per_neuron"):
spike_dict["spiking"][name].update(
{
"firing_rate_per_neuron": module.firing_rate_per_neuron
"firing_rate_per_neuron": module.acc_firing_rate_per_neuron
/ module.n_batches
if average
else module.firing_rate_per_neuron
}
)
if hasattr(module, "n_neurons"):
Expand Down Expand Up @@ -219,7 +226,9 @@ def get_model_statistics(self, average=False):
if hasattr(module, "firing_rate_per_neuron"):
if module.n_batches > 0:
firing_rates.append(
module.firing_rate_per_neuron.ravel() / module.n_batches
module.acc_firing_rate_per_neuron.ravel() / module.n_batches
if average
else module.firing_rate_per_neuron.ravel()
)
else:
firing_rates.append(torch.tensor([0.0]))
Expand Down
18 changes: 9 additions & 9 deletions tests/test_synops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def test_spiking_layer_firing_rate():

analyzer = sinabs.SNNAnalyzer(layer)
output = layer(input_)
model_stats = analyzer.get_model_statistics()
layer_stats = analyzer.get_layer_statistics()["spiking"][""]
model_stats = analyzer.get_model_statistics(average=True)
layer_stats = analyzer.get_layer_statistics(average=True)["spiking"][""]

assert (output == input_).all()
assert model_stats["firing_rate"] == 0.25
Expand All @@ -168,8 +168,8 @@ def test_spiking_layer_firing_rate_across_batches():
output = layer(input1)
sinabs.reset_states(layer)
output = layer(input2)
model_stats = analyzer.get_model_statistics()
layer_stats = analyzer.get_layer_statistics()["spiking"][""]
model_stats = analyzer.get_model_statistics(average=True)
layer_stats = analyzer.get_layer_statistics(average=True)["spiking"][""]

assert (output == input2).all()
assert model_stats["firing_rate"] == 0.375
Expand All @@ -188,8 +188,8 @@ def test_analyzer_reset():
sinabs.reset_states(layer)
analyzer.reset()
output = layer(input_)
model_stats = analyzer.get_model_statistics()
layer_stats = analyzer.get_layer_statistics()["spiking"][""]
model_stats = analyzer.get_model_statistics(average=True)
layer_stats = analyzer.get_layer_statistics(average=True)["spiking"][""]

assert (output == input_).all()
assert model_stats["firing_rate"] == 0.5
Expand All @@ -212,9 +212,9 @@ def test_snn_analyzer_statistics():
input_ = torch.rand((batch_size, num_timesteps, 1, 16, 16)) * 100
input_flattended = input_.flatten(0, 1)
output = model(input_flattended)
spike_layer_stats = analyzer.get_layer_statistics()["spiking"]
param_layer_stats = analyzer.get_layer_statistics()["parameter"]
model_stats = analyzer.get_model_statistics()
spike_layer_stats = analyzer.get_layer_statistics(average=True)["spiking"]
param_layer_stats = analyzer.get_layer_statistics(average=True)["parameter"]
model_stats = analyzer.get_model_statistics(average=True)

# spiking layer checks
assert (
Expand Down

0 comments on commit 33c642f

Please sign in to comment.