diff --git a/.github/workflows/build-wheels-windows.yml b/.github/workflows/build-wheels-windows.yml index 818d9f78b08..a269aea2604 100644 --- a/.github/workflows/build-wheels-windows.yml +++ b/.github/workflows/build-wheels-windows.yml @@ -25,6 +25,7 @@ jobs: os: windows test-infra-repository: pytorch/test-infra test-infra-ref: main + with-xpu: enable build: needs: generate-matrix strategy: diff --git a/.github/workflows/update-viablestrict.yml b/.github/workflows/update-viablestrict.yml deleted file mode 100644 index 8e4889b9ba7..00000000000 --- a/.github/workflows/update-viablestrict.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: Update viable/strict - -on: - pull_request: - paths: - - .github/workflows/update-viablestrict.yml - schedule: - - cron: 10,40 * * * * - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }} - cancel-in-progress: false - -jobs: - do_update_viablestrict: - uses: pytorch/test-infra/.github/workflows/update-viablestrict.yml@main - with: - repository: pytorch/vision - required_checks: "Build Linux,Build M1,Build Macos,Build Windows,Tests,CMake,Lint,Docs" - test-infra-ref: main - secrets: - ROCKSET_API_KEY: ${{ secrets.ROCKSET_API_KEY }} - GITHUB_DEPLOY_KEY : ${{ secrets.VISION_GITHUB_DEPLOY_KEY }} diff --git a/README.md b/README.md index 60583c45256..1076a7a186d 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ versions. | `torch` | `torchvision` | Python | | ------------------ | ------------------ | ------------------- | | `main` / `nightly` | `main` / `nightly` | `>=3.9`, `<=3.12` | +| `2.5` | `0.20` | `>=3.9`, `<=3.12` | | `2.4` | `0.19` | `>=3.8`, `<=3.12` | | `2.3` | `0.18` | `>=3.8`, `<=3.12` | | `2.2` | `0.17` | `>=3.8`, `<=3.11` | diff --git a/docs/source/io.rst b/docs/source/io.rst index 638f310bf69..6a76f95e897 100644 --- a/docs/source/io.rst +++ b/docs/source/io.rst @@ -3,34 +3,46 @@ Decoding / Encoding images and videos .. currentmodule:: torchvision.io -The :mod:`torchvision.io` package provides functions for performing IO -operations. They are currently specific to reading and writing images and -videos. +The :mod:`torchvision.io` module provides utilities for decoding and encoding +images and videos. -Images ------- +Image Decoding +-------------- Torchvision currently supports decoding JPEG, PNG, WEBP and GIF images. JPEG decoding can also be done on CUDA GPUs. -For encoding, JPEG (cpu and CUDA) and PNG are supported. +The main entry point is the :func:`~torchvision.io.decode_image` function, which +you can use as an alternative to ``PIL.Image.open()``. It will decode images +straight into image Tensors, thus saving you the conversion and allowing you to +run transforms/preproc natively on tensors. + +.. code:: + + from torchvision.io import decode_image + + img = decode_image("path_to_image", mode="RGB") + img.dtype # torch.uint8 + + # Or + raw_encoded_bytes = ... # read encoded bytes from your file system + img = decode_image(raw_encoded_bytes, mode="RGB") + + +:func:`~torchvision.io.decode_image` will automatically detect the image format, +and call the corresponding decoder. You can also use the lower-level +format-specific decoders which can be more powerful, e.g. if you want to +encode/decode JPEGs on CUDA. .. autosummary:: :toctree: generated/ :template: function.rst - read_image decode_image - encode_jpeg decode_jpeg - write_jpeg + encode_png decode_gif decode_webp - encode_png - decode_png - write_png - read_file - write_file .. autosummary:: :toctree: generated/ @@ -38,11 +50,51 @@ For encoding, JPEG (cpu and CUDA) and PNG are supported. ImageReadMode +Obsolete decoding function: +.. autosummary:: + :toctree: generated/ + :template: function.rst + + read_image + +Image Encoding +-------------- + +For encoding, JPEG (cpu and CUDA) and PNG are supported. + + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + encode_jpeg + write_jpeg + encode_png + write_png + +IO operations +------------- + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + read_file + write_file Video ----- +.. warning:: + + Torchvision supports video decoding through different APIs listed below, + some of which are still in BETA stage. In the near future, we intend to + centralize PyTorch's video decoding capabilities within the `torchcodec + `_ project. We encourage you to try + it out and share your feedback, as the torchvision video decoders will + eventually be deprecated. + .. autosummary:: :toctree: generated/ :template: function.rst @@ -52,45 +104,14 @@ Video write_video -Fine-grained video API -^^^^^^^^^^^^^^^^^^^^^^ +**Fine-grained video API** In addition to the :mod:`read_video` function, we provide a high-performance lower-level API for more fine-grained control compared to the :mod:`read_video` function. It does all this whilst fully supporting torchscript. -.. betastatus:: fine-grained video API - .. autosummary:: :toctree: generated/ :template: class.rst VideoReader - - -Example of inspecting a video: - -.. code:: python - - import torchvision - video_path = "path to a test video" - # Constructor allocates memory and a threaded decoder - # instance per video. At the moment it takes two arguments: - # path to the video file, and a wanted stream. - reader = torchvision.io.VideoReader(video_path, "video") - - # The information about the video can be retrieved using the - # `get_metadata()` method. It returns a dictionary for every stream, with - # duration and other relevant metadata (often frame rate) - reader_md = reader.get_metadata() - - # metadata is structured as a dict of dicts with following structure - # {"stream_type": {"attribute": [attribute per stream]}} - # - # following would print out the list of frame rates for every present video stream - print(reader_md["video"]["fps"]) - - # we explicitly select the stream we would like to operate on. In - # the constructor we select a default video stream, but - # in practice, we can set whichever stream we would like - video.set_current_stream("video:0") diff --git a/docs/source/models.rst b/docs/source/models.rst index 15540778602..53e8d87609e 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -226,10 +226,10 @@ Here is an example of how to use the pre-trained image classification models: .. code:: python - from torchvision.io import read_image + from torchvision.io import decode_image from torchvision.models import resnet50, ResNet50_Weights - img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") + img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") # Step 1: Initialize model with the best available weights weights = ResNet50_Weights.DEFAULT @@ -283,10 +283,10 @@ Here is an example of how to use the pre-trained quantized image classification .. code:: python - from torchvision.io import read_image + from torchvision.io import decode_image from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights - img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") + img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") # Step 1: Initialize model with the best available weights weights = ResNet50_QuantizedWeights.DEFAULT @@ -339,11 +339,11 @@ Here is an example of how to use the pre-trained semantic segmentation models: .. code:: python - from torchvision.io.image import read_image + from torchvision.io.image import decode_image from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights from torchvision.transforms.functional import to_pil_image - img = read_image("gallery/assets/dog1.jpg") + img = decode_image("gallery/assets/dog1.jpg") # Step 1: Initialize model with the best available weights weights = FCN_ResNet50_Weights.DEFAULT @@ -411,12 +411,12 @@ Here is an example of how to use the pre-trained object detection models: .. code:: python - from torchvision.io.image import read_image + from torchvision.io.image import decode_image from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights from torchvision.utils import draw_bounding_boxes from torchvision.transforms.functional import to_pil_image - img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") + img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") # Step 1: Initialize model with the best available weights weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT diff --git a/gallery/others/plot_repurposing_annotations.py b/gallery/others/plot_repurposing_annotations.py index 9d723064ee4..2c2e10ffb2a 100644 --- a/gallery/others/plot_repurposing_annotations.py +++ b/gallery/others/plot_repurposing_annotations.py @@ -66,12 +66,12 @@ def show(imgs): # We will take images and masks from the `PenFudan Dataset `_. -from torchvision.io import read_image +from torchvision.io import decode_image img_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054.png") mask_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054_mask.png") -img = read_image(img_path) -mask = read_image(mask_path) +img = decode_image(img_path) +mask = decode_image(mask_path) # %% @@ -181,8 +181,8 @@ def __getitem__(self, idx): img_path = os.path.join(self.root, "PNGImages", self.imgs[idx]) mask_path = os.path.join(self.root, "PedMasks", self.masks[idx]) - img = read_image(img_path) - mask = read_image(mask_path) + img = decode_image(img_path) + mask = decode_image(mask_path) img = F.convert_image_dtype(img, dtype=torch.float) mask = F.convert_image_dtype(mask, dtype=torch.float) diff --git a/gallery/others/plot_scripted_tensor_transforms.py b/gallery/others/plot_scripted_tensor_transforms.py index 5c49a7ca894..da2213347e3 100644 --- a/gallery/others/plot_scripted_tensor_transforms.py +++ b/gallery/others/plot_scripted_tensor_transforms.py @@ -21,7 +21,7 @@ import torch.nn as nn import torchvision.transforms as v1 -from torchvision.io import read_image +from torchvision.io import decode_image plt.rcParams["savefig.bbox"] = 'tight' torch.manual_seed(1) @@ -39,8 +39,8 @@ # :class:`torch.nn.Sequential` instead of # :class:`~torchvision.transforms.v2.Compose`: -dog1 = read_image(str(ASSETS_PATH / 'dog1.jpg')) -dog2 = read_image(str(ASSETS_PATH / 'dog2.jpg')) +dog1 = decode_image(str(ASSETS_PATH / 'dog1.jpg')) +dog2 = decode_image(str(ASSETS_PATH / 'dog2.jpg')) transforms = torch.nn.Sequential( v1.RandomCrop(224), diff --git a/gallery/others/plot_visualization_utils.py b/gallery/others/plot_visualization_utils.py index d0a214a7340..72c35b53717 100644 --- a/gallery/others/plot_visualization_utils.py +++ b/gallery/others/plot_visualization_utils.py @@ -42,11 +42,11 @@ def show(imgs): # image of dtype ``uint8`` as input. from torchvision.utils import make_grid -from torchvision.io import read_image +from torchvision.io import decode_image from pathlib import Path -dog1_int = read_image(str(Path('../assets') / 'dog1.jpg')) -dog2_int = read_image(str(Path('../assets') / 'dog2.jpg')) +dog1_int = decode_image(str(Path('../assets') / 'dog1.jpg')) +dog2_int = decode_image(str(Path('../assets') / 'dog2.jpg')) dog_list = [dog1_int, dog2_int] grid = make_grid(dog_list) @@ -362,9 +362,9 @@ def show(imgs): # from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights -from torchvision.io import read_image +from torchvision.io import decode_image -person_int = read_image(str(Path("../assets") / "person1.jpg")) +person_int = decode_image(str(Path("../assets") / "person1.jpg")) weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() diff --git a/gallery/transforms/plot_transforms_getting_started.py b/gallery/transforms/plot_transforms_getting_started.py index 0faf79c46af..2696a9e57e7 100644 --- a/gallery/transforms/plot_transforms_getting_started.py +++ b/gallery/transforms/plot_transforms_getting_started.py @@ -21,14 +21,14 @@ plt.rcParams["savefig.bbox"] = 'tight' from torchvision.transforms import v2 -from torchvision.io import read_image +from torchvision.io import decode_image torch.manual_seed(1) # If you're trying to run that on Colab, you can download the assets and the # helpers from https://github.com/pytorch/vision/tree/main/gallery/ from helpers import plot -img = read_image(str(Path('../assets') / 'astronaut.jpg')) +img = decode_image(str(Path('../assets') / 'astronaut.jpg')) print(f"{type(img) = }, {img.dtype = }, {img.shape = }") # %% diff --git a/packaging/windows/internal/vc_env_helper.bat b/packaging/windows/internal/vc_env_helper.bat index d3484a66e9f..699876beb8a 100644 --- a/packaging/windows/internal/vc_env_helper.bat +++ b/packaging/windows/internal/vc_env_helper.bat @@ -28,6 +28,8 @@ if "%VSDEVCMD_ARGS%" == "" ( @echo on +if "%CU_VERSION%" == "xpu" call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" + set DISTUTILS_USE_SDK=1 set args=%1 diff --git a/setup.py b/setup.py index dbe8ce58aa2..05a07c826fc 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None) BUILD_CUDA_SOURCES = (torch.cuda.is_available() and ((CUDA_HOME is not None) or IS_ROCM)) or FORCE_CUDA -PACKAGE_NAME = "torchvision" +package_name = os.getenv("TORCHVISION_PACKAGE_NAME", "torchvision") print("Torchvision build configuration:") print(f"{FORCE_CUDA = }") @@ -98,7 +98,7 @@ def get_dist(pkgname): except DistributionNotFound: return None - pytorch_dep = "torch" + pytorch_dep = os.getenv("TORCH_PACKAGE_NAME", "torch") if os.getenv("PYTORCH_VERSION"): pytorch_dep += "==" + os.getenv("PYTORCH_VERSION") @@ -366,7 +366,7 @@ def make_image_extension(): else: warnings.warn("Building torchvision without AVIF support") - if USE_NVJPEG and torch.cuda.is_available(): + if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() if nvjpeg_found: @@ -376,6 +376,8 @@ def make_image_extension(): Extension = CUDAExtension else: warnings.warn("Building torchvision without NVJPEG support") + elif USE_NVJPEG: + warnings.warn("Building torchvision without NVJPEG support") return Extension( name="torchvision.image", @@ -559,7 +561,7 @@ def run(self): version, sha = get_version() write_version_file(version, sha) - print(f"Building wheel {PACKAGE_NAME}-{version}") + print(f"Building wheel {package_name}-{version}") with open("README.md") as f: readme = f.read() @@ -571,7 +573,7 @@ def run(self): ] setup( - name=PACKAGE_NAME, + name=package_name, version=version, author="PyTorch Core Team", author_email="soumith@pytorch.org", @@ -581,7 +583,7 @@ def run(self): long_description_content_type="text/markdown", license="BSD", packages=find_packages(exclude=("test",)), - package_data={PACKAGE_NAME: ["*.dll", "*.dylib", "*.so", "prototype/datasets/_builtin/*.categories"]}, + package_data={package_name: ["*.dll", "*.dylib", "*.so", "prototype/datasets/_builtin/*.categories"]}, zip_safe=False, install_requires=get_requirements(), extras_require={ diff --git a/test/smoke_test.py b/test/smoke_test.py index f98d019bea5..3a44ae3efe9 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -6,7 +6,7 @@ import torch import torchvision -from torchvision.io import decode_jpeg, decode_webp, read_file, read_image +from torchvision.io import decode_image, decode_jpeg, decode_webp, read_file from torchvision.models import resnet50, ResNet50_Weights @@ -21,13 +21,13 @@ def smoke_test_torchvision() -> None: def smoke_test_torchvision_read_decode() -> None: - img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")) + img_jpg = decode_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")) if img_jpg.shape != (3, 606, 517): raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}") - img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png")) + img_png = decode_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png")) if img_png.shape != (4, 471, 354): raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}") - img_webp = read_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp")) + img_webp = decode_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp")) if img_webp.shape != (3, 100, 100): raise RuntimeError(f"Unexpected shape of img_webp: {img_webp.shape}") @@ -54,7 +54,7 @@ def smoke_test_compile() -> None: def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: - img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device) + img = decode_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device) # Step 1: Initialize model with the best available weights weights = ResNet50_Weights.DEFAULT diff --git a/test/test_image.py b/test/test_image.py index 4d14af638a0..f3c2984b348 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -4,6 +4,7 @@ import os import re import sys +from contextlib import nullcontext from pathlib import Path import numpy as np @@ -13,6 +14,7 @@ import torchvision.transforms.v2.functional as F from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence +from torchvision._internally_replaced_utils import IN_FBCODE from torchvision.io.image import ( _decode_avif, _decode_heic, @@ -1044,5 +1046,45 @@ def test_decode_heic(decode_fun, scripted): img += 123 # make sure image buffer wasn't freed by underlying decoding lib +@pytest.mark.parametrize("input_type", ("Path", "str", "tensor")) +@pytest.mark.parametrize("scripted", (False, True)) +def test_decode_image_path(input_type, scripted): + # Check that decode_image can support not just tensors as input + path = next(get_images(IMAGE_ROOT, ".jpg")) + if input_type == "Path": + input = Path(path) + elif input_type == "str": + input = path + elif input_type == "tensor": + input = read_file(path) + else: + raise ValueError("Oops") + + if scripted and input_type == "Path": + pytest.xfail(reason="Can't pass a Path when scripting") + + decode_fun = torch.jit.script(decode_image) if scripted else decode_image + decode_fun(input) + + +def test_mode_str(): + # Make sure decode_image supports string modes. We just test decode_image, + # not all of the decoding functions, but they should all support that too. + # Torchscript fails when passing strings, which is expected. + path = next(get_images(IMAGE_ROOT, ".png")) + assert decode_image(path, mode="RGB").shape[0] == 3 + assert decode_image(path, mode="rGb").shape[0] == 3 + assert decode_image(path, mode="GRAY").shape[0] == 1 + assert decode_image(path, mode="RGBA").shape[0] == 4 + + +def test_avif_heic_fbcode(): + cm = nullcontext() if IN_FBCODE else pytest.raises(ImportError, match="cannot import") + with cm: + from torchvision.io import decode_heic # noqa + with cm: + from torchvision.io import decode_avif # noqa + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f9218c3e840..e16c0677c9f 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -6169,3 +6169,50 @@ def test_transform_sequence_len_error(self, quality): def test_transform_invalid_quality_error(self, quality): with pytest.raises(ValueError, match="quality must be an integer from 1 to 100"): transforms.JPEG(quality=quality) + + +class TestUtils: + # TODO: Still need to test has_all, has_any, check_type and get_bouding_boxes + @pytest.mark.parametrize( + "make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask] + ) + @pytest.mark.parametrize( + "make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask] + ) + @pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw]) + def test_query_size_and_query_chw(self, make_input1, make_input2, query): + size = (32, 64) + input1 = make_input1(size) + input2 = make_input2(size) + + if query is transforms.query_chw and not any( + transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + for inpt in (input1, input2) + ): + return + + expected = size if query is transforms.query_size else ((3,) + size) + assert query([input1, input2]) == expected + + @pytest.mark.parametrize( + "make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask] + ) + @pytest.mark.parametrize( + "make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask] + ) + @pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw]) + def test_different_sizes(self, make_input1, make_input2, query): + input1 = make_input1((10, 10)) + input2 = make_input2((20, 20)) + if query is transforms.query_chw and not all( + transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + for inpt in (input1, input2) + ): + return + with pytest.raises(ValueError, match="Found multiple"): + query([input1, input2]) + + @pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw]) + def test_no_valid_input(self, query): + with pytest.raises(TypeError, match="No image"): + query(["blah"]) diff --git a/torchvision/_internally_replaced_utils.py b/torchvision/_internally_replaced_utils.py index d9a6e261ea2..e0fa72489f1 100644 --- a/torchvision/_internally_replaced_utils.py +++ b/torchvision/_internally_replaced_utils.py @@ -6,6 +6,7 @@ _HOME = os.path.join(_get_torch_home(), "datasets", "vision") _USE_SHARDED_DATASETS = False +IN_FBCODE = False def _download_file_from_remote_location(fpath: str, url: str) -> None: diff --git a/torchvision/csrc/io/image/cpu/decode_avif.cpp b/torchvision/csrc/io/image/cpu/decode_avif.cpp index 3cb326e2f11..c3ecd581e42 100644 --- a/torchvision/csrc/io/image/cpu/decode_avif.cpp +++ b/torchvision/csrc/io/image/cpu/decode_avif.cpp @@ -52,7 +52,6 @@ torch::Tensor decode_avif( result == AVIF_RESULT_OK, "avifDecoderParse failed: ", avifResultToString(result)); - printf("avif num images = %d\n", decoder->imageCount); TORCH_CHECK( decoder->imageCount == 1, "Avif file contains more than one image"); diff --git a/torchvision/datasets/stanford_cars.py b/torchvision/datasets/stanford_cars.py index c029ed0d358..6264de82eb7 100644 --- a/torchvision/datasets/stanford_cars.py +++ b/torchvision/datasets/stanford_cars.py @@ -15,6 +15,7 @@ class StanfordCars(VisionDataset): has been split roughly in a 50-50 split The original URL is https://ai.stanford.edu/~jkrause/cars/car_dataset.html, but it is broken. + Follow the instructions in ``download`` argument to obtain and use the dataset offline. .. note:: @@ -29,8 +30,12 @@ class StanfordCars(VisionDataset): target and transforms it. download (bool, optional): This parameter exists for backward compatibility but it does not download the dataset, since the original URL is not available anymore. The dataset - seems to be available on Kaggle so you can try to manually download it using - `these instructions `_. + seems to be available on Kaggle so you can try to manually download and configure it using + `these instructions `_, + or use an integrated + `dataset on Kaggle `_. + In both cases, first download and configure the dataset locally, and use the dataset with + ``"download=False"``. """ def __init__( diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index a604ea1fdb6..0dcbd7e9cea 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -74,3 +74,10 @@ "Video", "VideoReader", ] + +from .._internally_replaced_utils import IN_FBCODE + +if IN_FBCODE: + from .image import _decode_avif as decode_avif, _decode_heic as decode_heic + + __all__ += ["decode_avif", "decode_heic"] diff --git a/torchvision/io/image.py b/torchvision/io/image.py index f1df0d52672..cb48d0e6816 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -20,19 +20,25 @@ class ImageReadMode(Enum): - """ - Support for various modes while reading images. + """Allow automatic conversion to RGB, RGBA, etc while decoding. + + .. note:: + + You don't need to use this struct, you can just pass strings to all + ``mode`` parameters, e.g. ``mode="RGB"``. - Use ``ImageReadMode.UNCHANGED`` for loading the image as-is, - ``ImageReadMode.GRAY`` for converting to grayscale, - ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency, - ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for - RGB with transparency. + The different available modes are the following. + + - UNCHANGED: loads the image as-is + - RGB: converts to RGB + - RGBA: converts to RGB with transparency (also aliased as RGB_ALPHA) + - GRAY: converts to grayscale + - GRAY_ALPHA: converts to grayscale with transparency .. note:: - Some decoders won't support all possible values, e.g. a decoder may only - support "RGB" and "RGBA" mode. + Some decoders won't support all possible values, e.g. GRAY and + GRAY_ALPHA are only supported for PNG and JPEG images. """ UNCHANGED = 0 @@ -40,12 +46,12 @@ class ImageReadMode(Enum): GRAY_ALPHA = 2 RGB = 3 RGB_ALPHA = 4 + RGBA = RGB_ALPHA # Alias for convenience def read_file(path: str) -> torch.Tensor: """ - Reads and outputs the bytes contents of a file as a uint8 Tensor - with one dimension. + Return the bytes contents of a file as a uint8 1D Tensor. Args: path (str or ``pathlib.Path``): the path to the file to be read @@ -61,8 +67,7 @@ def read_file(path: str) -> torch.Tensor: def write_file(filename: str, data: torch.Tensor) -> None: """ - Writes the contents of an uint8 tensor with one dimension to a - file. + Write the content of an uint8 1D tensor to a file. Args: filename (str or ``pathlib.Path``): the path to the file to be written @@ -92,10 +97,9 @@ def decode_png( Args: input (Tensor[1]): a one dimensional uint8 tensor containing the raw bytes of the PNG image. - mode (ImageReadMode): the read mode used for optionally - converting the image. Default: ``ImageReadMode.UNCHANGED``. - See `ImageReadMode` class for more information on various - available modes. + mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB". + Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode` + for available modes. apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor. Default: False. @@ -104,6 +108,8 @@ def decode_png( """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(decode_png) + if isinstance(mode, str): + mode = ImageReadMode[mode.upper()] output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation) return output @@ -153,8 +159,7 @@ def decode_jpeg( device: Union[str, torch.device] = "cpu", apply_exif_orientation: bool = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Decode JPEG image(s) into 3 dimensional RGB or grayscale Tensor(s). + """Decode JPEG image(s) into 3D RGB or grayscale Tensor(s), on CPU or CUDA. The values of the output tensor are uint8 between 0 and 255. @@ -168,12 +173,9 @@ def decode_jpeg( input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing the raw bytes of the JPEG image. The tensor(s) must be on CPU, regardless of the ``device`` parameter. - mode (ImageReadMode): the read mode used for optionally - converting the image(s). The supported modes are: ``ImageReadMode.UNCHANGED``, - ``ImageReadMode.GRAY`` and ``ImageReadMode.RGB`` - Default: ``ImageReadMode.UNCHANGED``. - See ``ImageReadMode`` class for more information on various - available modes. + mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB". + Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode` + for available modes. device (str or torch.device): The device on which the decoded image will be stored. If a cuda device is specified, the image will be decoded with `nvjpeg `_. This is only @@ -198,6 +200,8 @@ def decode_jpeg( _log_api_usage_once(decode_jpeg) if isinstance(device, str): device = torch.device(device) + if isinstance(mode, str): + mode = ImageReadMode[mode.upper()] if isinstance(input, list): if len(input) == 0: @@ -223,9 +227,7 @@ def decode_jpeg( def encode_jpeg( input: Union[torch.Tensor, List[torch.Tensor]], quality: int = 75 ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Takes a (list of) input tensor(s) in CHW layout and returns a (list of) buffer(s) with the contents - of the corresponding JPEG file(s). + """Encode RGB tensor(s) into raw encoded jpeg bytes, on CPU or CUDA. .. note:: Passing a list of CUDA tensors is more efficient than repeated individual calls to ``encode_jpeg``. @@ -277,13 +279,13 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): def decode_image( - input: torch.Tensor, + input: Union[torch.Tensor, str], mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False, ) -> torch.Tensor: - """ - Detect whether an image is a JPEG, PNG, WEBP, or GIF and performs the - appropriate operation to decode the image into a Tensor. + """Decode an image into a uint8 tensor, from a path or from raw encoded bytes. + + Currently supported image formats are jpeg, png, gif and webp. The values of the output tensor are in uint8 in [0, 255] for most cases. @@ -295,12 +297,12 @@ def decode_image( tensor. Args: - input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the - image. - mode (ImageReadMode): the read mode used for optionally converting the image. - Default: ``ImageReadMode.UNCHANGED``. - See ``ImageReadMode`` class for more information on various - available modes. Only applies to JPEG and PNG images. + input (Tensor or str or ``pathlib.Path``): The image to decode. If a + tensor is passed, it must be one dimensional uint8 tensor containing + the raw bytes of the image. Otherwise, this must be a path to the image file. + mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB". + Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode` + for available modes. apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor. Only applies to JPEG and PNG images. Default: False. @@ -309,6 +311,10 @@ def decode_image( """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(decode_image) + if not isinstance(input, torch.Tensor): + input = read_file(str(input)) + if isinstance(mode, str): + mode = ImageReadMode[mode.upper()] output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation) return output @@ -318,30 +324,7 @@ def read_image( mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False, ) -> torch.Tensor: - """ - Reads a JPEG, PNG, WEBP, or GIF image into a Tensor. - - The values of the output tensor are in uint8 in [0, 255] for most cases. - - If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535] - (supported from torchvision ``0.21``. Since uint16 support is limited in - pytorch, we recommend calling - :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True`` - after this function to convert the decoded image into a uint8 or float - tensor. - - Args: - path (str or ``pathlib.Path``): path of the image. - mode (ImageReadMode): the read mode used for optionally converting the image. - Default: ``ImageReadMode.UNCHANGED``. - See ``ImageReadMode`` class for more information on various - available modes. Only applies to JPEG and PNG images. - apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor. - Only applies to JPEG and PNG images. Default: False. - - Returns: - output (Tensor[image_channels, image_height, image_width]) - """ + """[OBSOLETE] Use :func:`~torchvision.io.decode_image` instead.""" if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(read_image) data = read_file(path) @@ -380,15 +363,17 @@ def decode_webp( Args: input (Tensor[1]): a one dimensional contiguous uint8 tensor containing the raw bytes of the WEBP image. - mode (ImageReadMode): The read mode used for optionally - converting the image color space. Default: ``ImageReadMode.UNCHANGED``. - Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``. + mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB". + Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode` + for available modes. Returns: Decoded image (Tensor[image_channels, image_height, image_width]) """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(decode_webp) + if isinstance(mode, str): + mode = ImageReadMode[mode.upper()] return torch.ops.image.decode_webp(input, mode.value) @@ -409,15 +394,17 @@ def _decode_avif( Args: input (Tensor[1]): a one dimensional contiguous uint8 tensor containing the raw bytes of the AVIF image. - mode (ImageReadMode): The read mode used for optionally - converting the image color space. Default: ``ImageReadMode.UNCHANGED``. - Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``. + mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB". + Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode` + for available modes. Returns: Decoded image (Tensor[image_channels, image_height, image_width]) """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(_decode_avif) + if isinstance(mode, str): + mode = ImageReadMode[mode.upper()] return torch.ops.image.decode_avif(input, mode.value) @@ -435,13 +422,15 @@ def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN Args: input (Tensor[1]): a one dimensional contiguous uint8 tensor containing the raw bytes of the HEIC image. - mode (ImageReadMode): The read mode used for optionally - converting the image color space. Default: ``ImageReadMode.UNCHANGED``. - Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``. + mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB". + Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode` + for available modes. Returns: Decoded image (Tensor[image_channels, image_height, image_width]) """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(_decode_heic) + if isinstance(mode, str): + mode = ImageReadMode[mode.upper()] return torch.ops.image.decode_heic(input, mode.value) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index c8f7d2ebde2..9f768ed555d 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -62,7 +62,20 @@ def write_video( audio_options: Optional[Dict[str, Any]] = None, ) -> None: """ - Writes a 4d tensor in [T, H, W, C] format in a video file + Writes a 4d tensor in [T, H, W, C] format in a video file. + + This function relies on PyAV (therefore, ultimately FFmpeg) to encode + videos, you can get more fine-grained control by referring to the other + options at your disposal within `the FFMpeg wiki + `_. + + .. warning:: + + In the near future, we intend to centralize PyTorch's video decoding + capabilities within the `torchcodec + `_ project. We encourage you to + try it out and share your feedback, as the torchvision video decoders + will eventually be deprecated. Args: filename (str): path where the video will be saved @@ -70,12 +83,25 @@ def write_video( as a uint8 tensor in [T, H, W, C] format fps (Number): video frames per second video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc. - options (Dict): dictionary containing options to be passed into the PyAV video stream + options (Dict): dictionary containing options to be passed into the PyAV video stream. + The list of options is codec-dependent and can all + be found from `the FFMpeg wiki `_. audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels and N is the number of samples audio_fps (Number): audio sample rate, typically 44100 or 48000 audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc. - audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream + audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream. + The list of options is codec-dependent and can all + be found from `the FFMpeg wiki `_. + + Examples:: + >>> # Creating libx264 video with CRF 17, for visually lossless footage: + >>> + >>> from torchvision.io import write_video + >>> # 1000 frames of 100x100, 3-channel image. + >>> vid = torch.randn(1000, 100, 100, 3, dtype = torch.uint8) + >>> write_video("video.mp4", options = {"crf": "17"}) + """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(write_video) @@ -243,6 +269,14 @@ def read_video( """ Reads a video from a file, returning both the video frames and the audio frames + .. warning:: + + In the near future, we intend to centralize PyTorch's video decoding + capabilities within the `torchcodec + `_ project. We encourage you to + try it out and share your feedback, as the torchvision video decoders + will eventually be deprecated. + Args: filename (str): path to the video file. If using the pyav backend, this can be whatever ``av.open`` accepts. start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): @@ -367,6 +401,14 @@ def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[in """ List the video frames timestamps. + .. warning:: + + In the near future, we intend to centralize PyTorch's video decoding + capabilities within the `torchcodec + `_ project. We encourage you to + try it out and share your feedback, as the torchvision video decoders + will eventually be deprecated. + Note that the function decodes the whole video frame-by-frame. Args: diff --git a/torchvision/io/video_reader.py b/torchvision/io/video_reader.py index 505909fd984..cf319fe288e 100644 --- a/torchvision/io/video_reader.py +++ b/torchvision/io/video_reader.py @@ -52,6 +52,14 @@ class VideoReader: backends: video_reader, pyav, and cuda. Backends can be set via `torchvision.set_video_backend` function. + .. warning:: + + In the near future, we intend to centralize PyTorch's video decoding + capabilities within the `torchcodec + `_ project. We encourage you to + try it out and share your feedback, as the torchvision video decoders + will eventually be deprecated. + .. betastatus:: VideoReader class Example: diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 33d83f1fe3f..2d66917b6ea 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -55,5 +55,6 @@ ) from ._temporal import UniformTemporalSubsample from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor +from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size from ._deprecated import ToTensor # usort: skip diff --git a/torchvision/utils.py b/torchvision/utils.py index 33ac826e5ce..b69edcb572e 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -604,9 +604,9 @@ def _parse_colors( f"Number of colors must be equal or larger than the number of objects, but got {len(colors)} < {num_objects}." ) elif not isinstance(colors, (tuple, str)): - raise ValueError("`colors` must be a tuple or a string, or a list thereof, but got {colors}.") + raise ValueError(f"`colors` must be a tuple or a string, or a list thereof, but got {colors}.") elif isinstance(colors, tuple) and len(colors) != 3: - raise ValueError("If passed as tuple, colors should be an RGB triplet, but got {colors}.") + raise ValueError(f"If passed as tuple, colors should be an RGB triplet, but got {colors}.") else: # colors specifies a single color for all objects colors = [colors] * num_objects