Skip to content

Commit

Permalink
chore: fix pool and port sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
ariG23498 committed Feb 3, 2024
1 parent 9fe4fa9 commit c0de067
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 99 deletions.
4 changes: 2 additions & 2 deletions keras_cv/layers/object_detection/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def _pool_single_sample(self, args):
pooled_feature_map: [target_size, C] float Tensor
"""
feature_map, rois = args
num_rois = rois.get_shape().as_list()[0]
height, width, channel = feature_map.get_shape().as_list()
num_rois = ops.shape(rois)[0]
height, width, channel = ops.shape(feature_map)
# TODO (consider vectorize it for better performance)
for n in range(num_rois):
# [4]
Expand Down
49 changes: 23 additions & 26 deletions keras_cv/layers/object_detection/roi_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from tensorflow import keras

from keras_cv import bounding_box
from keras_cv.backend import assert_tf_keras
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.bounding_box import iou
from keras_cv.layers.object_detection import box_matcher
from keras_cv.layers.object_detection import sampling
Expand Down Expand Up @@ -69,7 +67,6 @@ def __init__(
append_gt_boxes: bool = True,
**kwargs,
):
assert_tf_keras("keras_cv.layers._ROISampler")
super().__init__(**kwargs)
self.bounding_box_format = bounding_box_format
self.roi_matcher = roi_matcher
Expand All @@ -84,9 +81,9 @@ def __init__(

def call(
self,
rois: tf.Tensor,
gt_boxes: tf.Tensor,
gt_classes: tf.Tensor,
rois,
gt_boxes,
gt_classes,
):
"""
Args:
Expand All @@ -102,11 +99,11 @@ def call(
"""
if self.append_gt_boxes:
# num_rois += num_gt
rois = tf.concat([rois, gt_boxes], axis=1)
num_rois = rois.get_shape().as_list()[1]
rois = ops.concatenate([rois, gt_boxes], axis=1)
num_rois = ops.shape(rois)[1]
if num_rois is None:
raise ValueError(
f"`rois` must have static shape, got {rois.get_shape()}"
f"`rois` must have static shape, got {ops.shape(rois)}"
)
if num_rois < self.num_sampled_rois:
raise ValueError(
Expand All @@ -126,27 +123,27 @@ def call(
# [batch_size, num_rois] | [batch_size, num_rois]
matched_gt_cols, matched_vals = self.roi_matcher(similarity_mat)
# [batch_size, num_rois]
positive_matches = tf.math.equal(matched_vals, 1)
negative_matches = tf.math.equal(matched_vals, -1)
positive_matches = ops.equal(matched_vals, 1)
negative_matches = ops.equal(matched_vals, -1)
self._positives.update_state(
tf.reduce_sum(tf.cast(positive_matches, tf.float32), axis=-1)
ops.sum(ops.cast(positive_matches, "float32"), axis=-1)
)
self._negatives.update_state(
tf.reduce_sum(tf.cast(negative_matches, tf.float32), axis=-1)
ops.sum(ops.cast(negative_matches, "float32"), axis=-1)
)
# [batch_size, num_rois, 1]
background_mask = tf.expand_dims(
tf.logical_not(positive_matches), axis=-1
background_mask = ops.expand_dims(
ops.logical_not(positive_matches), axis=-1
)
# [batch_size, num_rois, 1]
matched_gt_classes = target_gather._target_gather(
gt_classes, matched_gt_cols
)
# also set all background matches to `background_class`
matched_gt_classes = tf.where(
matched_gt_classes = ops.where(
background_mask,
tf.cast(
self.background_class * tf.ones_like(matched_gt_classes),
ops.cast(
self.background_class * ops.ones_like(matched_gt_classes),
gt_classes.dtype,
),
matched_gt_classes,
Expand All @@ -163,9 +160,9 @@ def call(
variance=[0.1, 0.1, 0.2, 0.2],
)
# also set all background matches to 0 coordinates
encoded_matched_gt_boxes = tf.where(
encoded_matched_gt_boxes = ops.where(
background_mask,
tf.zeros_like(matched_gt_boxes),
ops.zeros_like(matched_gt_boxes),
encoded_matched_gt_boxes,
)
# [batch_size, num_rois]
Expand All @@ -176,7 +173,7 @@ def call(
self.positive_fraction,
)
# [batch_size, num_sampled_rois] in the range of [0, num_rois)
sampled_indicators, sampled_indices = tf.math.top_k(
sampled_indicators, sampled_indices = ops.top_k(
sampled_indicators, k=self.num_sampled_rois, sorted=True
)
# [batch_size, num_sampled_rois, 4]
Expand All @@ -192,12 +189,12 @@ def call(
# [batch_size, num_sampled_rois, 1]
# all negative samples will be ignored in regression
sampled_box_weights = target_gather._target_gather(
tf.cast(positive_matches[..., tf.newaxis], gt_boxes.dtype),
ops.cast(positive_matches[..., None], gt_boxes.dtype),
sampled_indices,
)
# [batch_size, num_sampled_rois, 1]
sampled_indicators = sampled_indicators[..., tf.newaxis]
sampled_class_weights = tf.cast(sampled_indicators, gt_classes.dtype)
sampled_indicators = sampled_indicators[..., None]
sampled_class_weights = ops.cast(sampled_indicators, gt_classes.dtype)
return (
sampled_rois,
sampled_gt_boxes,
Expand Down
Loading

0 comments on commit c0de067

Please sign in to comment.