diff --git a/keras_cv/layers/object_detection/roi_pool.py b/keras_cv/layers/object_detection/roi_pool.py index e9774e3e15..3105b1d4be 100644 --- a/keras_cv/layers/object_detection/roi_pool.py +++ b/keras_cv/layers/object_detection/roi_pool.py @@ -112,11 +112,12 @@ def _pool_single_sample(self, args): feature_map: [H, W, C] float Tensor rois: [N, 4] float Tensor Returns: - pooled_feature_map: [target_size, C] float Tensor + pooled_feature_map: [N, target_height, target_width, C] float Tensor """ feature_map, rois = args num_rois = rois.get_shape().as_list()[0] height, width, channel = feature_map.get_shape().as_list() + regions = [] # TODO (consider vectorize it for better performance) for n in range(num_rois): # [4] @@ -127,7 +128,7 @@ def _pool_single_sample(self, args): region_width = width * (roi[3] - roi[1]) h_step = region_height / self.target_height w_step = region_width / self.target_width - regions = [] + region_steps = [] for i in range(self.target_height): for j in range(self.target_width): height_start = y_start + i * h_step @@ -147,16 +148,18 @@ def _pool_single_sample(self, args): 1, width_end - width_start ) # [h_step, w_step, C] - region = feature_map[ + region_step = feature_map[ height_start:height_end, width_start:width_end, : ] # target_height * target_width * [C] - regions.append(tf.reduce_max(region, axis=[0, 1])) - regions = tf.reshape( - tf.stack(regions), - [self.target_height, self.target_width, channel], + region_steps.append(tf.reduce_max(region_step, axis=[0, 1])) + regions.append( + tf.reshape( + tf.stack(region_steps), + [self.target_height, self.target_width, channel], + ) ) - return regions + return tf.stack(regions) def get_config(self): config = { diff --git a/keras_cv/layers/object_detection/roi_pool_test.py b/keras_cv/layers/object_detection/roi_pool_test.py index c6401beebc..e605c3e5a7 100644 --- a/keras_cv/layers/object_detection/roi_pool_test.py +++ b/keras_cv/layers/object_detection/roi_pool_test.py @@ -43,7 +43,7 @@ def test_no_quantize(self): # | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) | # -------------------------------------------- expected_feature_map = tf.reshape( - tf.constant([27, 31, 59, 63]), [1, 2, 2, 1] + tf.constant([27, 31, 59, 63]), [1, 1, 2, 2, 1] ) self.assertAllClose(expected_feature_map, pooled_feature_map) @@ -69,7 +69,7 @@ def test_roi_quantize_y(self): # | 56, 57, 58(max) | 59, 60, 61, 62(max) | 63 (removed) # -------------------------------------------- expected_feature_map = tf.reshape( - tf.constant([26, 30, 58, 62]), [1, 2, 2, 1] + tf.constant([26, 30, 58, 62]), [1, 1, 2, 2, 1] ) self.assertAllClose(expected_feature_map, pooled_feature_map) @@ -94,7 +94,7 @@ def test_roi_quantize_x(self): # | 48, 49, 50, 51(max) | 52, 53, 54, 55(max) | # -------------------------------------------- expected_feature_map = tf.reshape( - tf.constant([19, 23, 51, 55]), [1, 2, 2, 1] + tf.constant([19, 23, 51, 55]), [1, 1, 2, 2, 1] ) self.assertAllClose(expected_feature_map, pooled_feature_map) @@ -121,7 +121,7 @@ def test_roi_quantize_h(self): # | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) | # -------------------------------------------- expected_feature_map = tf.reshape( - tf.constant([11, 15, 35, 39, 59, 63]), [1, 3, 2, 1] + tf.constant([11, 15, 35, 39, 59, 63]), [1, 1, 3, 2, 1] ) self.assertAllClose(expected_feature_map, pooled_feature_map) @@ -147,7 +147,7 @@ def test_roi_quantize_w(self): # | 56, 57(max) | 58, 59, 60(max) | 61, 62, 63(max) | # -------------------------------------------- expected_feature_map = tf.reshape( - tf.constant([25, 28, 31, 57, 60, 63]), [1, 2, 3, 1] + tf.constant([25, 28, 31, 57, 60, 63]), [1, 1, 2, 3, 1] ) self.assertAllClose(expected_feature_map, pooled_feature_map) @@ -168,7 +168,8 @@ def test_roi_feature_map_height_smaller_than_roi(self): # ------------------repeated---------------------- # | 12, 13(max) | 14, 15(max) | expected_feature_map = tf.reshape( - tf.constant([1, 3, 1, 3, 5, 7, 9, 11, 9, 11, 13, 15]), [1, 6, 2, 1] + tf.constant([1, 3, 1, 3, 5, 7, 9, 11, 9, 11, 13, 15]), + [1, 1, 6, 2, 1], ) self.assertAllClose(expected_feature_map, pooled_feature_map) @@ -189,7 +190,7 @@ def test_roi_feature_map_width_smaller_than_roi(self): # -------------------------------------------- expected_feature_map = tf.reshape( tf.constant([4, 4, 5, 6, 6, 7, 12, 12, 13, 14, 14, 15]), - [1, 2, 6, 1], + [1, 1, 2, 6, 1], ) self.assertAllClose(expected_feature_map, pooled_feature_map) @@ -203,10 +204,43 @@ def test_roi_empty(self): rois = tf.reshape(tf.constant([0.0, 0.0, 0.0, 0.0]), [1, 1, 4]) pooled_feature_map = roi_pooler(feature_map, rois) # all outputs should be top-left pixel - self.assertAllClose(tf.ones([1, 2, 2, 1]), pooled_feature_map) + self.assertAllClose(tf.ones([1, 1, 2, 2, 1]), pooled_feature_map) def test_invalid_image_shape(self): with self.assertRaisesRegex(ValueError, "dynamic shape"): _ = ROIPooler( "rel_yxyx", target_size=[2, 2], image_shape=[None, 224, 3] ) + + def test_multiple_rois(self): + feature_map = tf.expand_dims( + tf.reshape(tf.range(0, 64), [8, 8, 1]), axis=0 + ) + + roi_pooler = ROIPooler( + bounding_box_format="yxyx", + target_size=[2, 2], + image_shape=[224, 224, 3], + ) + rois = tf.constant( + [[[0.0, 0.0, 112.0, 112.0], [0.0, 112.0, 224.0, 224.0]]], + ) + + pooled_feature_map = roi_pooler(feature_map, rois) + # the maximum value would be at bottom-right at each block, roi sharded + # into 2x2 blocks + # | 0, 1, 2, 3 | 4, 5, 6, 7 | + # | 8, 9, 10, 11 | 12, 13, 14, 15 | + # | 16, 17, 18, 19 | 20, 21, 22, 23 | + # | 24, 25, 26, 27(max) | 28, 29, 30, 31(max) | + # -------------------------------------------- + # | 32, 33, 34, 35 | 36, 37, 38, 39 | + # | 40, 41, 42, 43 | 44, 45, 46, 47 | + # | 48, 49, 50, 51 | 52, 53, 54, 55 | + # | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) | + # -------------------------------------------- + + expected_feature_map = tf.reshape( + tf.constant([9, 11, 25, 27, 29, 31, 61, 63]), [1, 2, 2, 2, 1] + ) + self.assertAllClose(expected_feature_map, pooled_feature_map)