Skip to content

Commit

Permalink
Merge pull request #103 from synsense/graph-tracing
Browse files Browse the repository at this point in the history
This merge adds two layers that are useful for graph tracing
  • Loading branch information
sheiksadique authored Jul 10, 2023
2 parents df3a113 + 8a6c7a9 commit 274e0a3
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 1 deletion.
2 changes: 2 additions & 0 deletions sinabs/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@
from .reshape import FlattenTime, Repeat, SqueezeMixin, UnflattenTime
from .stateful_layer import StatefulLayer
from .to_spike import Img2SpikeLayer, Sig2SpikeLayer
from .merge import Merge
from .channel_shift import ChannelShift
23 changes: 23 additions & 0 deletions sinabs/layers/channel_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import torch.nn as nn


class ChannelShift(nn.Module):
def __init__(self, channel_shift: int = 0, channel_axis=-3) -> None:
"""Given a tensor, shift the channel from the left, ie zero pad from the left.
Args:
channel_shift (int, optional): Number of channels to shift by. Defaults to 0.
channel_axis (int, optional): The channel axis dimension
NOTE: This has to be a negative dimension such that it counts the dimension from the right. Defaults to -3.
"""
super().__init__()
self.padding = []
self.channel_shift = channel_shift
self.channel_axis = channel_axis
for axis in range(-channel_axis):
self.padding += [0, 0]
self.padding[-2] = channel_shift

def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.pad(input=x, pad=self.padding, mode="constant", value=0)
28 changes: 28 additions & 0 deletions sinabs/layers/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch.nn as nn

class Merge(nn.Module):
def __init__(self) -> None:
"""
Module form for a merge operation.
In the context of events/spikes, events/spikes from two different sources/rasters will be added.
"""
super().__init__()

def forward(self, data1, data2):
size1 = data1.shape
size2 = data2.shape
if size1 == size2:
return data1 + data2
# If the sizes are not the same, find the larger size and pad the data accordingly
assert len(size1) == len(size2)
pad1 = ()
pad2 = ()
# Find the larger sizes
for s1, s2 in zip(size1, size2):
s_max = max(s1, s2)
pad1 = (0, s_max-s1, *pad1)
pad2 = (0, s_max-s2, *pad2)

data1 = nn.functional.pad(input=data1, pad=pad1, mode="constant", value=0)
data2 = nn.functional.pad(input=data2, pad=pad2, mode="constant", value=0)
return data1 + data2
1 change: 0 additions & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ pytest-cov
onnx
onnxruntime
torch>=1.8
torchvision
matplotlib
21 changes: 21 additions & 0 deletions tests/test_channelshift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from sinabs.layers.channel_shift import ChannelShift


def test_channel_shift_default():
x = torch.rand(1, 10, 5, 5)
cs = ChannelShift()

out = cs(x)
assert out.shape == x.shape


def test_channel_shift():
num_channels = 10
channel_shift = 14
x = torch.rand(1, num_channels, 5, 5)
cs = ChannelShift(channel_shift=channel_shift)

out = cs(x)
assert len(out.shape) == len(x.shape)
assert out.shape[1] == num_channels + channel_shift
22 changes: 22 additions & 0 deletions tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
import sinabs.layers as sl


def test_morph_same_size():
data1 = (torch.rand((100, 1, 20, 20)) > 0.5).float()
data2 = (torch.rand((100, 1, 20, 20)) > 0.5).float()

merge = sl.Merge()
out = merge(data1, data2)
assert out.shape == (100, 1, 20, 20)


def test_morph_different_size():
data1 = (torch.rand((100, 1, 5, 6)) > 0.5).float()
data2 = (torch.rand((100, 10, 5, 5)) > 0.5).float()

merge = sl.Merge()
out = merge(data1, data2)

assert out.shape == (100, 10, 5, 6)
assert out.sum() == data1.sum() + data2.sum()

0 comments on commit 274e0a3

Please sign in to comment.