diff --git a/test/test_utils.py b/test/test_utils.py index ffcad425aeb..ac394b51d63 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -116,6 +116,23 @@ def test_draw_boxes(): assert_equal(img, img_cp) +@pytest.mark.parametrize("fill", [True, False]) +def test_draw_boxes_dtypes(fill): + img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8) + out_uint8 = utils.draw_bounding_boxes(img_uint8, boxes, fill=fill) + + assert img_uint8 is not out_uint8 + assert out_uint8.dtype == torch.uint8 + + img_float = to_dtype(img_uint8, torch.float, scale=True) + out_float = utils.draw_bounding_boxes(img_float, boxes, fill=fill) + + assert img_float is not out_float + assert out_float.is_floating_point() + + torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1) + + @pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)]) def test_draw_boxes_colors(colors): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) @@ -152,7 +169,6 @@ def test_draw_boxes_grayscale(): def test_draw_invalid_boxes(): img_tp = ((1, 1, 1), (1, 2, 3)) - img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) @@ -162,8 +178,6 @@ def test_draw_invalid_boxes(): with pytest.raises(TypeError, match="Tensor expected"): utils.draw_bounding_boxes(img_tp, boxes) - with pytest.raises(ValueError, match="Tensor uint8 expected"): - utils.draw_bounding_boxes(img_wrong1, boxes) with pytest.raises(ValueError, match="Pass individual images, not batches"): utils.draw_bounding_boxes(img_wrong2, boxes) with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"): diff --git a/torchvision/utils.py b/torchvision/utils.py index 734cb127db1..94b3ec65c87 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -164,12 +164,12 @@ def draw_bounding_boxes( ) -> torch.Tensor: """ - Draws bounding boxes on given image. - The values of the input image should be uint8 between 0 and 255. + Draws bounding boxes on given RGB image. + The image values should be uint8 in [0, 255] or float in [0, 1]. If fill is True, Resulting Tensor should be saved as PNG image. Args: - image (Tensor): Tensor of shape (C x H x W) and dtype uint8. + image (Tensor): Tensor of shape (C, H, W) and dtype uint8 or float. boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and `0 <= ymin < ymax < H`. @@ -188,13 +188,14 @@ def draw_bounding_boxes( Returns: img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. """ + import torchvision.transforms.v2.functional as F # noqa if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(draw_bounding_boxes) if not isinstance(image, torch.Tensor): raise TypeError(f"Tensor expected, got {type(image)}") - elif image.dtype != torch.uint8: - raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + elif not (image.dtype == torch.uint8 or image.is_floating_point()): + raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size(0) not in {1, 3}: @@ -230,8 +231,11 @@ def draw_bounding_boxes( if image.size(0) == 1: image = torch.tile(image, (3, 1, 1)) - ndarr = image.permute(1, 2, 0).cpu().numpy() - img_to_draw = Image.fromarray(ndarr) + original_dtype = image.dtype + if original_dtype.is_floating_point: + image = F.to_dtype(image, dtype=torch.uint8, scale=True) + + img_to_draw = F.to_pil_image(image) img_boxes = boxes.to(torch.int64).tolist() if fill: @@ -250,7 +254,10 @@ def draw_bounding_boxes( margin = width + 1 draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) - return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + out = F.pil_to_tensor(img_to_draw) + if original_dtype.is_floating_point: + out = F.to_dtype(out, dtype=original_dtype, scale=True) + return out @torch.no_grad()