diff --git a/test/test_utils.py b/test/test_utils.py index 49dc553de3e..ffcad425aeb 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -432,6 +432,22 @@ def test_draw_keypoints_visibility_default(): assert_equal(result, expected) +def test_draw_keypoints_dtypes(): + image_uint8 = torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8) + image_float = to_dtype(image_uint8, torch.float, scale=True) + + out_uint8 = utils.draw_keypoints(image_uint8, keypoints) + out_float = utils.draw_keypoints(image_float, keypoints) + + assert out_uint8.dtype == torch.uint8 + assert out_uint8 is not image_uint8 + + assert out_float.is_floating_point() + assert out_float is not image_float + + torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1) + + def test_draw_keypoints_errors(): h, w = 10, 10 img = torch.full((3, 100, 100), 0, dtype=torch.uint8) diff --git a/torchvision/utils.py b/torchvision/utils.py index 79e533d4663..734cb127db1 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -336,13 +336,13 @@ def draw_keypoints( """ Draws Keypoints on given RGB image. - The values of the input image should be uint8 between 0 and 255. + The image values should be uint8 in [0, 255] or float in [0, 1]. Keypoints can be drawn for multiple instances at a time. This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint. Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float. keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances, in the format [x, y]. connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints @@ -363,7 +363,7 @@ def draw_keypoints( For more details, see :ref:`draw_keypoints_with_visibility`. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. + img (Tensor[C, H, W]): Image Tensor with keypoints drawn. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): @@ -371,8 +371,8 @@ def draw_keypoints( # validate image if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif not (image.dtype == torch.uint8 or image.is_floating_point()): + raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: @@ -397,6 +397,12 @@ def draw_keypoints( f"Got {visibility.shape = } and {keypoints.shape = }" ) + original_dtype = image.dtype + if original_dtype.is_floating_point: + from torchvision.transforms.v2.functional import to_dtype # noqa + + image = to_dtype(image, dtype=torch.uint8, scale=True) + ndarr = image.permute(1, 2, 0).cpu().numpy() img_to_draw = Image.fromarray(ndarr) draw = ImageDraw.Draw(img_to_draw) @@ -428,7 +434,10 @@ def draw_keypoints( width=width, ) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + out = torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) + if original_dtype.is_floating_point: + out = to_dtype(out, dtype=original_dtype, scale=True) + return out # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization