diff --git a/sinabs/layers/__init__.py b/sinabs/layers/__init__.py index b47c4bfc..406a4a38 100644 --- a/sinabs/layers/__init__.py +++ b/sinabs/layers/__init__.py @@ -9,3 +9,5 @@ from .reshape import FlattenTime, Repeat, SqueezeMixin, UnflattenTime from .stateful_layer import StatefulLayer from .to_spike import Img2SpikeLayer, Sig2SpikeLayer +from .merge import Merge +from .channel_shift import ChannelShift \ No newline at end of file diff --git a/sinabs/layers/channel_shift.py b/sinabs/layers/channel_shift.py new file mode 100644 index 00000000..34021b35 --- /dev/null +++ b/sinabs/layers/channel_shift.py @@ -0,0 +1,23 @@ +import torch +import torch.nn as nn + + +class ChannelShift(nn.Module): + def __init__(self, channel_shift: int = 0, channel_axis=-3) -> None: + """Given a tensor, shift the channel from the left, ie zero pad from the left. + + Args: + channel_shift (int, optional): Number of channels to shift by. Defaults to 0. + channel_axis (int, optional): The channel axis dimension + NOTE: This has to be a negative dimension such that it counts the dimension from the right. Defaults to -3. + """ + super().__init__() + self.padding = [] + self.channel_shift = channel_shift + self.channel_axis = channel_axis + for axis in range(-channel_axis): + self.padding += [0, 0] + self.padding[-2] = channel_shift + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return nn.functional.pad(input=x, pad=self.padding, mode="constant", value=0) diff --git a/sinabs/layers/merge.py b/sinabs/layers/merge.py new file mode 100644 index 00000000..cd631469 --- /dev/null +++ b/sinabs/layers/merge.py @@ -0,0 +1,28 @@ +import torch.nn as nn + +class Merge(nn.Module): + def __init__(self) -> None: + """ + Module form for a merge operation. + In the context of events/spikes, events/spikes from two different sources/rasters will be added. + """ + super().__init__() + + def forward(self, data1, data2): + size1 = data1.shape + size2 = data2.shape + if size1 == size2: + return data1 + data2 + # If the sizes are not the same, find the larger size and pad the data accordingly + assert len(size1) == len(size2) + pad1 = () + pad2 = () + # Find the larger sizes + for s1, s2 in zip(size1, size2): + s_max = max(s1, s2) + pad1 = (0, s_max-s1, *pad1) + pad2 = (0, s_max-s2, *pad2) + + data1 = nn.functional.pad(input=data1, pad=pad1, mode="constant", value=0) + data2 = nn.functional.pad(input=data2, pad=pad2, mode="constant", value=0) + return data1 + data2 diff --git a/tests/requirements.txt b/tests/requirements.txt index 49a294b1..ceaca48d 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,5 +3,4 @@ pytest-cov onnx onnxruntime torch>=1.8 -torchvision matplotlib diff --git a/tests/test_channelshift.py b/tests/test_channelshift.py new file mode 100644 index 00000000..bc2e0321 --- /dev/null +++ b/tests/test_channelshift.py @@ -0,0 +1,21 @@ +import torch +from sinabs.layers.channel_shift import ChannelShift + + +def test_channel_shift_default(): + x = torch.rand(1, 10, 5, 5) + cs = ChannelShift() + + out = cs(x) + assert out.shape == x.shape + + +def test_channel_shift(): + num_channels = 10 + channel_shift = 14 + x = torch.rand(1, num_channels, 5, 5) + cs = ChannelShift(channel_shift=channel_shift) + + out = cs(x) + assert len(out.shape) == len(x.shape) + assert out.shape[1] == num_channels + channel_shift diff --git a/tests/test_merge.py b/tests/test_merge.py new file mode 100644 index 00000000..1b3911e5 --- /dev/null +++ b/tests/test_merge.py @@ -0,0 +1,22 @@ +import torch +import sinabs.layers as sl + + +def test_morph_same_size(): + data1 = (torch.rand((100, 1, 20, 20)) > 0.5).float() + data2 = (torch.rand((100, 1, 20, 20)) > 0.5).float() + + merge = sl.Merge() + out = merge(data1, data2) + assert out.shape == (100, 1, 20, 20) + + +def test_morph_different_size(): + data1 = (torch.rand((100, 1, 5, 6)) > 0.5).float() + data2 = (torch.rand((100, 10, 5, 5)) > 0.5).float() + + merge = sl.Merge() + out = merge(data1, data2) + + assert out.shape == (100, 10, 5, 6) + assert out.sum() == data1.sum() + data2.sum()