Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable gpu load nifti #8188

Draft
wants to merge 13 commits into
base: dev
Choose a base branch
from
56 changes: 50 additions & 6 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
import glob
import os
import re
import gzip
import io
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np
from torch.utils.data._utils.collate import np_str_obj_array_pattern

Expand All @@ -41,15 +42,19 @@
import pydicom
from nibabel.nifti1 import Nifti1Image
from PIL import Image as PILImage
import cupy as cp
import kvikio

has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True
has_nrrd = has_itk = has_nib = has_pil = has_pydicom = has_cp = has_kvikio = True
else:
itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
nib, has_nib = optional_import("nibabel")
Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image")
PILImage, has_pil = optional_import("PIL.Image")
pydicom, has_pydicom = optional_import("pydicom")
nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True)
cp, has_cp = optional_import("cupy")
kvikio, has_kvikio = optional_import("kvikio")

__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]

Expand Down Expand Up @@ -148,6 +153,17 @@ def _stack_images(image_list: list, meta_dict: dict):
return np.stack(image_list, axis=0)


def _stack_gpu_images(image_list: list, meta_dict: dict):
if len(image_list) <= 1:
return image_list[0]
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
return cp.concatenate(image_list, axis=channel_dim)
# stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
return cp.stack(image_list, axis=0)


@require_pkg(pkg_name="itk")
class ITKReader(ImageReader):
"""
Expand Down Expand Up @@ -880,12 +896,15 @@ def __init__(
channel_dim: str | int | None = None,
as_closest_canonical: bool = False,
squeeze_non_spatial_dims: bool = False,
gpu_load: bool = False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to_gpu

**kwargs,
):
super().__init__()
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
self.as_closest_canonical = as_closest_canonical
self.squeeze_non_spatial_dims = squeeze_non_spatial_dims
# TODO: add warning if not have required libs
self.gpu_load = gpu_load
self.kwargs = kwargs

def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
Expand Down Expand Up @@ -916,6 +935,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img_: list[Nifti1Image] = []

filenames: Sequence[PathLike] = ensure_tuple(data)
self.filenames = filenames
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
Expand All @@ -939,7 +959,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
img_array: list[np.ndarray] = []
compatible_meta: dict = {}

for i in ensure_tuple(img):
for i, filename in zip(ensure_tuple(img), self.filenames):
header = self._get_meta_dict(i)
header[MetaKeys.AFFINE] = self._get_affine(i)
header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i)
Expand All @@ -949,7 +969,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
header[MetaKeys.AFFINE] = self._get_affine(i)
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
header[MetaKeys.SPACE] = SpaceKeys.RAS
data = self._get_array_data(i)
data = self._get_array_data(i, filename)
if self.squeeze_non_spatial_dims:
for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1):
if data.shape[d - 1] == 1:
Expand All @@ -962,7 +982,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
else:
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
_copy_compatible_dict(header, compatible_meta)

if self.gpu_load:
return _stack_gpu_images(img_array, compatible_meta), compatible_meta
return _stack_images(img_array, compatible_meta), compatible_meta

def _get_meta_dict(self, img) -> dict:
Expand Down Expand Up @@ -1015,14 +1036,37 @@ def _get_spatial_shape(self, img):
spatial_rank = max(min(ndim, 3), 1)
return np.asarray(size[:spatial_rank])

def _get_array_data(self, img):
def _get_array_data(self, img, filename):
"""
Get the raw array data of the image, converted to Numpy array.

Args:
img: a Nibabel image object loaded from an image file.

"""
if self.gpu_load:
file_size = os.path.getsize(filename)
image = cp.empty(file_size, dtype=cp.uint8)
# suggestion from Ming: more tests, diff size
# cucim + nifti
with kvikio.CuFile(filename, "r") as f:
f.read(image)
if filename.endswith(".gz"):
# for compressed data, have to tansfer to CPU to decompress
# and then transfer back to GPU. It is not efficient compared to .nii file
# but it's still faster than Nibabel's default reader.
# TODO: can benchmark more, it may no need to do this since we don't have to use .gz
# since it's waste times especially in training
compressed_data = cp.asnumpy(image)
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
decompressed_data = gz_file.read()

file_size = len(decompressed_data)
image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8))
data_shape = img.shape
data_offset = img.dataobj.offset
data_dtype = img.dataobj.dtype
return image[data_offset:].view(data_dtype).reshape(data_shape, order="F")
return np.asanyarray(img.dataobj, order="C")


Expand Down
13 changes: 9 additions & 4 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,12 @@ def clone(self, **kwargs):

@staticmethod
def ensure_torch_and_prune_meta(
im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "."
im: NdarrayTensor,
meta: dict | None,
simple_keys: bool = False,
pattern: str | None = None,
sep: str = ".",
device: None | str | torch.device = None,
):
"""
Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary,
Expand All @@ -547,13 +552,13 @@ def ensure_torch_and_prune_meta(
sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary).
default is ".", see also :py:class:`monai.transforms.DeleteItemsd`.
e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``.
device: target device to put the Tensor data.

Returns:
By default, a `MetaTensor` is returned.
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
"""
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray

img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None, device=device) # potentially ascontiguousarray
# if not tracking metadata, return `torch.Tensor`
if not isinstance(img, MetaTensor):
return img
Expand All @@ -565,7 +570,7 @@ def ensure_torch_and_prune_meta(
if simple_keys:
# ensure affine is of type `torch.Tensor`
if MetaKeys.AFFINE in meta:
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE], device=device) # bc-breaking
remove_extra_metadata(meta) # bc-breaking

if pattern is not None:
Expand Down
8 changes: 5 additions & 3 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
prune_meta_pattern: str | None = None,
prune_meta_sep: str = ".",
expanduser: bool = True,
device: None | str | torch.device = None,
*args,
**kwargs,
) -> None:
Expand All @@ -163,6 +164,7 @@ def __init__(
e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``.
expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is.
args: additional parameters for reader if providing a reader name.
device: target device to put the loaded image.
kwargs: additional parameters for reader if providing a reader name.

Note:
Expand All @@ -184,6 +186,7 @@ def __init__(
self.pattern = prune_meta_pattern
self.sep = prune_meta_sep
self.expanduser = expanduser
self.device = device

self.readers: list[ImageReader] = []
for r in SUPPORTED_READERS: # set predefined readers as default
Expand Down Expand Up @@ -286,18 +289,17 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
f" The current registered: {self.readers}.\n{msg}"
)

img_array: NdarrayOrTensor
img_array, meta_data = reader.get_data(img)
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype, device=self.device)[0]
if not isinstance(meta_data, dict):
raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.")
# make sure all elements in metadata are little endian
meta_data = switch_endianness(meta_data, "<")

meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader
img = MetaTensor.ensure_torch_and_prune_meta(
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep, device=self.device
)
if self.ensure_channel_first:
img = EnsureChannelFirst()(img)
Expand Down
Loading