diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index b5c242c1e64..6578f6c1c67 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -8,6 +8,7 @@ import numpy as np import torch from PIL import Image +from PIL.Image import Image as PILImage from torch import Tensor try: @@ -123,7 +124,7 @@ def _is_numpy_image(img: Any) -> bool: return img.ndim in {2, 3} -def to_tensor(pic) -> Tensor: +def to_tensor(pic: Union[PILImage, np.ndarray]) -> Tensor: """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This function does not support torchscript.