From 7e2a6212946ee979b94f55700d680aacbba3247b Mon Sep 17 00:00:00 2001 From: GsnMithra Date: Sat, 17 Feb 2024 17:43:08 +0530 Subject: [PATCH 1/4] draw_keypoints() float support --- test/test_utils.py | 18 ++++++++++++++++++ torchvision/utils.py | 10 +++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 49dc553de3e..126bf62c7a5 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -247,6 +247,24 @@ def test_draw_segmentation_masks(colors, alpha, device): torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0) +def test_draw_keypoints_dtypes(): + image_uint8 = torch.full((3, 100, 100), 0, dtype=torch.uint8) + image_float = to_dtype(image_uint8, torch.float, scale=True) + + keypoints_cp = keypoints.clone() + + out_uint8 = utils.draw_keypoints(image_uint8, keypoints) + out_float = utils.draw_keypoints(image_float, keypoints) + + assert out_uint8.dtype == torch.uint8 + assert out_uint8 is not image_uint8 + + assert out_float.is_floating_point() + assert out_float is not image_float + + torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1) + + def test_draw_segmentation_masks_dtypes(): num_masks, h, w = 2, 100, 100 diff --git a/torchvision/utils.py b/torchvision/utils.py index 79e533d4663..ea1c17230b7 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -371,8 +371,8 @@ def draw_keypoints( # validate image if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif not (image.dtype == torch.uint8 or image.is_floating_point()): + raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: @@ -397,6 +397,10 @@ def draw_keypoints( f"Got {visibility.shape = } and {keypoints.shape = }" ) + original_dtype = image.dtype + if image.is_floating_point(): + image = (image * 255).to(dtype=torch.uint8) + ndarr = image.permute(1, 2, 0).cpu().numpy() img_to_draw = Image.fromarray(ndarr) draw = ImageDraw.Draw(img_to_draw) @@ -428,7 +432,7 @@ def draw_keypoints( width=width, ) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=original_dtype) # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization From d9666930319439cc7d952b2371882cdfbd460360 Mon Sep 17 00:00:00 2001 From: GsnMithra Date: Sat, 17 Feb 2024 17:50:53 +0530 Subject: [PATCH 2/4] formatting --- test/test_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 126bf62c7a5..cbc7a66fae1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -250,18 +250,18 @@ def test_draw_segmentation_masks(colors, alpha, device): def test_draw_keypoints_dtypes(): image_uint8 = torch.full((3, 100, 100), 0, dtype=torch.uint8) image_float = to_dtype(image_uint8, torch.float, scale=True) - + keypoints_cp = keypoints.clone() - + out_uint8 = utils.draw_keypoints(image_uint8, keypoints) out_float = utils.draw_keypoints(image_float, keypoints) - + assert out_uint8.dtype == torch.uint8 assert out_uint8 is not image_uint8 - + assert out_float.is_floating_point() assert out_float is not image_float - + torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1) From 1f2b77ca13a205a166ce27fc111b699eb7c5d5b2 Mon Sep 17 00:00:00 2001 From: GsnMithra Date: Sat, 17 Feb 2024 18:05:00 +0530 Subject: [PATCH 3/4] method description update --- torchvision/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index ea1c17230b7..1ad78fe14c3 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -336,13 +336,13 @@ def draw_keypoints( """ Draws Keypoints on given RGB image. - The values of the input image should be uint8 between 0 and 255. + The image values should be uint8 in [0, 255] or float in [0, 1]. Keypoints can be drawn for multiple instances at a time. This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint. Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float. keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances, in the format [x, y]. connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints @@ -363,7 +363,7 @@ def draw_keypoints( For more details, see :ref:`draw_keypoints_with_visibility`. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. + img (Tensor[C, H, W]): Image Tensor with keypoints drawn. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): From a6d9f9e3e1a3bd09b486f8b8e561213bf8486080 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 4 Mar 2024 13:36:21 +0000 Subject: [PATCH 4/4] Address comments --- test/test_utils.py | 34 ++++++++++++++++------------------ torchvision/utils.py | 11 ++++++++--- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index cbc7a66fae1..ffcad425aeb 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -247,24 +247,6 @@ def test_draw_segmentation_masks(colors, alpha, device): torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0) -def test_draw_keypoints_dtypes(): - image_uint8 = torch.full((3, 100, 100), 0, dtype=torch.uint8) - image_float = to_dtype(image_uint8, torch.float, scale=True) - - keypoints_cp = keypoints.clone() - - out_uint8 = utils.draw_keypoints(image_uint8, keypoints) - out_float = utils.draw_keypoints(image_float, keypoints) - - assert out_uint8.dtype == torch.uint8 - assert out_uint8 is not image_uint8 - - assert out_float.is_floating_point() - assert out_float is not image_float - - torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1) - - def test_draw_segmentation_masks_dtypes(): num_masks, h, w = 2, 100, 100 @@ -450,6 +432,22 @@ def test_draw_keypoints_visibility_default(): assert_equal(result, expected) +def test_draw_keypoints_dtypes(): + image_uint8 = torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8) + image_float = to_dtype(image_uint8, torch.float, scale=True) + + out_uint8 = utils.draw_keypoints(image_uint8, keypoints) + out_float = utils.draw_keypoints(image_float, keypoints) + + assert out_uint8.dtype == torch.uint8 + assert out_uint8 is not image_uint8 + + assert out_float.is_floating_point() + assert out_float is not image_float + + torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1) + + def test_draw_keypoints_errors(): h, w = 10, 10 img = torch.full((3, 100, 100), 0, dtype=torch.uint8) diff --git a/torchvision/utils.py b/torchvision/utils.py index 1ad78fe14c3..734cb127db1 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -398,8 +398,10 @@ def draw_keypoints( ) original_dtype = image.dtype - if image.is_floating_point(): - image = (image * 255).to(dtype=torch.uint8) + if original_dtype.is_floating_point: + from torchvision.transforms.v2.functional import to_dtype # noqa + + image = to_dtype(image, dtype=torch.uint8, scale=True) ndarr = image.permute(1, 2, 0).cpu().numpy() img_to_draw = Image.fromarray(ndarr) @@ -432,7 +434,10 @@ def draw_keypoints( width=width, ) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=original_dtype) + out = torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) + if original_dtype.is_floating_point: + out = to_dtype(out, dtype=original_dtype, scale=True) + return out # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization