Skip to content

Commit

Permalink
run formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
innat committed Jan 27, 2024
1 parent e1f3d30 commit 2636d9b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 23 deletions.
18 changes: 6 additions & 12 deletions keras_cv/layers/video_swin_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def window_partition(x, window_size):
Returns:
windows: (batch_size*num_windows, window_size*window_size, channel)
""" # noqa: E501
""" # noqa: E501

input_shape = ops.shape(x)
batch_size, depth, height, width, channel = (
Expand Down Expand Up @@ -74,7 +74,7 @@ def window_reverse(windows, window_size, batch_size, depth, height, width):
Returns:
x: (batch_size, depth, height, width, channel)
""" # noqa: E501
""" # noqa: E501
x = ops.reshape(
windows,
[
Expand Down Expand Up @@ -106,7 +106,7 @@ def get_window_size(x_size, window_size, shift_size=None):
Returns:
x: window_size, shift_size
""" # noqa: E501
""" # noqa: E501

use_window_size = list(window_size)

Expand Down Expand Up @@ -413,7 +413,7 @@ def get_relative_position_index(
y_y = relative_coords[:, :, 2] + window_width - 1
relative_coords = ops.stack([z_z, x_x, y_y], axis=-1)
return ops.sum(relative_coords, axis=-1)

def build(self, input_shape):
input_dim = input_shape[-1]
head_dim = input_dim // self.num_heads
Expand Down Expand Up @@ -560,7 +560,6 @@ def __init__(
), "shift_size must in 0-window_size"

def build(self, input_shape):

input_dim = input_shape[-1]
self.mlp_hidden_dim = int(input_dim * self.mlp_ratio)
self.window_size, self.shift_size = get_window_size(
Expand Down Expand Up @@ -596,10 +595,7 @@ def build(self, input_shape):
drop_rate=self.drop_rate,
)


def call(
self, x, mask_matrix=None, training=None
):
def call(self, x, mask_matrix=None, training=None):
shortcut = x
input_shape = ops.shape(x)
batch_size, depth, height, width, _ = (
Expand Down Expand Up @@ -822,9 +818,7 @@ def call(self, x, training=None):
)

for block in self.blocks:
x = block(
x, self.attn_mask, training=training
)
x = block(x, self.attn_mask, training=training)

x = ops.reshape(x, [batch_size, depth, height, width, -1])

Expand Down
22 changes: 11 additions & 11 deletions keras_cv/layers/video_swin_transformer_layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

from keras_cv.backend import ops
from keras_cv.layers.video_swin_transformer_layers import PatchEmbedding3D
from keras_cv.layers.video_swin_transformer_layers import WindowAttention3D
from keras_cv.layers.video_swin_transformer_layers import SwinTransformerBlock3D
from keras_cv.tests.test_case import TestCase

class TestPatchEmbedding3D(TestCase):

class TestPatchEmbedding3D(TestCase):
def test_patch_embedding_compute_output_shape(self):
patch_embedding_model = PatchEmbedding3D(patch_size=(2, 4, 4), embed_dim=96, norm_layer=None)
patch_embedding_model = PatchEmbedding3D(
patch_size=(2, 4, 4), embed_dim=96, norm_layer=None
)
input_shape = (None, 16, 32, 32, 3)
output_shape = patch_embedding_model.compute_output_shape(input_shape)
expected_output_shape = (None, 8, 8, 8, 96)
self.assertEqual(output_shape, expected_output_shape)

def test_patch_embedding_get_config(self):
patch_embedding_model = PatchEmbedding3D(patch_size=(4, 4, 4), embed_dim=96)
patch_embedding_model = PatchEmbedding3D(
patch_size=(4, 4, 4), embed_dim=96
)
config = patch_embedding_model.get_config()
assert isinstance(config, dict)
assert config["patch_size"] == (4, 4, 4)
assert config["embed_dim"] == 96


class TestWindowAttention3D(TestCase):

@pytest.fixture
def window_attention_model(self):
return WindowAttention3D(
Expand All @@ -50,7 +51,7 @@ def window_attention_model(self):
attn_drop_rate=0.1,
proj_drop_rate=0.1,
)

def test_window_attention_output_shape(self, window_attention_model):
input_shape = (4, 10, 256)
input_array = ops.ones(input_shape)
Expand All @@ -64,8 +65,7 @@ def test_window_attention_get_config(self, window_attention_model):
assert isinstance(config, dict)
assert config["window_size"] == (2, 4, 4)
assert config["num_heads"] == 8
assert config["qkv_bias"] == True
assert config["qk_scale"] == None
assert config["qkv_bias"] is True
assert config["qk_scale"] is None
assert config["attn_drop_rate"] == 0.1
assert config["proj_drop_rate"] == 0.1

0 comments on commit 2636d9b

Please sign in to comment.