diff --git a/keras_cv/layers/video_swin_transformer_layers.py b/keras_cv/layers/video_swin_transformer_layers.py index a50cf893f4..92370dc193 100644 --- a/keras_cv/layers/video_swin_transformer_layers.py +++ b/keras_cv/layers/video_swin_transformer_layers.py @@ -31,7 +31,7 @@ def window_partition(x, window_size): Returns: windows: (batch_size*num_windows, window_size*window_size, channel) - """ # noqa: E501 + """ # noqa: E501 input_shape = ops.shape(x) batch_size, depth, height, width, channel = ( @@ -74,7 +74,7 @@ def window_reverse(windows, window_size, batch_size, depth, height, width): Returns: x: (batch_size, depth, height, width, channel) - """ # noqa: E501 + """ # noqa: E501 x = ops.reshape( windows, [ @@ -106,7 +106,7 @@ def get_window_size(x_size, window_size, shift_size=None): Returns: x: window_size, shift_size - """ # noqa: E501 + """ # noqa: E501 use_window_size = list(window_size) @@ -413,7 +413,7 @@ def get_relative_position_index( y_y = relative_coords[:, :, 2] + window_width - 1 relative_coords = ops.stack([z_z, x_x, y_y], axis=-1) return ops.sum(relative_coords, axis=-1) - + def build(self, input_shape): input_dim = input_shape[-1] head_dim = input_dim // self.num_heads @@ -560,7 +560,6 @@ def __init__( ), "shift_size must in 0-window_size" def build(self, input_shape): - input_dim = input_shape[-1] self.mlp_hidden_dim = int(input_dim * self.mlp_ratio) self.window_size, self.shift_size = get_window_size( @@ -596,10 +595,7 @@ def build(self, input_shape): drop_rate=self.drop_rate, ) - - def call( - self, x, mask_matrix=None, training=None - ): + def call(self, x, mask_matrix=None, training=None): shortcut = x input_shape = ops.shape(x) batch_size, depth, height, width, _ = ( @@ -822,9 +818,7 @@ def call(self, x, training=None): ) for block in self.blocks: - x = block( - x, self.attn_mask, training=training - ) + x = block(x, self.attn_mask, training=training) x = ops.reshape(x, [batch_size, depth, height, width, -1]) diff --git a/keras_cv/layers/video_swin_transformer_layers_test.py b/keras_cv/layers/video_swin_transformer_layers_test.py index 4f9bafacf7..1ec1e50429 100644 --- a/keras_cv/layers/video_swin_transformer_layers_test.py +++ b/keras_cv/layers/video_swin_transformer_layers_test.py @@ -12,26 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import pytest from keras_cv.backend import ops from keras_cv.layers.video_swin_transformer_layers import PatchEmbedding3D from keras_cv.layers.video_swin_transformer_layers import WindowAttention3D -from keras_cv.layers.video_swin_transformer_layers import SwinTransformerBlock3D from keras_cv.tests.test_case import TestCase -class TestPatchEmbedding3D(TestCase): +class TestPatchEmbedding3D(TestCase): def test_patch_embedding_compute_output_shape(self): - patch_embedding_model = PatchEmbedding3D(patch_size=(2, 4, 4), embed_dim=96, norm_layer=None) + patch_embedding_model = PatchEmbedding3D( + patch_size=(2, 4, 4), embed_dim=96, norm_layer=None + ) input_shape = (None, 16, 32, 32, 3) output_shape = patch_embedding_model.compute_output_shape(input_shape) expected_output_shape = (None, 8, 8, 8, 96) self.assertEqual(output_shape, expected_output_shape) - + def test_patch_embedding_get_config(self): - patch_embedding_model = PatchEmbedding3D(patch_size=(4, 4, 4), embed_dim=96) + patch_embedding_model = PatchEmbedding3D( + patch_size=(4, 4, 4), embed_dim=96 + ) config = patch_embedding_model.get_config() assert isinstance(config, dict) assert config["patch_size"] == (4, 4, 4) @@ -39,7 +41,6 @@ def test_patch_embedding_get_config(self): class TestWindowAttention3D(TestCase): - @pytest.fixture def window_attention_model(self): return WindowAttention3D( @@ -50,7 +51,7 @@ def window_attention_model(self): attn_drop_rate=0.1, proj_drop_rate=0.1, ) - + def test_window_attention_output_shape(self, window_attention_model): input_shape = (4, 10, 256) input_array = ops.ones(input_shape) @@ -64,8 +65,7 @@ def test_window_attention_get_config(self, window_attention_model): assert isinstance(config, dict) assert config["window_size"] == (2, 4, 4) assert config["num_heads"] == 8 - assert config["qkv_bias"] == True - assert config["qk_scale"] == None + assert config["qkv_bias"] is True + assert config["qk_scale"] is None assert config["attn_drop_rate"] == 0.1 assert config["proj_drop_rate"] == 0.1 - \ No newline at end of file