Skip to content

Commit

Permalink
Fix loss computation for CenterPillar when batch_size > 1 (#2056)
Browse files Browse the repository at this point in the history
* Fix larger batch sizes for CenterPillar

* Another fix
  • Loading branch information
ianstenbit authored Sep 7, 2023
1 parent a9fea08 commit 2a0d753
Showing 1 changed file with 18 additions and 21 deletions.
39 changes: 18 additions & 21 deletions keras_cv/models/object_detection_3d/center_pillar.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,35 +137,32 @@ def compute_loss(self, x, y, y_pred, sample_weight=None, **kwargs):

# TODO(ianstenbit): loss heatmap threshold should be configurable.
box_regression_mask = (
ops.squeeze(
ops.take(
ops.reshape(heatmap, (ops.shape(heatmap)[0], -1)),
index[..., 0] * ops.shape(heatmap)[1] + index[..., 1],
axis=1,
),
axis=0,
ops.take_along_axis(
ops.reshape(heatmap, (heatmap.shape[0], -1)),
index[..., 0] * heatmap.shape[1] + index[..., 1],
axis=1,
)
> 0.95
)

box = ops.squeeze(
ops.take(
ops.reshape(box, (ops.shape(box)[0], -1, 7)),
index[..., 0] * ops.shape(box)[1] + index[..., 1],
axis=1,
box = ops.take_along_axis(
ops.reshape(box, (ops.shape(box)[0], -1, 7)),
ops.expand_dims(
index[..., 0] * ops.shape(box)[1] + index[..., 1], axis=-1
),
axis=0,
axis=1,
)
box_pred = ops.squeeze(
ops.take(
ops.reshape(
box_pred,
(ops.shape(box_pred)[0], -1, ops.shape(box_pred)[-1]),
),

box_pred = ops.take_along_axis(
ops.reshape(
box_pred,
(ops.shape(box_pred)[0], -1, ops.shape(box_pred)[-1]),
),
ops.expand_dims(
index[..., 0] * ops.shape(box_pred)[1] + index[..., 1],
axis=1,
axis=-1,
),
axis=0,
axis=1,
)

box_center_mask = heatmap > 0.99
Expand Down

0 comments on commit 2a0d753

Please sign in to comment.