diff --git a/keras_cv/layers/object_detection/roi_pool.py b/keras_cv/layers/object_detection/roi_pool.py index b0d4f73271..34b7c0fd08 100644 --- a/keras_cv/layers/object_detection/roi_pool.py +++ b/keras_cv/layers/object_detection/roi_pool.py @@ -113,8 +113,8 @@ def _pool_single_sample(self, args): pooled_feature_map: [target_size, C] float Tensor """ feature_map, rois = args - num_rois = rois.get_shape().as_list()[0] - height, width, channel = feature_map.get_shape().as_list() + num_rois = ops.shape(rois)[0] + height, width, channel = ops.shape(feature_map) # TODO (consider vectorize it for better performance) for n in range(num_rois): # [4] diff --git a/keras_cv/layers/object_detection/roi_sampler.py b/keras_cv/layers/object_detection/roi_sampler.py index fe63e31ba9..56d774dba5 100644 --- a/keras_cv/layers/object_detection/roi_sampler.py +++ b/keras_cv/layers/object_detection/roi_sampler.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf -from tensorflow import keras - from keras_cv import bounding_box -from keras_cv.backend import assert_tf_keras +from keras_cv.backend import keras +from keras_cv.backend import ops from keras_cv.bounding_box import iou from keras_cv.layers.object_detection import box_matcher from keras_cv.layers.object_detection import sampling @@ -69,7 +67,6 @@ def __init__( append_gt_boxes: bool = True, **kwargs, ): - assert_tf_keras("keras_cv.layers._ROISampler") super().__init__(**kwargs) self.bounding_box_format = bounding_box_format self.roi_matcher = roi_matcher @@ -84,9 +81,9 @@ def __init__( def call( self, - rois: tf.Tensor, - gt_boxes: tf.Tensor, - gt_classes: tf.Tensor, + rois, + gt_boxes, + gt_classes, ): """ Args: @@ -102,11 +99,11 @@ def call( """ if self.append_gt_boxes: # num_rois += num_gt - rois = tf.concat([rois, gt_boxes], axis=1) - num_rois = rois.get_shape().as_list()[1] + rois = ops.concatenate([rois, gt_boxes], axis=1) + num_rois = ops.shape(rois)[1] if num_rois is None: raise ValueError( - f"`rois` must have static shape, got {rois.get_shape()}" + f"`rois` must have static shape, got {ops.shape(rois)}" ) if num_rois < self.num_sampled_rois: raise ValueError( @@ -126,27 +123,27 @@ def call( # [batch_size, num_rois] | [batch_size, num_rois] matched_gt_cols, matched_vals = self.roi_matcher(similarity_mat) # [batch_size, num_rois] - positive_matches = tf.math.equal(matched_vals, 1) - negative_matches = tf.math.equal(matched_vals, -1) + positive_matches = ops.equal(matched_vals, 1) + negative_matches = ops.equal(matched_vals, -1) self._positives.update_state( - tf.reduce_sum(tf.cast(positive_matches, tf.float32), axis=-1) + ops.sum(ops.cast(positive_matches, "float32"), axis=-1) ) self._negatives.update_state( - tf.reduce_sum(tf.cast(negative_matches, tf.float32), axis=-1) + ops.sum(ops.cast(negative_matches, "float32"), axis=-1) ) # [batch_size, num_rois, 1] - background_mask = tf.expand_dims( - tf.logical_not(positive_matches), axis=-1 + background_mask = ops.expand_dims( + ops.logical_not(positive_matches), axis=-1 ) # [batch_size, num_rois, 1] matched_gt_classes = target_gather._target_gather( gt_classes, matched_gt_cols ) # also set all background matches to `background_class` - matched_gt_classes = tf.where( + matched_gt_classes = ops.where( background_mask, - tf.cast( - self.background_class * tf.ones_like(matched_gt_classes), + ops.cast( + self.background_class * ops.ones_like(matched_gt_classes), gt_classes.dtype, ), matched_gt_classes, @@ -163,9 +160,9 @@ def call( variance=[0.1, 0.1, 0.2, 0.2], ) # also set all background matches to 0 coordinates - encoded_matched_gt_boxes = tf.where( + encoded_matched_gt_boxes = ops.where( background_mask, - tf.zeros_like(matched_gt_boxes), + ops.zeros_like(matched_gt_boxes), encoded_matched_gt_boxes, ) # [batch_size, num_rois] @@ -176,7 +173,7 @@ def call( self.positive_fraction, ) # [batch_size, num_sampled_rois] in the range of [0, num_rois) - sampled_indicators, sampled_indices = tf.math.top_k( + sampled_indicators, sampled_indices = ops.top_k( sampled_indicators, k=self.num_sampled_rois, sorted=True ) # [batch_size, num_sampled_rois, 4] @@ -192,12 +189,12 @@ def call( # [batch_size, num_sampled_rois, 1] # all negative samples will be ignored in regression sampled_box_weights = target_gather._target_gather( - tf.cast(positive_matches[..., tf.newaxis], gt_boxes.dtype), + ops.cast(positive_matches[..., None], gt_boxes.dtype), sampled_indices, ) # [batch_size, num_sampled_rois, 1] - sampled_indicators = sampled_indicators[..., tf.newaxis] - sampled_class_weights = tf.cast(sampled_indicators, gt_classes.dtype) + sampled_indicators = sampled_indicators[..., None] + sampled_class_weights = ops.cast(sampled_indicators, gt_classes.dtype) return ( sampled_rois, sampled_gt_boxes, diff --git a/keras_cv/layers/object_detection/roi_sampler_test.py b/keras_cv/layers/object_detection/roi_sampler_test.py index a0ab5c92c2..d209a96297 100644 --- a/keras_cv/layers/object_detection/roi_sampler_test.py +++ b/keras_cv/layers/object_detection/roi_sampler_test.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -import tensorflow as tf + +import numpy as np from keras_cv.layers.object_detection.box_matcher import BoxMatcher from keras_cv.layers.object_detection.roi_sampler import _ROISampler from keras_cv.tests.test_case import TestCase -@pytest.mark.tf_keras_only class ROISamplerTest(TestCase): def test_roi_sampler(self): box_matcher = BoxMatcher(thresholds=[0.3], match_values=[-1, 1]) @@ -31,7 +30,7 @@ def test_roi_sampler(self): num_sampled_rois=2, append_gt_boxes=False, ) - rois = tf.constant( + rois = np.array( [ [0, 0, 5, 5], [2.5, 2.5, 7.5, 7.5], @@ -39,32 +38,28 @@ def test_roi_sampler(self): [7.5, 7.5, 12.5, 12.5], ] ) - rois = rois[tf.newaxis, ...] + rois = rois[np.newaxis, ...] # the 3rd box will generate 0 IOUs and not sampled. - gt_boxes = tf.constant( + gt_boxes = np.array( [[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]] ) - gt_boxes = gt_boxes[tf.newaxis, ...] - gt_classes = tf.constant([[2, 10, -1]], dtype=tf.int32) - gt_classes = gt_classes[..., tf.newaxis] + gt_boxes = gt_boxes[np.newaxis, ...] + gt_classes = np.array([[2, 10, -1]], dtype=np.int32) + gt_classes = gt_classes[..., np.newaxis] _, sampled_gt_boxes, _, sampled_gt_classes, _ = roi_sampler( rois, gt_boxes, gt_classes ) # given we only choose 1 positive sample, and `append_label` is False, # only the 2nd ROI is chosen. - expected_gt_boxes = tf.constant( - [[0.0, 0.0, 0, 0.0], [0.0, 0.0, 0, 0.0]] - ) - expected_gt_boxes = expected_gt_boxes[tf.newaxis, ...] + expected_gt_boxes = np.array([[0.0, 0.0, 0, 0.0], [0.0, 0.0, 0, 0.0]]) + expected_gt_boxes = expected_gt_boxes[np.newaxis, ...] # only the 2nd ROI is chosen, and the negative ROI is mapped to 0. - expected_gt_classes = tf.constant([[10], [0]], dtype=tf.int32) - expected_gt_classes = expected_gt_classes[tf.newaxis, ...] - self.assertAllClose( - tf.reduce_max(expected_gt_boxes), tf.reduce_max(sampled_gt_boxes) - ) + expected_gt_classes = np.array([[10], [0]], dtype=np.int32) + expected_gt_classes = expected_gt_classes[np.newaxis, ...] + self.assertAllClose(np.max(expected_gt_boxes), np.max(sampled_gt_boxes)) self.assertAllClose( - tf.reduce_min(expected_gt_classes), - tf.reduce_min(sampled_gt_classes), + np.min(expected_gt_classes), + np.min(sampled_gt_classes), ) def test_roi_sampler_small_threshold(self): @@ -76,7 +71,7 @@ def test_roi_sampler_small_threshold(self): num_sampled_rois=2, append_gt_boxes=False, ) - rois = tf.constant( + rois = np.array( [ [0, 0, 5, 5], [2.5, 2.5, 7.5, 7.5], @@ -84,14 +79,14 @@ def test_roi_sampler_small_threshold(self): [7.5, 7.5, 12.5, 12.5], ] ) - rois = rois[tf.newaxis, ...] + rois = rois[np.newaxis, ...] # the 3rd box will generate 0 IOUs and not sampled. - gt_boxes = tf.constant( + gt_boxes = np.array( [[10, 10, 15, 15], [2.6, 2.6, 7.6, 7.6], [-1, -1, -1, -1]] ) - gt_boxes = gt_boxes[tf.newaxis, ...] - gt_classes = tf.constant([[2, 10, -1]], dtype=tf.int32) - gt_classes = gt_classes[..., tf.newaxis] + gt_boxes = gt_boxes[np.newaxis, ...] + gt_classes = np.array([[2, 10, -1]], dtype=np.int32) + gt_classes = gt_classes[..., np.newaxis] sampled_rois, sampled_gt_boxes, _, sampled_gt_classes, _ = roi_sampler( rois, gt_boxes, gt_classes ) @@ -99,25 +94,23 @@ def test_roi_sampler_small_threshold(self): # only the 2nd ROI is chosen. No negative samples exist given we # select positive_threshold to be 0.1. (the minimum IOU is 1/7) # given num_sampled_rois=2, it selects the 1st ROI as well. - expected_rois = tf.constant([[5, 5, 10, 10], [0.0, 0.0, 5.0, 5.0]]) - expected_rois = expected_rois[tf.newaxis, ...] + expected_rois = np.array([[5, 5, 10, 10], [0.0, 0.0, 5.0, 5.0]]) + expected_rois = expected_rois[np.newaxis, ...] # all ROIs are matched to the 2nd gt box. # the boxes are encoded by dimensions, so the result is # tx, ty = (5.1 - 5.0) / 5 = 0.02, tx, ty = (5.1 - 2.5) / 5 = 0.52 # then divide by 0.1 as box variance. expected_gt_boxes = ( - tf.constant([[0.02, 0.02, 0.0, 0.0], [0.52, 0.52, 0.0, 0.0]]) / 0.1 + np.array([[0.02, 0.02, 0.0, 0.0], [0.52, 0.52, 0.0, 0.0]]) / 0.1 ) - expected_gt_boxes = expected_gt_boxes[tf.newaxis, ...] + expected_gt_boxes = expected_gt_boxes[np.newaxis, ...] # only the 2nd ROI is chosen, and the negative ROI is mapped to 0. - expected_gt_classes = tf.constant([[10], [10]], dtype=tf.int32) - expected_gt_classes = expected_gt_classes[tf.newaxis, ...] - self.assertAllClose( - tf.reduce_max(expected_rois, 1), tf.reduce_max(sampled_rois, 1) - ) + expected_gt_classes = np.array([[10], [10]], dtype=np.int32) + expected_gt_classes = expected_gt_classes[np.newaxis, ...] + self.assertAllClose(np.max(expected_rois, 1), np.max(sampled_rois, 1)) self.assertAllClose( - tf.reduce_max(expected_gt_boxes, 1), - tf.reduce_max(sampled_gt_boxes, 1), + np.max(expected_gt_boxes, 1), + np.max(sampled_gt_boxes, 1), ) self.assertAllClose(expected_gt_classes, sampled_gt_classes) @@ -132,7 +125,7 @@ def test_roi_sampler_large_threshold(self): num_sampled_rois=2, append_gt_boxes=False, ) - rois = tf.constant( + rois = np.array( [ [0, 0, 5, 5], [2.5, 2.5, 7.5, 7.5], @@ -140,22 +133,22 @@ def test_roi_sampler_large_threshold(self): [7.5, 7.5, 12.5, 12.5], ] ) - rois = rois[tf.newaxis, ...] + rois = rois[np.newaxis, ...] # the 3rd box will generate 0 IOUs and not sampled. - gt_boxes = tf.constant( + gt_boxes = np.array( [[10, 10, 15, 15], [2.6, 2.6, 7.6, 7.6], [-1, -1, -1, -1]] ) - gt_boxes = gt_boxes[tf.newaxis, ...] - gt_classes = tf.constant([[2, 10, -1]], dtype=tf.int32) - gt_classes = gt_classes[..., tf.newaxis] + gt_boxes = gt_boxes[np.newaxis, ...] + gt_classes = np.array([[2, 10, -1]], dtype=np.int32) + gt_classes = gt_classes[..., np.newaxis] _, sampled_gt_boxes, _, sampled_gt_classes, _ = roi_sampler( rois, gt_boxes, gt_classes ) # all ROIs are negative matches, so they are mapped to 0. - expected_gt_boxes = tf.zeros([1, 2, 4], dtype=tf.float32) + expected_gt_boxes = np.zeros([1, 2, 4], dtype=np.float32) # only the 2nd ROI is chosen, and the negative ROI is mapped to 0. - expected_gt_classes = tf.constant([[0], [0]], dtype=tf.int32) - expected_gt_classes = expected_gt_classes[tf.newaxis, ...] + expected_gt_classes = np.array([[0], [0]], dtype=np.int32) + expected_gt_classes = expected_gt_classes[np.newaxis, ...] # self.assertAllClose(expected_rois, sampled_rois) self.assertAllClose(expected_gt_boxes, sampled_gt_boxes) self.assertAllClose(expected_gt_classes, sampled_gt_classes) @@ -172,7 +165,7 @@ def test_roi_sampler_large_threshold_custom_bg_class(self): num_sampled_rois=2, append_gt_boxes=False, ) - rois = tf.constant( + rois = np.array( [ [0, 0, 5, 5], [2.5, 2.5, 7.5, 7.5], @@ -180,23 +173,23 @@ def test_roi_sampler_large_threshold_custom_bg_class(self): [7.5, 7.5, 12.5, 12.5], ] ) - rois = rois[tf.newaxis, ...] + rois = rois[np.newaxis, ...] # the 3rd box will generate 0 IOUs and not sampled. - gt_boxes = tf.constant( + gt_boxes = np.array( [[10, 10, 15, 15], [2.6, 2.6, 7.6, 7.6], [-1, -1, -1, -1]] ) - gt_boxes = gt_boxes[tf.newaxis, ...] - gt_classes = tf.constant([[2, 10, -1]], dtype=tf.int32) - gt_classes = gt_classes[..., tf.newaxis] + gt_boxes = gt_boxes[np.newaxis, ...] + gt_classes = np.array([[2, 10, -1]], dtype=np.int32) + gt_classes = gt_classes[..., np.newaxis] _, sampled_gt_boxes, _, sampled_gt_classes, _ = roi_sampler( rois, gt_boxes, gt_classes ) # all ROIs are negative matches, so they are mapped to 0. - expected_gt_boxes = tf.zeros([1, 2, 4], dtype=tf.float32) + expected_gt_boxes = np.zeros([1, 2, 4], dtype=np.float32) # only the 2nd ROI is chosen, and the negative ROI is mapped to -1 from # customization. - expected_gt_classes = tf.constant([[-1], [-1]], dtype=tf.int32) - expected_gt_classes = expected_gt_classes[tf.newaxis, ...] + expected_gt_classes = np.array([[-1], [-1]], dtype=np.int32) + expected_gt_classes = expected_gt_classes[np.newaxis, ...] # self.assertAllClose(expected_rois, sampled_rois) self.assertAllClose(expected_gt_boxes, sampled_gt_boxes) self.assertAllClose(expected_gt_classes, sampled_gt_classes) @@ -212,7 +205,7 @@ def test_roi_sampler_large_threshold_append_gt_boxes(self): num_sampled_rois=2, append_gt_boxes=True, ) - rois = tf.constant( + rois = np.array( [ [0, 0, 5, 5], [2.5, 2.5, 7.5, 7.5], @@ -220,24 +213,24 @@ def test_roi_sampler_large_threshold_append_gt_boxes(self): [7.5, 7.5, 12.5, 12.5], ] ) - rois = rois[tf.newaxis, ...] + rois = rois[np.newaxis, ...] # the 3rd box will generate 0 IOUs and not sampled. - gt_boxes = tf.constant( + gt_boxes = np.array( [[10, 10, 15, 15], [2.6, 2.6, 7.6, 7.6], [-1, -1, -1, -1]] ) - gt_boxes = gt_boxes[tf.newaxis, ...] - gt_classes = tf.constant([[2, 10, -1]], dtype=tf.int32) - gt_classes = gt_classes[..., tf.newaxis] + gt_boxes = gt_boxes[np.newaxis, ...] + gt_classes = np.array([[2, 10, -1]], dtype=np.int32) + gt_classes = gt_classes[..., np.newaxis] _, sampled_gt_boxes, _, sampled_gt_classes, _ = roi_sampler( rois, gt_boxes, gt_classes ) # the selected gt boxes should be [0, 0, 0, 0], and [10, 10, 15, 15] # but the 2nd will be encoded to 0. - self.assertAllClose(tf.reduce_min(sampled_gt_boxes), 0) - self.assertAllClose(tf.reduce_max(sampled_gt_boxes), 0) + self.assertAllClose(np.min(sampled_gt_boxes), 0) + self.assertAllClose(np.max(sampled_gt_boxes), 0) # the selected gt classes should be [0, 2 or 10] - self.assertAllLessEqual(tf.reduce_max(sampled_gt_classes), 10) - self.assertAllGreaterEqual(tf.reduce_min(sampled_gt_classes), 0) + self.assertAllLessEqual(np.max(sampled_gt_classes), 10) + self.assertAllGreaterEqual(np.min(sampled_gt_classes), 0) def test_roi_sampler_large_num_sampled_rois(self): box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) @@ -248,7 +241,7 @@ def test_roi_sampler_large_num_sampled_rois(self): num_sampled_rois=200, append_gt_boxes=True, ) - rois = tf.constant( + rois = np.array( [ [0, 0, 5, 5], [2.5, 2.5, 7.5, 7.5], @@ -256,14 +249,14 @@ def test_roi_sampler_large_num_sampled_rois(self): [7.5, 7.5, 12.5, 12.5], ] ) - rois = rois[tf.newaxis, ...] + rois = rois[np.newaxis, ...] # the 3rd box will generate 0 IOUs and not sampled. - gt_boxes = tf.constant( + gt_boxes = np.array( [[10, 10, 15, 15], [2.6, 2.6, 7.6, 7.6], [-1, -1, -1, -1]] ) - gt_boxes = gt_boxes[tf.newaxis, ...] - gt_classes = tf.constant([[2, 10, -1]], dtype=tf.int32) - gt_classes = gt_classes[..., tf.newaxis] + gt_boxes = gt_boxes[np.newaxis, ...] + gt_classes = np.array([[2, 10, -1]], dtype=np.int32) + gt_classes = gt_classes[..., np.newaxis] with self.assertRaisesRegex(ValueError, "must be less than"): _, _, _ = roi_sampler(rois, gt_boxes, gt_classes)