Skip to content

Commit

Permalink
chore: port roi pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
ariG23498 committed Feb 3, 2024
1 parent f8a8b2e commit 9fe4fa9
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 71 deletions.
26 changes: 11 additions & 15 deletions keras_cv/layers/object_detection/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ def multilevel_crop_and_resize(
# Concat tensor of [batch_size, height_l * width_l, num_filters] for
# each level.
features_all.append(
ops.reshape(
features[f"P{level}"], [batch_size, -1, num_filters]
)
ops.reshape(features[f"P{level}"], [batch_size, -1, num_filters])
)
features_r2 = ops.reshape(
ops.concatenate(features_all, 1), [-1, num_filters]
Expand All @@ -244,8 +242,12 @@ def multilevel_crop_and_resize(
for i in range(len(feature_widths) - 1):
level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i])
batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1]
level_dim_offsets = ops.ones_like(level_dim_offsets, dtype="int32") * level_dim_offsets
height_dim_sizes = ops.ones_like(feature_widths, dtype="int32") * feature_widths
level_dim_offsets = (
ops.ones_like(level_dim_offsets, dtype="int32") * level_dim_offsets
)
height_dim_sizes = (
ops.ones_like(feature_widths, dtype="int32") * feature_widths
)

# Assigns boxes to the right level.
box_width = boxes[:, :, 3] - boxes[:, :, 1]
Expand Down Expand Up @@ -290,14 +292,12 @@ def multilevel_crop_and_resize(
ops.concatenate(
[
ops.expand_dims(
[[ops.cast(max_feature_height, "float32")]]
/ level_strides
[[ops.cast(max_feature_height, "float32")]] / level_strides
- 1,
axis=-1,
),
ops.expand_dims(
[[ops.cast(max_feature_width, "float32")]]
/ level_strides
[[ops.cast(max_feature_width, "float32")]] / level_strides
- 1,
axis=-1,
),
Expand Down Expand Up @@ -340,8 +340,7 @@ def multilevel_crop_and_resize(
)
y_indices_offset = ops.tile(
ops.reshape(
y_indices
* ops.expand_dims(ops.take(height_dim_sizes, levels), -1),
y_indices * ops.expand_dims(ops.take(height_dim_sizes, levels), -1),
[batch_size, num_boxes, output_size * 2, 1],
),
[1, 1, 1, output_size * 2],
Expand All @@ -351,10 +350,7 @@ def multilevel_crop_and_resize(
[1, 1, output_size * 2, 1],
)
indices = ops.reshape(
batch_size_offset
+ levels_offset
+ y_indices_offset
+ x_indices_offset,
batch_size_offset + levels_offset + y_indices_offset + x_indices_offset,
[-1],
)

Expand Down
26 changes: 12 additions & 14 deletions keras_cv/layers/object_detection/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import assert_tf_keras
from keras_cv.backend import keras
from keras_cv.backend import ops


@keras_cv_export("keras_cv.layers.ROIPooler")
Expand Down Expand Up @@ -59,7 +58,6 @@ def __init__(
image_shape,
**kwargs,
):
assert_tf_keras("keras_cv.layers.ROIPooler")
if not isinstance(target_size, (tuple, list)):
raise ValueError(
"Expected `target_size` to be tuple or list, got "
Expand Down Expand Up @@ -101,7 +99,7 @@ def call(self, feature_map, rois):
target="rel_yxyx",
image_shape=self.image_shape,
)
pooled_feature_map = tf.vectorized_map(
pooled_feature_map = ops.vectorized_map(
self._pool_single_sample, (feature_map, rois)
)
return pooled_feature_map
Expand Down Expand Up @@ -132,28 +130,28 @@ def _pool_single_sample(self, args):
for j in range(self.target_width):
height_start = y_start + i * h_step
height_end = height_start + h_step
height_start = tf.cast(height_start, tf.int32)
height_end = tf.cast(height_end, tf.int32)
height_start = ops.cast(height_start, "int32")
height_end = ops.cast(height_end, "int32")
# if feature_map shape smaller than roi, h_step would be 0
# in this case the result will be feature_map[0, 0, ...]
height_end = height_start + tf.maximum(
height_end = height_start + ops.maximum(
1, height_end - height_start
)
width_start = x_start + j * w_step
width_end = width_start + w_step
width_start = tf.cast(width_start, tf.int32)
width_end = tf.cast(width_end, tf.int32)
width_end = width_start + tf.maximum(
width_start = ops.cast(width_start, "int32")
width_end = ops.cast(width_end, "int32")
width_end = width_start + ops.maximum(
1, width_end - width_start
)
# [h_step, w_step, C]
region = feature_map[
height_start:height_end, width_start:width_end, :
]
# target_height * target_width * [C]
regions.append(tf.reduce_max(region, axis=[0, 1]))
regions = tf.reshape(
tf.stack(regions),
regions.append(ops.max(region, axis=[0, 1]))
regions = ops.reshape(
ops.stack(regions),
[self.target_height, self.target_width, channel],
)
return regions
Expand Down
82 changes: 40 additions & 42 deletions keras_cv/layers/object_detection/roi_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import tensorflow as tf
import numpy as np

from keras_cv.layers.object_detection.roi_pool import ROIPooler
from keras_cv.tests.test_case import TestCase


@pytest.mark.tf_keras_only
class ROIPoolTest(TestCase):
def test_no_quantize(self):
roi_pooler = ROIPooler(
"rel_yxyx", target_size=[2, 2], image_shape=[224, 224, 3]
)
feature_map = tf.expand_dims(
tf.reshape(tf.range(64), [8, 8, 1]), axis=0
feature_map = np.expand_dims(
np.reshape(np.arange(64), [8, 8, 1]), axis=0
)
rois = tf.reshape(tf.constant([0.0, 0.0, 1.0, 1.0]), [1, 1, 4])
rois = np.reshape(np.array([0.0, 0.0, 1.0, 1.0]), [1, 1, 4])
pooled_feature_map = roi_pooler(feature_map, rois)
# the maximum value would be at bottom-right at each block, roi sharded
# into 2x2 blocks
Expand All @@ -42,19 +40,19 @@ def test_no_quantize(self):
# | 48, 49, 50, 51 | 52, 53, 54, 55 |
# | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) |
# --------------------------------------------
expected_feature_map = tf.reshape(
tf.constant([27, 31, 59, 63]), [1, 2, 2, 1]
expected_feature_map = np.reshape(
np.array([27, 31, 59, 63]), [1, 2, 2, 1]
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

def test_roi_quantize_y(self):
roi_pooler = ROIPooler(
"yxyx", target_size=[2, 2], image_shape=[224, 224, 3]
)
feature_map = tf.expand_dims(
tf.reshape(tf.range(64), [8, 8, 1]), axis=0
feature_map = np.expand_dims(
np.reshape(np.arange(64), [8, 8, 1]), axis=0
)
rois = tf.reshape(tf.constant([0.0, 0.0, 224, 220]), [1, 1, 4])
rois = np.reshape(np.array([0.0, 0.0, 224, 220]), [1, 1, 4])
pooled_feature_map = roi_pooler(feature_map, rois)
# the maximum value would be at bottom-right at each block, roi sharded
# into 2x2 blocks
Expand All @@ -68,19 +66,19 @@ def test_roi_quantize_y(self):
# | 48, 49, 50 | 51, 52, 53, 54 | 55 (removed)
# | 56, 57, 58(max) | 59, 60, 61, 62(max) | 63 (removed)
# --------------------------------------------
expected_feature_map = tf.reshape(
tf.constant([26, 30, 58, 62]), [1, 2, 2, 1]
expected_feature_map = np.reshape(
np.array([26, 30, 58, 62]), [1, 2, 2, 1]
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

def test_roi_quantize_x(self):
roi_pooler = ROIPooler(
"yxyx", target_size=[2, 2], image_shape=[224, 224, 3]
)
feature_map = tf.expand_dims(
tf.reshape(tf.range(64), [8, 8, 1]), axis=0
feature_map = np.expand_dims(
np.reshape(np.arange(64), [8, 8, 1]), axis=0
)
rois = tf.reshape(tf.constant([0.0, 0.0, 220, 224]), [1, 1, 4])
rois = np.reshape(np.array([0.0, 0.0, 220, 224]), [1, 1, 4])
pooled_feature_map = roi_pooler(feature_map, rois)
# the maximum value would be at bottom-right at each block, roi sharded
# into 2x2 blocks
Expand All @@ -93,19 +91,19 @@ def test_roi_quantize_x(self):
# | 40, 41, 42, 43 | 44, 45, 46, 47 |
# | 48, 49, 50, 51(max) | 52, 53, 54, 55(max) |
# --------------------------------------------
expected_feature_map = tf.reshape(
tf.constant([19, 23, 51, 55]), [1, 2, 2, 1]
expected_feature_map = np.reshape(
np.array([19, 23, 51, 55]), [1, 2, 2, 1]
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

def test_roi_quantize_h(self):
roi_pooler = ROIPooler(
"yxyx", target_size=[3, 2], image_shape=[224, 224, 3]
)
feature_map = tf.expand_dims(
tf.reshape(tf.range(64), [8, 8, 1]), axis=0
feature_map = np.expand_dims(
np.reshape(np.arange(64), [8, 8, 1]), axis=0
)
rois = tf.reshape(tf.constant([0.0, 0.0, 224, 224]), [1, 1, 4])
rois = np.reshape(np.array([0.0, 0.0, 224, 224]), [1, 1, 4])
pooled_feature_map = roi_pooler(feature_map, rois)
# the maximum value would be at bottom-right at each block, roi sharded
# into 3x2 blocks
Expand All @@ -120,19 +118,19 @@ def test_roi_quantize_h(self):
# | 48, 49, 50, 51 | 52, 53, 54, 55 |
# | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) |
# --------------------------------------------
expected_feature_map = tf.reshape(
tf.constant([11, 15, 35, 39, 59, 63]), [1, 3, 2, 1]
expected_feature_map = np.reshape(
np.array([11, 15, 35, 39, 59, 63]), [1, 3, 2, 1]
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

def test_roi_quantize_w(self):
roi_pooler = ROIPooler(
"yxyx", target_size=[2, 3], image_shape=[224, 224, 3]
)
feature_map = tf.expand_dims(
tf.reshape(tf.range(64), [8, 8, 1]), axis=0
feature_map = np.expand_dims(
np.reshape(np.arange(64), [8, 8, 1]), axis=0
)
rois = tf.reshape(tf.constant([0.0, 0.0, 224, 224]), [1, 1, 4])
rois = np.reshape(np.array([0.0, 0.0, 224, 224]), [1, 1, 4])
pooled_feature_map = roi_pooler(feature_map, rois)
# the maximum value would be at bottom-right at each block, roi sharded
# into 2x3 blocks
Expand All @@ -146,19 +144,19 @@ def test_roi_quantize_w(self):
# | 48, 49 | 50, 51, 52 | 53, 54, 55 |
# | 56, 57(max) | 58, 59, 60(max) | 61, 62, 63(max) |
# --------------------------------------------
expected_feature_map = tf.reshape(
tf.constant([25, 28, 31, 57, 60, 63]), [1, 2, 3, 1]
expected_feature_map = np.reshape(
np.array([25, 28, 31, 57, 60, 63]), [1, 2, 3, 1]
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

def test_roi_feature_map_height_smaller_than_roi(self):
roi_pooler = ROIPooler(
"yxyx", target_size=[6, 2], image_shape=[224, 224, 3]
)
feature_map = tf.expand_dims(
tf.reshape(tf.range(16), [4, 4, 1]), axis=0
feature_map = np.expand_dims(
np.reshape(np.arange(16), [4, 4, 1]), axis=0
)
rois = tf.reshape(tf.constant([0.0, 0.0, 224, 224]), [1, 1, 4])
rois = np.reshape(np.array([0.0, 0.0, 224, 224]), [1, 1, 4])
pooled_feature_map = roi_pooler(feature_map, rois)
# | 0, 1(max) | 2, 3(max) |
# ------------------repeated----------------------
Expand All @@ -167,28 +165,28 @@ def test_roi_feature_map_height_smaller_than_roi(self):
# | 8, 9(max) | 10, 11(max) |
# ------------------repeated----------------------
# | 12, 13(max) | 14, 15(max) |
expected_feature_map = tf.reshape(
tf.constant([1, 3, 1, 3, 5, 7, 9, 11, 9, 11, 13, 15]), [1, 6, 2, 1]
expected_feature_map = np.reshape(
np.array([1, 3, 1, 3, 5, 7, 9, 11, 9, 11, 13, 15]), [1, 6, 2, 1]
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

def test_roi_feature_map_width_smaller_than_roi(self):
roi_pooler = ROIPooler(
"yxyx", target_size=[2, 6], image_shape=[224, 224, 3]
)
feature_map = tf.expand_dims(
tf.reshape(tf.range(16), [4, 4, 1]), axis=0
feature_map = np.expand_dims(
np.reshape(np.arange(16), [4, 4, 1]), axis=0
)
rois = tf.reshape(tf.constant([0.0, 0.0, 224, 224]), [1, 1, 4])
rois = np.reshape(np.array([0.0, 0.0, 224, 224]), [1, 1, 4])
pooled_feature_map = roi_pooler(feature_map, rois)
# | 0 | 1 | 2 | 3 |
# | 4(max) | 5(max) | 6(max) | 7(max) |
# --------------------------------------------
# | 8 | 9 | 10 | 11 |
# | 12(max) | 13(max) | 14(max) | 15(max) |
# --------------------------------------------
expected_feature_map = tf.reshape(
tf.constant([4, 4, 5, 6, 6, 7, 12, 12, 13, 14, 14, 15]),
expected_feature_map = np.reshape(
np.array([4, 4, 5, 6, 6, 7, 12, 12, 13, 14, 14, 15]),
[1, 2, 6, 1],
)
self.assertAllClose(expected_feature_map, pooled_feature_map)
Expand All @@ -197,13 +195,13 @@ def test_roi_empty(self):
roi_pooler = ROIPooler(
"yxyx", target_size=[2, 2], image_shape=[224, 224, 3]
)
feature_map = tf.expand_dims(
tf.reshape(tf.range(1, 65), [8, 8, 1]), axis=0
feature_map = np.expand_dims(
np.reshape(np.arange(1, 65), [8, 8, 1]), axis=0
)
rois = tf.reshape(tf.constant([0.0, 0.0, 0.0, 0.0]), [1, 1, 4])
rois = np.reshape(np.array([0.0, 0.0, 0.0, 0.0]), [1, 1, 4])
pooled_feature_map = roi_pooler(feature_map, rois)
# all outputs should be top-left pixel
self.assertAllClose(tf.ones([1, 2, 2, 1]), pooled_feature_map)
self.assertAllClose(np.ones([1, 2, 2, 1]), pooled_feature_map)

def test_invalid_image_shape(self):
with self.assertRaisesRegex(ValueError, "dynamic shape"):
Expand Down

0 comments on commit 9fe4fa9

Please sign in to comment.