Skip to content

Commit

Permalink
Merge branch 'google3'
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Gorban committed May 22, 2023
2 parents 419b532 + d3c22b7 commit 656f759
Show file tree
Hide file tree
Showing 22 changed files with 128 additions and 111 deletions.
61 changes: 30 additions & 31 deletions docs/pose_estimation_metric.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ subsets of them by selecting a group of keypoints and thresholds.
number of keypoints close to ground truth to the total number of keypoints. We
use thresholds `(0.05, 0.1, 0.2, 0.3, 0.4, 0.5)` relative to bounding
box scales, for example for a human-like 1x1x2m box, the box's scale will be
$(1 \cdot 1 \cdot 2)^\frac{1}{3} = 1.26$, so the $0.20$ threshold of this
$`(1 \cdot 1 \cdot 2)^\frac{1}{3} = 1.26`$, so the $`0.20`$ threshold of this
scale will be 25cm and keypoints with errors less than 25cm will be considered
correct. The metric takes values in the `[0, 1]` range (higher is better).
Useful to understand the distribution of errors.
Expand Down Expand Up @@ -70,20 +70,19 @@ unmatched keypoints (aka `mismatch_penalty`), expressed in meters.

We compute the PEM on a set of candidate pairs of predicted and ground truth
objects, for which at least one predicted keypoint is within a distance
threshold constant $C$ from the ground truth box. The final object assignment
threshold constant $`C`$ from the ground truth box. The final object assignment
is selected using the Hungarian method to minimize:

$$\textbf{PEM}(Y,\hat{Y}) = \frac{\sum_{i\in M}\left\|y_{i} -
\hat{y}_{i}\right\|_2 + C|U|}{|M| + |U|}$$

where $M$ - a set of indices of matched keypoints, $U$ - a set of indices of
where $`M`$ - a set of indices of matched keypoints, $`U`$ - a set of indices of
unmatched keypoints (ground truth keypoints without matching predicted keypoints
or predicted keypoints for unmatched objects); Sets
$`Y = \left\lbrace y_i\right\rbrace_{i \in M}`$
and
$`\hat{Y} = \left\lbrace\hat{y}_i\right\rbrace_{i \in M}`$
are ground truth and predicted 3D coordinates of keypoints; $C=0.25$ - a
constant penalty for an unmatched keypoint.
or predicted keypoints for unmatched objects); Sets
$`Y= \left\{y_i\right\}_{i \in M}`$ and
$`\hat{Y} = \left\{\hat{y}_i\right\}_{i \in M}`$ are ground truth
and predicted 3D coordinates of keypoints; $`C=0.25`$ - a constant penalty for
an unmatched keypoint.


## Object Matching Algorithm
Expand All @@ -104,25 +103,25 @@ outputs three sets of objects:
However, matching is complicated by the fact that not all GT objects in WOD have
visible keypoints. To address this, two kinds of GT objects are distinguished:

- $GT_i$ - GT objects without any visible keypoints, which includes unlabeled
- $`GT_i`$ - GT objects without any visible keypoints, which includes unlabeled
or heavily occluded human objects.
- $GT_v$ - GT boxes with at least one
- $`GT_v`$ - GT boxes with at least one
visible keypoint.

| ![a toy example to illustrate $GT_v$ and $GT_i$](images/pem_matching_fig.png) |
| ![a toy example to illustrate $`GT_v`$ and $`GT_i`$](images/pem_matching_fig.png) |
| :-: |
| Fig 1. A toy scene |

On the Fig. 1 you can see:

- Ground truth objects:
- $GT_i$: $G_0$, $G_1$, $G_3$, $G_5$, $G_7$
- $GT_v$: $G_2$, $G_4$, $G_6$, $G_8$, $G_9$
- $`GT_i`$: $`G_0`$, $`G_1`$, $`G_3`$, $`G_5`$, $`G_7`$
- $`GT_v`$: $`G_2`$, $`G_4`$, $`G_6`$, $`G_8`$, $`G_9`$
- Predicted objects:
$P_0$, $P_1$, $P_2$, $P_3$, $P_4$, $P_5$, $P_6$, $P_7$
$`P_0`$, $`P_1`$, $`P_2`$, $`P_3`$, $`P_4`$, $`P_5`$, $`P_6`$, $`P_7`$

If a PR object corresponds to a $GT_i$ object, no penalty is assigned since the
MPJPE cannot be computed for such matches. Only matches between $GT_v$ objects and
If a PR object corresponds to a $`GT_i`$ object, no penalty is assigned since the
MPJPE cannot be computed for such matches. Only matches between $`GT_v`$ objects and
PR objects are considered for the computation of the PEM metric.

Since computing the PEM metric for all possible matches between GT and PR is not
Expand All @@ -132,7 +131,7 @@ challenge is the
[`MeanErrorMatcher`](src/waymo_open_dataset/metrics/python/keypoint_metrics.py),
which computes keypoint errors for each pair of candidate matches. It has two stages:

1. When keypoints clearly fall in $GT_i$ objects (see criterion in
1. When keypoints clearly fall in $`GT_i`$ objects (see criterion in
[keypoint_metrics.py](src/waymo_open_dataset/metrics/python/keypoint_metrics.py)),
remove them from considerations, without any penalties.
2. For all remaining candidate GTv ground truth boxes and detections pairs,
Expand All @@ -143,22 +142,22 @@ this:
- stage #1:
- Select pairs of GT and PR objects for which at least one PR keypoint is
inside GT box enlarged by 25cm.
- assume $PEM(G_4, P_5) > C$ and $PEM(G_6, P_6) < C$
- should exclude: $(G_0, P_0)$, $(G_1, P_1)$, $(G_3, P_3)$,
$(G_5, P_5)$ pairs.
- assume $`PEM(G_4, P_5) > C`$ and $`PEM(G_6, P_6) < C`$
- should exclude: $`(G_0, P_0)`$, $`(G_1, P_1)`$, $`(G_3, P_3)`$,
$`(G_5, P_5)`$ pairs.
- stage #2:
- consider only GTv objects
- compute errors for candidate pairs and populate the assignment error $A$
(aka cost matrix): $A_{k,j}=PEM(G_k, P_j)$ for
$(G_2, P_2)$, $(G_4, P_5)$, $(G_6, P_6)$, $(G_8, P_7)$,
$(G_9, P_7)$ and set the rest of the 8x7 matrix $A=\infty$.
- assuming $PEM(G_9, P_7) < PEM(G_8, P_7)$, the matching assignment should
- compute errors for candidate pairs and populate the assignment error $`A`$
(aka cost matrix): $`A_{k,j}=PEM(G_k, P_j)`$ for
$`(G_2, P_2)`$, $`(G_4, P_5)`$, $`(G_6, P_6)`$, $`(G_8, P_7)`$,
$`(G_9, P_7)`$ and set the rest of the 8x7 matrix $`A=\infty`$.
- assuming $`PEM(G_9, P_7) < PEM(G_8, P_7)`$, the matching assignment should
output the following pairs:
$(G_1, P_1)$, $(G_2, P_2)$, $(G_6, P_6)$, $(G_9, P_7)$
$`(G_1, P_1)`$, $`(G_2, P_2)`$, $`(G_6, P_6)`$, $`(G_9, P_7)`$
- the final output of the matcher should be:
$(G_2, P_2)$, $(G_6, P_6)$, $(G_9, P_7)$,
$(G_4, \emptyset)$, $(G_8, \emptyset)$,
$(\emptyset, P_4)$
$`(G_2, P_2)`$, $`(G_6, P_6)`$, $`(G_9, P_7)`$,
$`(G_4, \emptyset)`$, $`(G_8, \emptyset)`$,
$`(\emptyset, P_4)`$

For the PEM metric, each ground-truth box – GTV and GTi – can only be
associated with a maximum of 1 detection. To maximize your PEM scores, you are
Expand All @@ -167,4 +166,4 @@ responsible for removing duplicate detections.
NOTE: The WOD library also implements the [`CppMatcher`](src/waymo_open_dataset/metrics/python/keypoint_metrics.py)
which maximizes total Intersection over Union (IoU) between predicted and ground
truth boxes. However, this matcher requires all predictions to have bounding
boxes and is provided only as a reference.
boxes and is provided only as a reference.
34 changes: 0 additions & 34 deletions src/waymo_open_dataset/metrics/ops/detection_metrics_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,40 +398,6 @@ def testVelocityBreakdown(self):
self.assertTrue(np.all(aph >= -EPSILON))
self.assertTrue(np.all(aph <= 1.0 + EPSILON))

def testStateBasedRequiresScoreCutoffs(self):
config = metrics_pb2.Config()
config_text = """
num_desired_score_cutoffs: 11
breakdown_generator_ids: ONE_SHARD
difficulties {
}
matcher_type: TYPE_HUNGARIAN
iou_thresholds: 0.5
iou_thresholds: 0.5
iou_thresholds: 0.5
iou_thresholds: 0.5
iou_thresholds: 0.5
box_type: TYPE_3D
"""
text_format.Parse(config_text, config)
k, n, m = 10, 0, 0
pd_bbox, pd_type, pd_frameid, pd_score, _ = self._GenerateRandomBBoxes(k, m)
gt_bbox, gt_type, gt_frameid, _, gt_speed = self._GenerateRandomBBoxes(k, n)
with self.assertRaisesRegex(
ValueError, 'requires that score cutoffs are set explicitly'
):
self._GetAP(
pd_bbox,
pd_type,
pd_frameid,
pd_score,
gt_bbox,
gt_type,
gt_frameid,
gt_speed,
config,
)


if __name__ == '__main__':
tf.compat.v1.disable_eager_execution()
Expand Down
1 change: 0 additions & 1 deletion src/waymo_open_dataset/metrics/ops/metrics_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ REGISTER_OP("DetectionMetrics")
.Output("breakdown: uint8")
.Attr("config: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {

return ::tensorflow::Status();
})
.Doc(R"doc(
Expand Down
55 changes: 44 additions & 11 deletions src/waymo_open_dataset/metrics/ops/py_metrics_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,10 @@
# limitations under the License.
# ==============================================================================
"""Waymo Open Dataset tensorflow ops python interface."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

metrics_module = tf.load_op_library(
tf.compat.v1.resource_loader.get_path_to_datafile('metrics_ops.so'))
gen_metrics_ops = tf.load_op_library(
tf.compat.v1.resource_loader.get_path_to_datafile('metrics_ops.so'))


def detection_metrics(prediction_bbox,
Expand All @@ -40,7 +35,7 @@ def detection_metrics(prediction_bbox,
num_gt_boxes = tf.shape(ground_truth_bbox)[0]
ground_truth_speed = tf.zeros((num_gt_boxes, 2), dtype=tf.float32)

return metrics_module.detection_metrics(
return gen_metrics_ops.detection_metrics(
prediction_bbox=prediction_bbox,
prediction_type=prediction_type,
prediction_score=prediction_score,
Expand All @@ -54,6 +49,44 @@ def detection_metrics(prediction_bbox,
config=config)


def detection_metrics_state(
prediction_bbox,
prediction_type,
prediction_score,
prediction_frame_id,
prediction_overlap_nlz,
ground_truth_bbox,
ground_truth_type,
ground_truth_frame_id,
ground_truth_difficulty,
config,
ground_truth_speed=None,
):
"""Wraps detection metrics state. See metrics_ops.cc for documentation."""
if ground_truth_speed is None:
num_gt_boxes = tf.shape(ground_truth_bbox)[0]
ground_truth_speed = tf.zeros((num_gt_boxes, 2), dtype=tf.float32)

return gen_metrics_ops.detection_metrics_state(
prediction_bbox=prediction_bbox,
prediction_type=prediction_type,
prediction_score=prediction_score,
prediction_frame_id=prediction_frame_id,
prediction_overlap_nlz=prediction_overlap_nlz,
ground_truth_bbox=ground_truth_bbox,
ground_truth_type=ground_truth_type,
ground_truth_frame_id=ground_truth_frame_id,
ground_truth_difficulty=ground_truth_difficulty,
ground_truth_speed=ground_truth_speed,
config=config,
)


def detection_metrics_result(state, config):
"""Wraps detection metrics result. See metrics_ops.cc for documentation."""
return gen_metrics_ops.detection_metrics_result(state=state, config=config)


def motion_metrics(prediction_trajectory,
prediction_score,
ground_truth_trajectory,
Expand All @@ -75,7 +108,7 @@ def motion_metrics(prediction_trajectory,
batch_size = tf.shape(ground_truth_trajectory)[0]
scenario_id = tf.strings.as_string(tf.range(batch_size))

return metrics_module.motion_metrics(
return gen_metrics_ops.motion_metrics(
prediction_trajectory=prediction_trajectory,
prediction_score=prediction_score,
ground_truth_trajectory=ground_truth_trajectory,
Expand Down Expand Up @@ -111,7 +144,7 @@ def tracking_metrics(prediction_bbox,
if prediction_overlap_nlz is None:
prediction_overlap_nlz = tf.zeros_like(prediction_frame_id, dtype=tf.bool)

return metrics_module.tracking_metrics(
return gen_metrics_ops.tracking_metrics(
prediction_bbox=prediction_bbox,
prediction_type=prediction_type,
prediction_score=prediction_score,
Expand All @@ -131,4 +164,4 @@ def tracking_metrics(prediction_bbox,

def match(prediction_boxes, groundtruth_boxes, config):
"""Wraps match. See metrics_ops.cc for full documentation."""
return metrics_module.match(prediction_boxes, groundtruth_boxes, config)
return gen_metrics_ops.match(prediction_boxes, groundtruth_boxes, config)
4 changes: 2 additions & 2 deletions src/waymo_open_dataset/pip_pkg_scripts/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ py_wheel(
"tensorflow_graphics==2021.12.3",
"tensorflow_probability==0.19.0",
],
version = "1.5.1",
version = "1.5.2",
deps = [
":all_python_modules",
],
Expand All @@ -213,7 +213,7 @@ py_wheel(
"numpy==1.21.5",
"tensorflow==2.11",
],
version = "1.5.1",
version = "1.5.2",
deps = [
"@deeplab2",
],
Expand Down
2 changes: 1 addition & 1 deletion src/waymo_open_dataset/pip_pkg_scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ docker build \
--tag=open_dataset_pip\
-f src/waymo_open_dataset/waymo_open_dataset/pip_pkg_scripts/build.Dockerfile\
--build-arg USERNAME=$USER\
--build-arg USER_UID=$(id -u $USER) .
--build-arg USER_UID=$`(id -u `$USER) .
docker run --mount type=bind,source=/tmp/wod,target=/tmp/wod open_dataset_pip
```
Expand Down
1 change: 0 additions & 1 deletion src/waymo_open_dataset/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ py_library(
deps = [
requirement("tensorflow"),
"//waymo_open_dataset/protos:scenario_proto_py_pb2",
"//waymo_open_dataset/protos:sim_agents_submission_proto_py_pb2",
],
)

Expand Down
8 changes: 4 additions & 4 deletions src/waymo_open_dataset/utils/keypoint_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,15 +505,15 @@ def create_laser_box_tensors(
def _stack_optional_headings(
values: List[BoundingBoxTensors], axis: int = 0
) -> Optional[tf.Tensor]:
no_heading = sum([1 if v.heading else 0 for v in values])
if no_heading == 0:
num_have_heading = sum([v.heading is not None for v in values])
if not num_have_heading:
return None
elif no_heading == len(values):
elif num_have_heading == len(values):
return tf.stack([v.heading for v in values], axis)
else:
raise AssertionError(
'Either all or none of the boxes need to have heading, got'
f' {no_heading} without heading out of {len(values)}'
f' {num_have_heading} with heading out of {len(values)}'
)


Expand Down
24 changes: 24 additions & 0 deletions src/waymo_open_dataset/utils/keypoint_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,30 @@ def test_returns_none_box_for_pose_prediction_without_box(self):

self.assertIsNone(tensors.box)

def test_populates_heading_for_a_box_with_zero_heading(self):
# A special case to verify there is no implicit boolean conversion for
# heading = tf.constant(0), e.g. we are not using `if heading:` anywhere.
heading = 0
kp1 = _util.laser_keypoint(_LEFT_SHOULDER, location_m=(4, 5, 6))
kp2 = _util.laser_keypoint(_RIGHT_SHOULDER, location_m=(6, 5, 4))

poses = [
_lib.PoseLabel(
object_type=_lib.ObjectType.TYPE_PEDESTRIAN,
box=_util.laser_box((1, 2, 3), (4, 5, 6), heading),
keypoints=keypoint_pb2.LaserKeypoints(keypoint=[kp1, kp2]),
)
]

tensors = _lib.create_pose_estimation_tensors(
poses,
default_location=tf.zeros(3, dtype=tf.float32),
order=[_LEFT_SHOULDER, _RIGHT_SHOULDER],
)

self.assertIsNotNone(tensors.box)
self.assertIsNotNone(tensors.box.heading) # pylint: disable=attribute-error


if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,9 @@
# ==============================================================================
"""Camera model tensorflow ops python interface."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
gen_camera_model_ops = tf.load_op_library(
tf.compat.v1.resource_loader.get_path_to_datafile('camera_model_ops.so'))

camera_model_module = tf.load_op_library(
tf.compat.v1.resource_loader.get_path_to_datafile('camera_model_ops.so'))

world_to_image = camera_model_module.world_to_image
image_to_world = camera_model_module.image_to_world
world_to_image = gen_camera_model_ops.world_to_image
image_to_world = gen_camera_model_ops.image_to_world
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ class CameraSegmentationMetricsTest(tf.test.TestCase):

def setUp(self):
super().setUp()
self._panoptic_label_divisor = 100000
self._true_panoptic_labels = [
np.array([[1, 2, 3, 4]], dtype=np.int32),
np.array([[2, 3, 4, 5]], dtype=np.int32)]
np.array([[1, 2, 3, 4]], dtype=np.int32) * self._panoptic_label_divisor,
np.array([[2, 3, 4, 5]], dtype=np.int32) * self._panoptic_label_divisor]
self._pred_panoptic_labels = [
np.array([[4, 2, 3, 1]], dtype=np.int32),
np.array([[0, 0, 0, 0]], dtype=np.int32)]
np.array([[4, 2, 3, 1]], dtype=np.int32) * self._panoptic_label_divisor,
np.array([[0, 0, 0, 0]], dtype=np.int32) * self._panoptic_label_divisor]
self._num_cameras_covered = [
np.array([[1, 1, 1, 1]], dtype=np.int32),
np.array([[2, 2, 2, 2]], dtype=np.int32),
Expand All @@ -42,7 +43,8 @@ def setUp(self):
def test_get_eval_config(self):
eval_config = camera_segmentation_metrics.get_eval_config()
self.assertIsNotNone(eval_config)
self.assertEqual(eval_config.panoptic_label_divisor, 100000)
self.assertEqual(eval_config.panoptic_label_divisor,
self._panoptic_label_divisor)

def test_get_metric_object_by_sequence_basic(self):
metric_obj = camera_segmentation_metrics.get_metric_object_by_sequence(
Expand Down
Loading

0 comments on commit 656f759

Please sign in to comment.