From 84d8cf3c14af42c5a9c407cc6221a4be82a73d0f Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 08:11:09 +0800 Subject: [PATCH 01/12] enable gpu load nifti Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 127 ++++++++++++++++++++++++++++++++++- monai/transforms/io/array.py | 2 + 2 files changed, 127 insertions(+), 2 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index b4ae562911..20cd46994d 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -14,12 +14,15 @@ import glob import os import re +import gzip +import io import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any +import torch import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -41,8 +44,10 @@ import pydicom from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage + import cupy as cp + import kvikio - has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True + has_nrrd = has_itk = has_nib = has_pil = has_pydicom = has_cp = has_kvikio = True else: itk, has_itk = optional_import("itk", allow_namespace_pkg=True) nib, has_nib = optional_import("nibabel") @@ -50,8 +55,10 @@ PILImage, has_pil = optional_import("PIL.Image") pydicom, has_pydicom = optional_import("pydicom") nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) + cp, has_cp = optional_import("cupy") + kvikio, has_kvikio = optional_import("kvikio") -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NibabelGPUReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] class ImageReader(ABC): @@ -1024,6 +1031,122 @@ def _get_array_data(self, img): """ return np.asanyarray(img.dataobj, order="C") + + +@require_pkg(pkg_name="nibabel") +@require_pkg(pkg_name="cupy") +@require_pkg(pkg_name="kvikio") +class NibabelGPUReader(NibabelReader): + + def _gds_load(self, file_path): + file_size = os.path.getsize(file_path) + image = cp.empty(file_size, dtype=cp.uint8) + with kvikio.CuFile(file_path, "r") as f: + f.read(image) + + if file_path.endswith(".gz"): + # for compressed data, have to tansfer to CPU to decompress + # and then transfer back to GPU. It is not efficient compared to .nii file + # but it's still faster than Nibabel's default reader. + # TODO: can benchmark more, it may no need to do this since we don't have to use .gz + # since it's waste times especially in training + compressed_data = cp.asnumpy(image) + with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: + decompressed_data = gz_file.read() + + file_size = len(decompressed_data) + image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) + + return image + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + Note that the returned object is Nibabel image object or list of Nibabel image objects. + + Args: + data: file name or a list of file names to read. + + """ + img_: list[Nifti1Image] = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + img = self._gds_load(name) + img_.append(img) # type: ignore + return img_ if len(filenames) > 1 else img_[0] + + def get_data(self, img): + """ + Extract data array and metadata from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of metadata. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + When loading a list of files, they are stacked together at a new dimension as the first dimension, + and the metadata of the first image is used to present the output metadata. + + Args: + img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. + + """ + compatible_meta: dict = {} + img_array = [] + for i in ensure_tuple(img): + header = self._get_header(i) + data_offset = header.get_data_offset() + data_shape = header.get_data_shape() + data_dtype = header.get_data_dtype() + affine = header.get_best_affine() + meta = dict(header) + meta[MetaKeys.AFFINE] = affine + meta[MetaKeys.ORIGINAL_AFFINE] = affine + # TODO: as_closest_canonical + # TODO: correct_nifti_header_if_necessary + meta[MetaKeys.SPATIAL_SHAPE] = data_shape + # TODO: figure out why always RAS for NibabelReader ? + # meta[MetaKeys.SPACE] = SpaceKeys.RAS + + data = i[data_offset:].view(data_dtype).reshape(data_shape, order="F") + # TODO: check channel + # if self.squeeze_non_spatial_dims: + img_array.append(data) + if self.channel_dim is None: # default to "no_channel" or -1 + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + else: + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + _copy_compatible_dict(header, compatible_meta) + + return self._stack_images(img_array, compatible_meta), compatible_meta + + def _get_header(self, img): + """ + Get the all the metadata of the image and convert to dict type. + + Args: + img: a Nibabel image object loaded from an image file. + + """ + header_bytes = cp.asnumpy(img[:348]) + header = nib.Nifti1Header.from_fileobj(io.BytesIO(header_bytes)) + # swap to little endian as PyTorch doesn't support big endian + try: + header = header.as_byteswapped("<") + except ValueError: + pass + return header + + def _stack_images(self, image_list: list, meta_dict: dict): + if len(image_list) <= 1: + return image_list[0] + if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): + channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) + return torch.cat(image_list, axis=channel_dim) + meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 + return torch.stack(image_list, dim=0) class NumpyReader(ImageReader): diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 4e71870fc9..eb0a0b88d8 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -35,6 +35,7 @@ ImageReader, ITKReader, NibabelReader, + NibabelGPUReader, NrrdReader, NumpyReader, PILReader, @@ -69,6 +70,7 @@ "numpyreader": NumpyReader, "pilreader": PILReader, "nibabelreader": NibabelReader, + "nibabelgpureader": NibabelGPUReader, } From ca1cfb81d953459eed3f20620c9b9999c5c95cc8 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 08:35:06 +0800 Subject: [PATCH 02/12] fix issue Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 20cd46994d..54a0fd3b4c 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1069,7 +1069,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): data: file name or a list of file names to read. """ - img_: list[Nifti1Image] = [] + img_ = [] filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() @@ -1113,12 +1113,12 @@ def get_data(self, img): # if self.squeeze_non_spatial_dims: img_array.append(data) if self.channel_dim is None: # default to "no_channel" or -1 - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1 ) else: - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - _copy_compatible_dict(header, compatible_meta) + meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + _copy_compatible_dict(meta, compatible_meta) return self._stack_images(img_array, compatible_meta), compatible_meta From d3551cc1d1a61f82e765ac353a21fbbb95322694 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 2 Nov 2024 00:35:38 +0000 Subject: [PATCH 03/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/image_reader.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 54a0fd3b4c..fe89fd3921 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1031,7 +1031,7 @@ def _get_array_data(self, img): """ return np.asanyarray(img.dataobj, order="C") - + @require_pkg(pkg_name="nibabel") @require_pkg(pkg_name="cupy") @@ -1053,12 +1053,12 @@ def _gds_load(self, file_path): compressed_data = cp.asnumpy(image) with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: decompressed_data = gz_file.read() - + file_size = len(decompressed_data) image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) return image - + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): """ Read image data from specified file or files, it can read a list of images @@ -1078,7 +1078,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img = self._gds_load(name) img_.append(img) # type: ignore return img_ if len(filenames) > 1 else img_[0] - + def get_data(self, img): """ Extract data array and metadata from loaded image and return them. @@ -1121,7 +1121,7 @@ def get_data(self, img): _copy_compatible_dict(meta, compatible_meta) return self._stack_images(img_array, compatible_meta), compatible_meta - + def _get_header(self, img): """ Get the all the metadata of the image and convert to dict type. @@ -1138,7 +1138,7 @@ def _get_header(self, img): except ValueError: pass return header - + def _stack_images(self, image_list: list, meta_dict: dict): if len(image_list) <= 1: return image_list[0] From 01a21e055acf5f4431d89ce13347ad30b5d8be35 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 09:10:45 +0800 Subject: [PATCH 04/12] update loadimage Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 108 ++++++++++++++--------------------- monai/transforms/io/array.py | 11 +++- 2 files changed, 53 insertions(+), 66 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index fe89fd3921..b9bacc303b 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -23,7 +23,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any import torch - +from monai.data.meta_tensor import MetaTensor import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -1038,13 +1038,22 @@ def _get_array_data(self, img): @require_pkg(pkg_name="kvikio") class NibabelGPUReader(NibabelReader): - def _gds_load(self, file_path): - file_size = os.path.getsize(file_path) + def read(self, filename: PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + Note that the returned object is Nibabel image object or list of Nibabel image objects. + + Args: + data: file name. + + """ + file_size = os.path.getsize(filename) image = cp.empty(file_size, dtype=cp.uint8) - with kvikio.CuFile(file_path, "r") as f: + with kvikio.CuFile(filename, "r") as f: f.read(image) - if file_path.endswith(".gz"): + if filename.endswith(".gz"): # for compressed data, have to tansfer to CPU to decompress # and then transfer back to GPU. It is not efficient compared to .nii file # but it's still faster than Nibabel's default reader. @@ -1056,29 +1065,8 @@ def _gds_load(self, file_path): file_size = len(decompressed_data) image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) - return image - def read(self, data: Sequence[PathLike] | PathLike, **kwargs): - """ - Read image data from specified file or files, it can read a list of images - and stack them together as multi-channel data in `get_data()`. - Note that the returned object is Nibabel image object or list of Nibabel image objects. - - Args: - data: file name or a list of file names to read. - - """ - img_ = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - img = self._gds_load(name) - img_.append(img) # type: ignore - return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img): """ Extract data array and metadata from loaded image and return them. @@ -1088,39 +1076,38 @@ def get_data(self, img): and the metadata of the first image is used to present the output metadata. Args: - img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. + img: a Nibabel image object loaded from an image file. """ - compatible_meta: dict = {} - img_array = [] - for i in ensure_tuple(img): - header = self._get_header(i) - data_offset = header.get_data_offset() - data_shape = header.get_data_shape() - data_dtype = header.get_data_dtype() - affine = header.get_best_affine() - meta = dict(header) - meta[MetaKeys.AFFINE] = affine - meta[MetaKeys.ORIGINAL_AFFINE] = affine - # TODO: as_closest_canonical - # TODO: correct_nifti_header_if_necessary - meta[MetaKeys.SPATIAL_SHAPE] = data_shape - # TODO: figure out why always RAS for NibabelReader ? - # meta[MetaKeys.SPACE] = SpaceKeys.RAS - - data = i[data_offset:].view(data_dtype).reshape(data_shape, order="F") - # TODO: check channel - # if self.squeeze_non_spatial_dims: - img_array.append(data) - if self.channel_dim is None: # default to "no_channel" or -1 - meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1 - ) - else: - meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - _copy_compatible_dict(meta, compatible_meta) - return self._stack_images(img_array, compatible_meta), compatible_meta + # TODO: use a formal way for device + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + header = self._get_header(img) + data_offset = header.get_data_offset() + data_shape = header.get_data_shape() + data_dtype = header.get_data_dtype() + affine = header.get_best_affine() + meta = dict(header) + meta[MetaKeys.AFFINE] = affine + meta[MetaKeys.ORIGINAL_AFFINE] = affine + # TODO: as_closest_canonical + # TODO: correct_nifti_header_if_necessary + meta[MetaKeys.SPATIAL_SHAPE] = data_shape + # TODO: figure out why always RAS for NibabelReader ? + # meta[MetaKeys.SPACE] = SpaceKeys.RAS + + data = img[data_offset:].view(data_dtype).reshape(data_shape, order="F") + # TODO: check channel + # if self.squeeze_non_spatial_dims: + if self.channel_dim is None: # default to "no_channel" or -1 + meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + else: + meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + + return MetaTensor(data, affine=affine, meta=meta, device=device) def _get_header(self, img): """ @@ -1139,15 +1126,6 @@ def _get_header(self, img): pass return header - def _stack_images(self, image_list: list, meta_dict: dict): - if len(image_list) <= 1: - return image_list[0] - if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): - channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) - return torch.cat(image_list, axis=channel_dim) - meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 - return torch.stack(image_list, dim=0) - class NumpyReader(ImageReader): """ diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index eb0a0b88d8..52f98ce8ee 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -258,6 +258,16 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader ) img, err = None, [] if reader is not None: + if isinstance(reader, NibabelGPUReader): + buffer = reader.read(filename) + img = reader.get_data(buffer) + # TODO: check ensure channel first + if self.ensure_channel_first: + img = EnsureChannelFirst()(img) + if self.image_only: + return img + return img, img.meta + img = reader.read(filename) # runtime specified reader else: for reader in self.readers[::-1]: @@ -288,7 +298,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" f" The current registered: {self.readers}.\n{msg}" ) - img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0] From be77a45be56f840fd096dd76fda235b50553deaa Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 09:13:02 +0800 Subject: [PATCH 05/12] add init Signed-off-by: Yiheng Wang --- monai/data/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 340c5eb8fa..14d0dfb193 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -50,7 +50,7 @@ from .folder_layout import FolderLayout, FolderLayoutBase from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd from .image_dataset import ImageDataset -from .image_reader import ImageReader, ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader, PydicomReader +from .image_reader import ImageReader, ITKReader, NibabelReader, NibabelGPUReader, NrrdReader, NumpyReader, PILReader, PydicomReader from .image_writer import ( SUPPORTED_WRITERS, ImageWriter, From b4a747ce096f2a691a69fedcd94a782588730de5 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 09:47:17 +0800 Subject: [PATCH 06/12] update filename Signed-off-by: Yiheng Wang --- monai/transforms/io/array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 52f98ce8ee..9012f2bb80 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -259,7 +259,8 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader img, err = None, [] if reader is not None: if isinstance(reader, NibabelGPUReader): - buffer = reader.read(filename) + # TODO: handle multiple filenames later + buffer = reader.read(filename[0]) img = reader.get_data(buffer) # TODO: check ensure channel first if self.ensure_channel_first: From f6af1202bd7ab913f5775219874c2e1e55974333 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 09:53:20 +0800 Subject: [PATCH 07/12] update supported reader Signed-off-by: Yiheng Wang --- monai/transforms/io/array.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 9012f2bb80..f465fe60a6 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -70,7 +70,6 @@ "numpyreader": NumpyReader, "pilreader": PILReader, "nibabelreader": NibabelReader, - "nibabelgpureader": NibabelGPUReader, } From 009fdf7d60d449e767ae1108cfdb54d3b35a72c5 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 09:59:39 +0800 Subject: [PATCH 08/12] update load image call Signed-off-by: Yiheng Wang --- monai/transforms/io/array.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index f465fe60a6..4e2fdfcda8 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -277,6 +277,16 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader break else: # try the user designated readers try: + if isinstance(reader, NibabelGPUReader): + # TODO: handle multiple filenames later + buffer = reader.read(filename[0]) + img = reader.get_data(buffer) + # TODO: check ensure channel first + if self.ensure_channel_first: + img = EnsureChannelFirst()(img) + if self.image_only: + return img + return img, img.meta img = reader.read(filename) except Exception as e: err.append(traceback.format_exc()) From 27d218a1a15f65cc96448a1438f4668a0a6e2831 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 11:36:15 +0800 Subject: [PATCH 09/12] remove useless header Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index b9bacc303b..68ef5420ae 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1088,7 +1088,7 @@ def get_data(self, img): data_shape = header.get_data_shape() data_dtype = header.get_data_dtype() affine = header.get_best_affine() - meta = dict(header) + meta = {} meta[MetaKeys.AFFINE] = affine meta[MetaKeys.ORIGINAL_AFFINE] = affine # TODO: as_closest_canonical From 1baa31b85fc887dd009f600c9db69964669c0dc7 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 12:03:04 +0800 Subject: [PATCH 10/12] add filename Signed-off-by: Yiheng Wang --- monai/transforms/io/array.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 4e2fdfcda8..455e38ac08 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -261,6 +261,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader # TODO: handle multiple filenames later buffer = reader.read(filename[0]) img = reader.get_data(buffer) + img.meta[Key.FILENAME_OR_OBJ] = filename[0] # TODO: check ensure channel first if self.ensure_channel_first: img = EnsureChannelFirst()(img) @@ -281,6 +282,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader # TODO: handle multiple filenames later buffer = reader.read(filename[0]) img = reader.get_data(buffer) + img.meta[Key.FILENAME_OR_OBJ] = filename[0] # TODO: check ensure channel first if self.ensure_channel_first: img = EnsureChannelFirst()(img) From f4531588232449ad9231aff797e857a474f88397 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 8 Nov 2024 08:11:20 +0000 Subject: [PATCH 11/12] reformat to add gpu load support on nibabelreader Signed-off-by: Yiheng Wang --- monai/data/__init__.py | 2 +- monai/data/image_reader.py | 143 +++++++++++------------------------ monai/data/meta_tensor.py | 13 +++- monai/transforms/io/array.py | 31 ++------ 4 files changed, 59 insertions(+), 130 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 14d0dfb193..340c5eb8fa 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -50,7 +50,7 @@ from .folder_layout import FolderLayout, FolderLayoutBase from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd from .image_dataset import ImageDataset -from .image_reader import ImageReader, ITKReader, NibabelReader, NibabelGPUReader, NrrdReader, NumpyReader, PILReader, PydicomReader +from .image_reader import ImageReader, ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader, PydicomReader from .image_writer import ( SUPPORTED_WRITERS, ImageWriter, diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 68ef5420ae..ae94fcc053 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -58,7 +58,7 @@ cp, has_cp = optional_import("cupy") kvikio, has_kvikio = optional_import("kvikio") -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NibabelGPUReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] class ImageReader(ABC): @@ -155,6 +155,17 @@ def _stack_images(image_list: list, meta_dict: dict): return np.stack(image_list, axis=0) +def _stack_gpu_images(image_list: list, meta_dict: dict): + if len(image_list) <= 1: + return image_list[0] + if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): + channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) + return cp.concatenate(image_list, axis=channel_dim) + # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified + meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 + return cp.stack(image_list, axis=0) + + @require_pkg(pkg_name="itk") class ITKReader(ImageReader): """ @@ -887,12 +898,15 @@ def __init__( channel_dim: str | int | None = None, as_closest_canonical: bool = False, squeeze_non_spatial_dims: bool = False, + gpu_load: bool = False, **kwargs, ): super().__init__() self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.as_closest_canonical = as_closest_canonical self.squeeze_non_spatial_dims = squeeze_non_spatial_dims + # TODO: add warning if not have required libs + self.gpu_load = gpu_load self.kwargs = kwargs def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: @@ -923,6 +937,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) + self.filenames = filenames kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -946,7 +961,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: img_array: list[np.ndarray] = [] compatible_meta: dict = {} - for i in ensure_tuple(img): + for i, filename in zip(ensure_tuple(img), self.filenames): header = self._get_meta_dict(i) header[MetaKeys.AFFINE] = self._get_affine(i) header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) @@ -956,7 +971,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: header[MetaKeys.AFFINE] = self._get_affine(i) header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) header[MetaKeys.SPACE] = SpaceKeys.RAS - data = self._get_array_data(i) + data = self._get_array_data(i, filename) if self.squeeze_non_spatial_dims: for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1): if data.shape[d - 1] == 1: @@ -969,7 +984,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: else: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim _copy_compatible_dict(header, compatible_meta) - + if self.gpu_load: + return _stack_gpu_images(img_array, compatible_meta), compatible_meta return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img) -> dict: @@ -1022,7 +1038,7 @@ def _get_spatial_shape(self, img): spatial_rank = max(min(ndim, 3), 1) return np.asarray(size[:spatial_rank]) - def _get_array_data(self, img): + def _get_array_data(self, img, filename): """ Get the raw array data of the image, converted to Numpy array. @@ -1030,103 +1046,32 @@ def _get_array_data(self, img): img: a Nibabel image object loaded from an image file. """ + if self.gpu_load: + file_size = os.path.getsize(filename) + image = cp.empty(file_size, dtype=cp.uint8) + # suggestion from Ming: more tests, diff size + # cucim + nifti + with kvikio.CuFile(filename, "r") as f: + f.read(image) + if filename.endswith(".gz"): + # for compressed data, have to tansfer to CPU to decompress + # and then transfer back to GPU. It is not efficient compared to .nii file + # but it's still faster than Nibabel's default reader. + # TODO: can benchmark more, it may no need to do this since we don't have to use .gz + # since it's waste times especially in training + compressed_data = cp.asnumpy(image) + with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: + decompressed_data = gz_file.read() + + file_size = len(decompressed_data) + image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) + data_shape = img.shape + data_offset = img.dataobj.offset + data_dtype = img.dataobj.dtype + return image[data_offset:].view(data_dtype).reshape(data_shape, order="F") return np.asanyarray(img.dataobj, order="C") -@require_pkg(pkg_name="nibabel") -@require_pkg(pkg_name="cupy") -@require_pkg(pkg_name="kvikio") -class NibabelGPUReader(NibabelReader): - - def read(self, filename: PathLike, **kwargs): - """ - Read image data from specified file or files, it can read a list of images - and stack them together as multi-channel data in `get_data()`. - Note that the returned object is Nibabel image object or list of Nibabel image objects. - - Args: - data: file name. - - """ - file_size = os.path.getsize(filename) - image = cp.empty(file_size, dtype=cp.uint8) - with kvikio.CuFile(filename, "r") as f: - f.read(image) - - if filename.endswith(".gz"): - # for compressed data, have to tansfer to CPU to decompress - # and then transfer back to GPU. It is not efficient compared to .nii file - # but it's still faster than Nibabel's default reader. - # TODO: can benchmark more, it may no need to do this since we don't have to use .gz - # since it's waste times especially in training - compressed_data = cp.asnumpy(image) - with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: - decompressed_data = gz_file.read() - - file_size = len(decompressed_data) - image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) - return image - - def get_data(self, img): - """ - Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. - When loading a list of files, they are stacked together at a new dimension as the first dimension, - and the metadata of the first image is used to present the output metadata. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - - # TODO: use a formal way for device - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - - header = self._get_header(img) - data_offset = header.get_data_offset() - data_shape = header.get_data_shape() - data_dtype = header.get_data_dtype() - affine = header.get_best_affine() - meta = {} - meta[MetaKeys.AFFINE] = affine - meta[MetaKeys.ORIGINAL_AFFINE] = affine - # TODO: as_closest_canonical - # TODO: correct_nifti_header_if_necessary - meta[MetaKeys.SPATIAL_SHAPE] = data_shape - # TODO: figure out why always RAS for NibabelReader ? - # meta[MetaKeys.SPACE] = SpaceKeys.RAS - - data = img[data_offset:].view(data_dtype).reshape(data_shape, order="F") - # TODO: check channel - # if self.squeeze_non_spatial_dims: - if self.channel_dim is None: # default to "no_channel" or -1 - meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1 - ) - else: - meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - - return MetaTensor(data, affine=affine, meta=meta, device=device) - - def _get_header(self, img): - """ - Get the all the metadata of the image and convert to dict type. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - header_bytes = cp.asnumpy(img[:348]) - header = nib.Nifti1Header.from_fileobj(io.BytesIO(header_bytes)) - # swap to little endian as PyTorch doesn't support big endian - try: - header = header.as_byteswapped("<") - except ValueError: - pass - return header - - class NumpyReader(ImageReader): """ Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects. diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index ac171e8508..959108eb47 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -532,7 +532,12 @@ def clone(self, **kwargs): @staticmethod def ensure_torch_and_prune_meta( - im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "." + im: NdarrayTensor, + meta: dict | None, + simple_keys: bool = False, + pattern: str | None = None, + sep: str = ".", + device: None | str | torch.device = None, ): """ Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary, @@ -547,13 +552,13 @@ def ensure_torch_and_prune_meta( sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary). default is ".", see also :py:class:`monai.transforms.DeleteItemsd`. e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``. + device: target device to put the Tensor data. Returns: By default, a `MetaTensor` is returned. However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned. """ - img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray - + img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None, device=device) # potentially ascontiguousarray # if not tracking metadata, return `torch.Tensor` if not isinstance(img, MetaTensor): return img @@ -565,7 +570,7 @@ def ensure_torch_and_prune_meta( if simple_keys: # ensure affine is of type `torch.Tensor` if MetaKeys.AFFINE in meta: - meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking + meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE], device=device) # bc-breaking remove_extra_metadata(meta) # bc-breaking if pattern is not None: diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 455e38ac08..2eb00ab38d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -35,7 +35,6 @@ ImageReader, ITKReader, NibabelReader, - NibabelGPUReader, NrrdReader, NumpyReader, PILReader, @@ -140,6 +139,7 @@ def __init__( prune_meta_pattern: str | None = None, prune_meta_sep: str = ".", expanduser: bool = True, + device: None | str | torch.device = None, *args, **kwargs, ) -> None: @@ -164,6 +164,7 @@ def __init__( e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``. expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is. args: additional parameters for reader if providing a reader name. + device: target device to put the loaded image. kwargs: additional parameters for reader if providing a reader name. Note: @@ -185,6 +186,7 @@ def __init__( self.pattern = prune_meta_pattern self.sep = prune_meta_sep self.expanduser = expanduser + self.device = device self.readers: list[ImageReader] = [] for r in SUPPORTED_READERS: # set predefined readers as default @@ -257,18 +259,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader ) img, err = None, [] if reader is not None: - if isinstance(reader, NibabelGPUReader): - # TODO: handle multiple filenames later - buffer = reader.read(filename[0]) - img = reader.get_data(buffer) - img.meta[Key.FILENAME_OR_OBJ] = filename[0] - # TODO: check ensure channel first - if self.ensure_channel_first: - img = EnsureChannelFirst()(img) - if self.image_only: - return img - return img, img.meta - img = reader.read(filename) # runtime specified reader else: for reader in self.readers[::-1]: @@ -278,17 +268,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader break else: # try the user designated readers try: - if isinstance(reader, NibabelGPUReader): - # TODO: handle multiple filenames later - buffer = reader.read(filename[0]) - img = reader.get_data(buffer) - img.meta[Key.FILENAME_OR_OBJ] = filename[0] - # TODO: check ensure channel first - if self.ensure_channel_first: - img = EnsureChannelFirst()(img) - if self.image_only: - return img - return img, img.meta img = reader.read(filename) except Exception as e: err.append(traceback.format_exc()) @@ -312,7 +291,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader ) img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) - img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0] + img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype, device=self.device)[0] if not isinstance(meta_data, dict): raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.") # make sure all elements in metadata are little endian @@ -320,7 +299,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader img = MetaTensor.ensure_torch_and_prune_meta( - img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep + img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep, device=self.device ) if self.ensure_channel_first: img = EnsureChannelFirst()(img) From 8d8ba0ff710415f69297aed1efac27ef449d530f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Nov 2024 08:11:43 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/image_reader.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index ae94fcc053..d602af2217 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -22,8 +22,6 @@ from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any -import torch -from monai.data.meta_tensor import MetaTensor import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern