Skip to content

Commit

Permalink
added channel shift layer
Browse files Browse the repository at this point in the history
  • Loading branch information
sheiksadique committed Jul 3, 2023
1 parent 89cfe36 commit fbf408e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
23 changes: 23 additions & 0 deletions sinabs/layers/channel_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import torch.nn as nn


class ChannelShift(nn.Module):
def __init__(self, channel_shift: int = 0, channel_axis=-3) -> None:
"""Given a tensor, shift the channel from the left, ie zero pad from the left.
Args:
channel_shift (int, optional): Number of channels to shift by. Defaults to 0.
channel_axis (int, optional): The channel axis dimension
NOTE: This has to be a negative dimension such that it counts the dimension from the right. Defaults to -3.
"""
super().__init__()
self.padding = []
self.channel_shift = channel_shift
self.channel_axis = channel_axis
for axis in range(-channel_axis):
self.padding += [0, 0]
self.padding[-2] = channel_shift

def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.pad(input=x, pad=self.padding, mode="constant", value=0)
21 changes: 21 additions & 0 deletions tests/test_channelshift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from sinabs.layers.channel_shift import ChannelShift


def test_channel_shift_default():
x = torch.rand(1, 10, 5, 5)
cs = ChannelShift()

out = cs(x)
assert out.shape == x.shape


def test_channel_shift():
num_channels = 10
channel_shift = 14
x = torch.rand(1, num_channels, 5, 5)
cs = ChannelShift(channel_shift=channel_shift)

out = cs(x)
assert len(out.shape) == len(x.shape)
assert out.shape[1] == num_channels + channel_shift

0 comments on commit fbf408e

Please sign in to comment.