From 90239782c36ad0540addde68e22e76c31e8403b0 Mon Sep 17 00:00:00 2001 From: roman807 Date: Fri, 19 Apr 2024 08:33:49 +0200 Subject: [PATCH 1/3] Update lightning to `2.2.2` (#371) --- pdm.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pdm.lock b/pdm.lock index d97b9fee..dee63384 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "all", "typecheck", "lint", "vision", "test"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:c10565d042a67f776991cc9989f4cc9aee39c75161829ad679c1b394d2bdb906" +content_hash = "sha256:ffa7b5a5665a4fd7d142fd7c8cf32c533a041f9823b79d24f24a228ec144a1f7" [[package]] name = "absl-py" @@ -760,7 +760,7 @@ files = [ [[package]] name = "lightning" -version = "2.2.1" +version = "2.3.0.dev20240407" requires_python = ">=3.8" summary = "The Deep Learning framework to train, deploy, and ship AI products Lightning fast." groups = ["default"] @@ -777,8 +777,8 @@ dependencies = [ "typing-extensions<6.0,>=4.4.0", ] files = [ - {file = "lightning-2.2.1-py3-none-any.whl", hash = "sha256:fec9b49d29a6019e8fe49e825082bab8d5ea3fde8e4b36dcf5c8896c2bdb86c3"}, - {file = "lightning-2.2.1.tar.gz", hash = "sha256:b3e46d596b32cafd1fb9b21fdba1b1767df97b1af5cc702693d1c51df60b19aa"}, + {file = "lightning-2.3.0.dev20240407-py3-none-any.whl", hash = "sha256:27fa1f37a5ab12b917590f833baeea3e02c3a979f49f9a14bb35e2fd0ae29cfd"}, + {file = "lightning-2.3.0.dev20240407.tar.gz", hash = "sha256:6aab115c1c22a75d359f79db3457d7488682900aa03426a9534ef1a535195310"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index 71ea6fdd..59241a78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ maintainers = [ ] requires-python = ">=3.10" dependencies = [ - "lightning>=2.2.1", + "lightning>=2.2.2", "jsonargparse[omegaconf]>=4.27.4", "tensorboard>=2.16.2", "loguru>=0.7.2", From a421ee28746bf424a4a17ef444a91523622c2bea Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 22 Apr 2024 11:09:06 +0200 Subject: [PATCH 2/3] reformat image segmentation datasets --- configs/vision/dino_mvit2/offline/bach.yaml | 113 ++++++++++++++++ configs/vision/dino_mvit2/offline/crc.yaml | 111 ++++++++++++++++ configs/vision/dino_mvit2/offline/mhist.yaml | 108 +++++++++++++++ .../dino_mvit2/offline/patch_camelyon.yaml | 125 ++++++++++++++++++ .../vision/data/datasets/segmentation/base.py | 59 +++------ .../segmentation/total_segmentator.py | 33 ++--- src/eva/vision/utils/io/nifti.py | 2 +- .../segmentation/test_total_segmentator.py | 10 +- 8 files changed, 497 insertions(+), 64 deletions(-) create mode 100644 configs/vision/dino_mvit2/offline/bach.yaml create mode 100644 configs/vision/dino_mvit2/offline/crc.yaml create mode 100644 configs/vision/dino_mvit2/offline/mhist.yaml create mode 100644 configs/vision/dino_mvit2/offline/patch_camelyon.yaml diff --git a/configs/vision/dino_mvit2/offline/bach.yaml b/configs/vision/dino_mvit2/offline/bach.yaml new file mode 100644 index 00000000..4b74260c --- /dev/null +++ b/configs/vision/dino_mvit2/offline/bach.yaml @@ -0,0 +1,113 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:DINO_BACKBONE, dino_vits16}/offline/bach} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: true + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: 400 + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + - class_path: eva.callbacks.EmbeddingsWriter + init_args: + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:DINO_BACKBONE, dino_vits16}/bach + dataloader_idx_map: + 0: train + 1: val + backbone: + class_path: eva.models.ModelFromFunction + init_args: + path: timm.create_model + arguments: + model_name: ${oc.env:MODEL_NAME, mvitv2_small} + num_classes: 0 + pretrained: ${oc.env:PRETRAINED, false} + checkpoint_path: ${oc.env:CHECKPOINT_PATH, null} + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.HeadModule + init_args: + head: + class_path: torch.nn.Linear + init_args: + in_features: ${oc.env:IN_FEATURES, 768} + out_features: &NUM_CLASSES 4 + criterion: torch.nn.CrossEntropyLoss + optimizer: + class_path: torch.optim.SGD + init_args: + lr: &LR_VALUE ${oc.env:LR_VALUE, 0.000625} + momentum: 0.9 + weight_decay: 0.0 + lr_scheduler: + class_path: torch.optim.lr_scheduler.CosineAnnealingLR + init_args: + T_max: *MAX_STEPS + eta_min: 0.0 + metrics: + common: + - class_path: eva.metrics.AverageLoss + - class_path: eva.metrics.MulticlassClassificationMetrics + init_args: + num_classes: *NUM_CLASSES +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: &DATASET_ARGS + root: *DATASET_EMBEDDINGS_ROOT + manifest_file: manifest.csv + split: train + val: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: + <<: *DATASET_ARGS + split: val + predict: + - class_path: eva.vision.datasets.BACH + init_args: &PREDICT_DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data}/bach + split: train + download: true + # Set `download: true` to download the dataset from https://zenodo.org/records/3632035 + # The BACH dataset is distributed under the following license + # Attribution-NonCommercial-NoDerivs 4.0 International license + # (see: https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode) + image_transforms: + class_path: eva.vision.data.transforms.common.ResizeAndCrop + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} + std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: eva.vision.datasets.BACH + init_args: + <<: *PREDICT_DATASET_ARGS + split: val + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} + shuffle: true + val: + batch_size: *BATCH_SIZE + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} diff --git a/configs/vision/dino_mvit2/offline/crc.yaml b/configs/vision/dino_mvit2/offline/crc.yaml new file mode 100644 index 00000000..9bf3e9fc --- /dev/null +++ b/configs/vision/dino_mvit2/offline/crc.yaml @@ -0,0 +1,111 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:DINO_BACKBONE, dino_vits16}/offline/crc} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: true + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: 24 + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + - class_path: eva.callbacks.EmbeddingsWriter + init_args: + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:DINO_BACKBONE, dino_vits16}/crc + dataloader_idx_map: + 0: train + 1: val + backbone: + class_path: eva.models.ModelFromFunction + init_args: + path: timm.create_model + arguments: + model_name: ${oc.env:MODEL_NAME, mvitv2_small} + num_classes: 0 + pretrained: ${oc.env:PRETRAINED, false} + checkpoint_path: ${oc.env:CHECKPOINT_PATH, null} + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.HeadModule + init_args: + head: + class_path: torch.nn.Linear + init_args: + in_features: ${oc.env:IN_FEATURES, 768} + out_features: &NUM_CLASSES 9 + criterion: torch.nn.CrossEntropyLoss + optimizer: + class_path: torch.optim.SGD + init_args: + lr: &LR_VALUE ${oc.env:LR_VALUE, 0.01} + momentum: 0.9 + weight_decay: 0.0 + lr_scheduler: + class_path: torch.optim.lr_scheduler.CosineAnnealingLR + init_args: + T_max: *MAX_STEPS + eta_min: 0.0 + metrics: + common: + - class_path: eva.metrics.AverageLoss + - class_path: eva.metrics.MulticlassClassificationMetrics + init_args: + num_classes: *NUM_CLASSES +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: &DATASET_ARGS + root: *DATASET_EMBEDDINGS_ROOT + manifest_file: manifest.csv + split: train + val: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: + <<: *DATASET_ARGS + split: val + predict: + - class_path: eva.vision.datasets.CRC + init_args: &PREDICT_DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data}/crc + split: train + download: true + # Set `download: true` to download the dataset from https://zenodo.org/records/1214456 + # The CRC dataset is distributed under the following license: "CC BY 4.0 LEGAL CODE" + # (see: https://creativecommons.org/licenses/by/4.0/legalcode) + image_transforms: + class_path: eva.vision.data.transforms.common.ResizeAndCrop + init_args: + mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} + std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: eva.vision.datasets.CRC + init_args: + <<: *PREDICT_DATASET_ARGS + split: val + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 4096} + shuffle: true + val: + batch_size: *BATCH_SIZE + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} diff --git a/configs/vision/dino_mvit2/offline/mhist.yaml b/configs/vision/dino_mvit2/offline/mhist.yaml new file mode 100644 index 00000000..76f769ab --- /dev/null +++ b/configs/vision/dino_mvit2/offline/mhist.yaml @@ -0,0 +1,108 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:DINO_BACKBONE, dino_vits16}/offline/mhist} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: true + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryAccuracy} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: 51 + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + - class_path: eva.callbacks.EmbeddingsWriter + init_args: + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:DINO_BACKBONE, dino_vits16}/mhist + dataloader_idx_map: + 0: train + 1: test + backbone: + class_path: eva.models.ModelFromFunction + init_args: + path: timm.create_model + arguments: + model_name: ${oc.env:MODEL_NAME, mvitv2_small} + num_classes: 0 + pretrained: ${oc.env:PRETRAINED, false} + checkpoint_path: ${oc.env:CHECKPOINT_PATH, null} + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.HeadModule + init_args: + head: + class_path: torch.nn.Linear + init_args: + in_features: ${oc.env:IN_FEATURES, 768} + out_features: 1 + criterion: torch.nn.BCEWithLogitsLoss + optimizer: + class_path: torch.optim.SGD + init_args: + lr: &LR_VALUE ${oc.env:LR_VALUE, 0.000625} + momentum: 0.9 + weight_decay: 0.0 + lr_scheduler: + class_path: torch.optim.lr_scheduler.CosineAnnealingLR + init_args: + T_max: *MAX_STEPS + eta_min: 0.0 + metrics: + common: + - class_path: eva.metrics.AverageLoss + - class_path: eva.metrics.BinaryClassificationMetrics +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: &DATASET_ARGS + root: *DATASET_EMBEDDINGS_ROOT + manifest_file: manifest.csv + split: train + target_transforms: + class_path: eva.core.data.transforms.ArrayToFloatTensor + val: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: + <<: *DATASET_ARGS + split: test + predict: + - class_path: eva.vision.datasets.MHIST + init_args: &PREDICT_DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data}/mhist + split: train + image_transforms: + class_path: eva.vision.data.transforms.common.ResizeAndCrop + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} + std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: eva.vision.datasets.MHIST + init_args: + <<: *PREDICT_DATASET_ARGS + split: test + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} + shuffle: true + val: + batch_size: *BATCH_SIZE + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} diff --git a/configs/vision/dino_mvit2/offline/patch_camelyon.yaml b/configs/vision/dino_mvit2/offline/patch_camelyon.yaml new file mode 100644 index 00000000..47c0db11 --- /dev/null +++ b/configs/vision/dino_mvit2/offline/patch_camelyon.yaml @@ -0,0 +1,125 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:DINO_BACKBONE, dino_vits16}/offline/patch_camelyon} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: true + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryAccuracy} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: 9 + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + - class_path: eva.callbacks.EmbeddingsWriter + init_args: + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:DINO_BACKBONE, dino_vits16}/patch_camelyon + dataloader_idx_map: + 0: train + 1: val + 2: test + backbone: + class_path: eva.models.ModelFromFunction + init_args: + path: timm.create_model + arguments: + model_name: ${oc.env:MODEL_NAME, mvitv2_small} + num_classes: 0 + pretrained: ${oc.env:PRETRAINED, false} + checkpoint_path: ${oc.env:CHECKPOINT_PATH, null} + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.HeadModule + init_args: + head: + class_path: torch.nn.Linear + init_args: + in_features: ${oc.env:IN_FEATURES, 768} + out_features: 1 + criterion: torch.nn.BCEWithLogitsLoss + optimizer: + class_path: torch.optim.SGD + init_args: + lr: &LR_VALUE ${oc.env:LR_VALUE, 0.01} + momentum: 0.9 + weight_decay: 0.0 + lr_scheduler: + class_path: torch.optim.lr_scheduler.CosineAnnealingLR + init_args: + T_max: *MAX_STEPS + eta_min: 0.0 + metrics: + common: + - class_path: eva.metrics.AverageLoss + - class_path: eva.metrics.BinaryClassificationMetrics +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: &DATASET_ARGS + root: *DATASET_EMBEDDINGS_ROOT + manifest_file: manifest.csv + split: train + target_transforms: + class_path: eva.core.data.transforms.ArrayToFloatTensor + val: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: + <<: *DATASET_ARGS + split: val + test: + class_path: eva.datasets.EmbeddingsClassificationDataset + init_args: + <<: *DATASET_ARGS + split: test + predict: + - class_path: eva.vision.datasets.PatchCamelyon + init_args: &PREDICT_DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data}/patch_camelyon + split: train + download: true + # Set `download: true` to download the dataset from https://zenodo.org/records/1494286 + # The PatchCamelyon dataset is distributed under the following license: + # "Creative Commons Zero v1.0 Universal" + # (see: https://choosealicense.com/licenses/cc0-1.0/) + image_transforms: + class_path: eva.vision.data.transforms.common.ResizeAndCrop + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} + std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: eva.vision.datasets.PatchCamelyon + init_args: + <<: *PREDICT_DATASET_ARGS + split: val + - class_path: eva.vision.datasets.PatchCamelyon + init_args: + <<: *PREDICT_DATASET_ARGS + split: test + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 4096} + shuffle: true + val: + batch_size: *BATCH_SIZE + test: + batch_size: *BATCH_SIZE + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} diff --git a/src/eva/vision/data/datasets/segmentation/base.py b/src/eva/vision/data/datasets/segmentation/base.py index f8ebaf64..d6fd5264 100644 --- a/src/eva/vision/data/datasets/segmentation/base.py +++ b/src/eva/vision/data/datasets/segmentation/base.py @@ -3,38 +3,28 @@ import abc from typing import Any, Callable, Dict, List, Tuple -import numpy as np +from torchvision import tv_tensors from typing_extensions import override from eva.vision.data.datasets import vision -class ImageSegmentation(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc.ABC): +class ImageSegmentation(vision.VisionDataset[Tuple[tv_tensors.Image, tv_tensors.Mask]], abc.ABC): """Image segmentation abstract dataset.""" def __init__( self, - image_transforms: Callable | None = None, - target_transforms: Callable | None = None, - image_target_transforms: Callable | None = None, + transforms: Callable | None = None, ) -> None: """Initializes the image segmentation base class. Args: - image_transforms: A function/transform that takes in an image - and returns a transformed version. - target_transforms: A function/transform that takes in the target - and transforms it. - image_target_transforms: A function/transforms that takes in an + transforms: A function/transforms that takes in an image and a label and returns the transformed versions of both. - This transform happens after the `image_transforms` and - `target_transforms`. """ super().__init__() - self._image_transforms = image_transforms - self._target_transforms = target_transforms - self._image_target_transforms = image_target_transforms + self._transforms = transforms @property def classes(self) -> List[str] | None: @@ -56,25 +46,26 @@ def load_metadata(self, index: int | None) -> Dict[str, Any] | List[Dict[str, An """ @abc.abstractmethod - def load_image(self, index: int) -> np.ndarray: + def load_image(self, index: int) -> tv_tensors.Image: """Loads and returns the `index`'th image sample. Args: index: The index of the data sample to load. Returns: - The image as a numpy array. + An image torchvision tensor (channels, height, width). """ @abc.abstractmethod - def load_mask(self, index: int) -> np.ndarray: - """Returns the `index`'th target mask sample. + def load_masks(self, index: int) -> tv_tensors.Mask: + """Returns the `index`'th target masks sample. Args: - index: The index of the data sample target mask to load. + index: The index of the data sample target masks to load. Returns: - The sample mask as a stack of binary mask arrays (label, height, width). + The sample masks as a stack of binary torchvision mask + tensors (label, height, width). """ @abc.abstractmethod @@ -83,30 +74,24 @@ def __len__(self) -> int: raise NotImplementedError @override - def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]: + def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask]: image = self.load_image(index) - mask = self.load_mask(index) - return self._apply_transforms(image, mask) + masks = self.load_masks(index) + return self._apply_transforms(image, masks) def _apply_transforms( - self, image: np.ndarray, target: np.ndarray - ) -> Tuple[np.ndarray, np.ndarray]: + self, image: tv_tensors.Image, masks: tv_tensors.Mask + ) -> Tuple[tv_tensors.Image, tv_tensors.Mask]: """Applies the transforms to the provided data and returns them. Args: image: The desired image. - target: The target of the image. + masks: The target masks of the image. Returns: - A tuple with the image and the target transformed. + A tuple with the image and the masks transformed. """ - if self._image_transforms is not None: - image = self._image_transforms(image) + if self._transforms is not None: + image, masks = self._transforms(image, masks) - if self._target_transforms is not None: - target = self._target_transforms(target) - - if self._image_target_transforms is not None: - image, target = self._image_target_transforms(image, target) - - return image, target + return image, masks diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index 7b291b9f..261479a5 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -6,6 +6,7 @@ from typing import Callable, Dict, List, Literal, Tuple import numpy as np +from torchvision import tv_tensors from torchvision.datasets import utils from typing_extensions import override @@ -50,9 +51,7 @@ def __init__( split: Literal["train", "val"] | None, version: Literal["small", "full"] = "small", download: bool = False, - image_transforms: Callable | None = None, - target_transforms: Callable | None = None, - image_target_transforms: Callable | None = None, + transforms: Callable | None = None, ) -> None: """Initialize dataset. @@ -65,20 +64,10 @@ def __init__( Note that the download will be executed only by additionally calling the :meth:`prepare_data` method and if the data does not exist yet on disk. - image_transforms: A function/transform that takes in an image - and returns a transformed version. - target_transforms: A function/transform that takes in the target - and transforms it. - image_target_transforms: A function/transforms that takes in an - image and a label and returns the transformed versions of both. - This transform happens after the `image_transforms` and - `target_transforms`. + transforms: A function/transforms that takes in an image and a target + mask and returns the transformed versions of both. """ - super().__init__( - image_transforms=image_transforms, - target_transforms=target_transforms, - image_target_transforms=image_target_transforms, - ) + super().__init__(transforms=transforms) self._root = root self._split = split @@ -134,19 +123,21 @@ def __len__(self) -> int: return len(self._indices) * self._n_slices_per_image @override - def load_image(self, index: int) -> np.ndarray: + def load_image(self, index: int) -> tv_tensors.Image: image_path = self._get_image_path(index) slice_index = self._get_sample_slice_index(index) image_array = io.read_nifti_slice(image_path, slice_index) - return image_array.repeat(3, axis=2) + image_rgb_array = image_array.repeat(3, axis=2) + return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1)) @override - def load_mask(self, index: int) -> np.ndarray: + def load_masks(self, index: int) -> tv_tensors.Mask: masks_dir = self._get_masks_dir(index) slice_index = self._get_sample_slice_index(index) mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes) - masks = [io.read_nifti_slice(path, slice_index) for path in mask_paths] - return np.concatenate(masks, axis=-1) + list_of_mask_arrays = [io.read_nifti_slice(path, slice_index) for path in mask_paths] + masks = np.concatenate(list_of_mask_arrays, axis=-1) + return tv_tensors.Mask(masks.transpose(2, 0, 1)) def _get_masks_dir(self, index: int) -> str: """Returns the directory of the corresponding masks.""" diff --git a/src/eva/vision/utils/io/nifti.py b/src/eva/vision/utils/io/nifti.py index f4884079..7fc771fb 100644 --- a/src/eva/vision/utils/io/nifti.py +++ b/src/eva/vision/utils/io/nifti.py @@ -17,7 +17,7 @@ def read_nifti_slice(path: str, slice_index: int) -> npt.NDArray[Any]: return the full 3D image. Returns: - The image as a numpy array. + The image as a numpy array (height, width, channels). Raises: FileExistsError: If the path does not exist or it is unreachable. diff --git a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py index eb40b94f..2e8f3abe 100644 --- a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py +++ b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py @@ -3,8 +3,8 @@ import os from typing import Literal -import numpy as np import pytest +from torchvision import tv_tensors from eva.vision.data import datasets @@ -35,10 +35,10 @@ def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: i assert len(sample) == 2 # assert the format of the `image` and `mask` image, mask = sample - assert isinstance(image, np.ndarray) - assert image.shape == (16, 16, 3) - assert isinstance(mask, np.ndarray) - assert mask.shape == (16, 16, 3) + assert isinstance(image, tv_tensors.Image) + assert image.shape == (3, 16, 16) + assert isinstance(mask, tv_tensors.Mask) + assert mask.shape == (3, 16, 16) @pytest.fixture(scope="function") From bc57f9e4e2dddbd621cc4e2613b4f5a6edc2c61b Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 22 Apr 2024 11:11:24 +0200 Subject: [PATCH 3/3] remove dev files --- configs/vision/dino_mvit2/offline/bach.yaml | 113 ---------------- configs/vision/dino_mvit2/offline/crc.yaml | 111 ---------------- configs/vision/dino_mvit2/offline/mhist.yaml | 108 --------------- .../dino_mvit2/offline/patch_camelyon.yaml | 125 ------------------ 4 files changed, 457 deletions(-) delete mode 100644 configs/vision/dino_mvit2/offline/bach.yaml delete mode 100644 configs/vision/dino_mvit2/offline/crc.yaml delete mode 100644 configs/vision/dino_mvit2/offline/mhist.yaml delete mode 100644 configs/vision/dino_mvit2/offline/patch_camelyon.yaml diff --git a/configs/vision/dino_mvit2/offline/bach.yaml b/configs/vision/dino_mvit2/offline/bach.yaml deleted file mode 100644 index 4b74260c..00000000 --- a/configs/vision/dino_mvit2/offline/bach.yaml +++ /dev/null @@ -1,113 +0,0 @@ ---- -trainer: - class_path: eva.Trainer - init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 5} - default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:DINO_BACKBONE, dino_vits16}/offline/bach} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: epoch - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - filename: best - save_last: true - save_top_k: 1 - monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} - mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - min_delta: 0 - patience: 400 - monitor: *MONITOR_METRIC - mode: *MONITOR_METRIC_MODE - - class_path: eva.callbacks.EmbeddingsWriter - init_args: - output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:DINO_BACKBONE, dino_vits16}/bach - dataloader_idx_map: - 0: train - 1: val - backbone: - class_path: eva.models.ModelFromFunction - init_args: - path: timm.create_model - arguments: - model_name: ${oc.env:MODEL_NAME, mvitv2_small} - num_classes: 0 - pretrained: ${oc.env:PRETRAINED, false} - checkpoint_path: ${oc.env:CHECKPOINT_PATH, null} - logger: - - class_path: lightning.pytorch.loggers.TensorBoardLogger - init_args: - save_dir: *OUTPUT_ROOT - name: "" -model: - class_path: eva.HeadModule - init_args: - head: - class_path: torch.nn.Linear - init_args: - in_features: ${oc.env:IN_FEATURES, 768} - out_features: &NUM_CLASSES 4 - criterion: torch.nn.CrossEntropyLoss - optimizer: - class_path: torch.optim.SGD - init_args: - lr: &LR_VALUE ${oc.env:LR_VALUE, 0.000625} - momentum: 0.9 - weight_decay: 0.0 - lr_scheduler: - class_path: torch.optim.lr_scheduler.CosineAnnealingLR - init_args: - T_max: *MAX_STEPS - eta_min: 0.0 - metrics: - common: - - class_path: eva.metrics.AverageLoss - - class_path: eva.metrics.MulticlassClassificationMetrics - init_args: - num_classes: *NUM_CLASSES -data: - class_path: eva.DataModule - init_args: - datasets: - train: - class_path: eva.datasets.EmbeddingsClassificationDataset - init_args: &DATASET_ARGS - root: *DATASET_EMBEDDINGS_ROOT - manifest_file: manifest.csv - split: train - val: - class_path: eva.datasets.EmbeddingsClassificationDataset - init_args: - <<: *DATASET_ARGS - split: val - predict: - - class_path: eva.vision.datasets.BACH - init_args: &PREDICT_DATASET_ARGS - root: ${oc.env:DATA_ROOT, ./data}/bach - split: train - download: true - # Set `download: true` to download the dataset from https://zenodo.org/records/3632035 - # The BACH dataset is distributed under the following license - # Attribution-NonCommercial-NoDerivs 4.0 International license - # (see: https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode) - image_transforms: - class_path: eva.vision.data.transforms.common.ResizeAndCrop - init_args: - size: ${oc.env:RESIZE_DIM, 224} - mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} - std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} - - class_path: eva.vision.datasets.BACH - init_args: - <<: *PREDICT_DATASET_ARGS - split: val - dataloaders: - train: - batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} - shuffle: true - val: - batch_size: *BATCH_SIZE - predict: - batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} diff --git a/configs/vision/dino_mvit2/offline/crc.yaml b/configs/vision/dino_mvit2/offline/crc.yaml deleted file mode 100644 index 9bf3e9fc..00000000 --- a/configs/vision/dino_mvit2/offline/crc.yaml +++ /dev/null @@ -1,111 +0,0 @@ ---- -trainer: - class_path: eva.Trainer - init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 5} - default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:DINO_BACKBONE, dino_vits16}/offline/crc} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: epoch - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - filename: best - save_last: true - save_top_k: 1 - monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} - mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - min_delta: 0 - patience: 24 - monitor: *MONITOR_METRIC - mode: *MONITOR_METRIC_MODE - - class_path: eva.callbacks.EmbeddingsWriter - init_args: - output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:DINO_BACKBONE, dino_vits16}/crc - dataloader_idx_map: - 0: train - 1: val - backbone: - class_path: eva.models.ModelFromFunction - init_args: - path: timm.create_model - arguments: - model_name: ${oc.env:MODEL_NAME, mvitv2_small} - num_classes: 0 - pretrained: ${oc.env:PRETRAINED, false} - checkpoint_path: ${oc.env:CHECKPOINT_PATH, null} - logger: - - class_path: lightning.pytorch.loggers.TensorBoardLogger - init_args: - save_dir: *OUTPUT_ROOT - name: "" -model: - class_path: eva.HeadModule - init_args: - head: - class_path: torch.nn.Linear - init_args: - in_features: ${oc.env:IN_FEATURES, 768} - out_features: &NUM_CLASSES 9 - criterion: torch.nn.CrossEntropyLoss - optimizer: - class_path: torch.optim.SGD - init_args: - lr: &LR_VALUE ${oc.env:LR_VALUE, 0.01} - momentum: 0.9 - weight_decay: 0.0 - lr_scheduler: - class_path: torch.optim.lr_scheduler.CosineAnnealingLR - init_args: - T_max: *MAX_STEPS - eta_min: 0.0 - metrics: - common: - - class_path: eva.metrics.AverageLoss - - class_path: eva.metrics.MulticlassClassificationMetrics - init_args: - num_classes: *NUM_CLASSES -data: - class_path: eva.DataModule - init_args: - datasets: - train: - class_path: eva.datasets.EmbeddingsClassificationDataset - init_args: &DATASET_ARGS - root: *DATASET_EMBEDDINGS_ROOT - manifest_file: manifest.csv - split: train - val: - class_path: eva.datasets.EmbeddingsClassificationDataset - init_args: - <<: *DATASET_ARGS - split: val - predict: - - class_path: eva.vision.datasets.CRC - init_args: &PREDICT_DATASET_ARGS - root: ${oc.env:DATA_ROOT, ./data}/crc - split: train - download: true - # Set `download: true` to download the dataset from https://zenodo.org/records/1214456 - # The CRC dataset is distributed under the following license: "CC BY 4.0 LEGAL CODE" - # (see: https://creativecommons.org/licenses/by/4.0/legalcode) - image_transforms: - class_path: eva.vision.data.transforms.common.ResizeAndCrop - init_args: - mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} - std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} - - class_path: eva.vision.datasets.CRC - init_args: - <<: *PREDICT_DATASET_ARGS - split: val - dataloaders: - train: - batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 4096} - shuffle: true - val: - batch_size: *BATCH_SIZE - predict: - batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} diff --git a/configs/vision/dino_mvit2/offline/mhist.yaml b/configs/vision/dino_mvit2/offline/mhist.yaml deleted file mode 100644 index 76f769ab..00000000 --- a/configs/vision/dino_mvit2/offline/mhist.yaml +++ /dev/null @@ -1,108 +0,0 @@ ---- -trainer: - class_path: eva.Trainer - init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 5} - default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:DINO_BACKBONE, dino_vits16}/offline/mhist} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: epoch - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - filename: best - save_last: true - save_top_k: 1 - monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryAccuracy} - mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - min_delta: 0 - patience: 51 - monitor: *MONITOR_METRIC - mode: *MONITOR_METRIC_MODE - - class_path: eva.callbacks.EmbeddingsWriter - init_args: - output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:DINO_BACKBONE, dino_vits16}/mhist - dataloader_idx_map: - 0: train - 1: test - backbone: - class_path: eva.models.ModelFromFunction - init_args: - path: timm.create_model - arguments: - model_name: ${oc.env:MODEL_NAME, mvitv2_small} - num_classes: 0 - pretrained: ${oc.env:PRETRAINED, false} - checkpoint_path: ${oc.env:CHECKPOINT_PATH, null} - logger: - - class_path: lightning.pytorch.loggers.TensorBoardLogger - init_args: - save_dir: *OUTPUT_ROOT - name: "" -model: - class_path: eva.HeadModule - init_args: - head: - class_path: torch.nn.Linear - init_args: - in_features: ${oc.env:IN_FEATURES, 768} - out_features: 1 - criterion: torch.nn.BCEWithLogitsLoss - optimizer: - class_path: torch.optim.SGD - init_args: - lr: &LR_VALUE ${oc.env:LR_VALUE, 0.000625} - momentum: 0.9 - weight_decay: 0.0 - lr_scheduler: - class_path: torch.optim.lr_scheduler.CosineAnnealingLR - init_args: - T_max: *MAX_STEPS - eta_min: 0.0 - metrics: - common: - - class_path: eva.metrics.AverageLoss - - class_path: eva.metrics.BinaryClassificationMetrics -data: - class_path: eva.DataModule - init_args: - datasets: - train: - class_path: eva.datasets.EmbeddingsClassificationDataset - init_args: &DATASET_ARGS - root: *DATASET_EMBEDDINGS_ROOT - manifest_file: manifest.csv - split: train - target_transforms: - class_path: eva.core.data.transforms.ArrayToFloatTensor - val: - class_path: eva.datasets.EmbeddingsClassificationDataset - init_args: - <<: *DATASET_ARGS - split: test - predict: - - class_path: eva.vision.datasets.MHIST - init_args: &PREDICT_DATASET_ARGS - root: ${oc.env:DATA_ROOT, ./data}/mhist - split: train - image_transforms: - class_path: eva.vision.data.transforms.common.ResizeAndCrop - init_args: - size: ${oc.env:RESIZE_DIM, 224} - mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} - std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} - - class_path: eva.vision.datasets.MHIST - init_args: - <<: *PREDICT_DATASET_ARGS - split: test - dataloaders: - train: - batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} - shuffle: true - val: - batch_size: *BATCH_SIZE - predict: - batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} diff --git a/configs/vision/dino_mvit2/offline/patch_camelyon.yaml b/configs/vision/dino_mvit2/offline/patch_camelyon.yaml deleted file mode 100644 index 47c0db11..00000000 --- a/configs/vision/dino_mvit2/offline/patch_camelyon.yaml +++ /dev/null @@ -1,125 +0,0 @@ ---- -trainer: - class_path: eva.Trainer - init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 5} - default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:DINO_BACKBONE, dino_vits16}/offline/patch_camelyon} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: epoch - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - filename: best - save_last: true - save_top_k: 1 - monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryAccuracy} - mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - min_delta: 0 - patience: 9 - monitor: *MONITOR_METRIC - mode: *MONITOR_METRIC_MODE - - class_path: eva.callbacks.EmbeddingsWriter - init_args: - output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:DINO_BACKBONE, dino_vits16}/patch_camelyon - dataloader_idx_map: - 0: train - 1: val - 2: test - backbone: - class_path: eva.models.ModelFromFunction - init_args: - path: timm.create_model - arguments: - model_name: ${oc.env:MODEL_NAME, mvitv2_small} - num_classes: 0 - pretrained: ${oc.env:PRETRAINED, false} - checkpoint_path: ${oc.env:CHECKPOINT_PATH, null} - logger: - - class_path: lightning.pytorch.loggers.TensorBoardLogger - init_args: - save_dir: *OUTPUT_ROOT - name: "" -model: - class_path: eva.HeadModule - init_args: - head: - class_path: torch.nn.Linear - init_args: - in_features: ${oc.env:IN_FEATURES, 768} - out_features: 1 - criterion: torch.nn.BCEWithLogitsLoss - optimizer: - class_path: torch.optim.SGD - init_args: - lr: &LR_VALUE ${oc.env:LR_VALUE, 0.01} - momentum: 0.9 - weight_decay: 0.0 - lr_scheduler: - class_path: torch.optim.lr_scheduler.CosineAnnealingLR - init_args: - T_max: *MAX_STEPS - eta_min: 0.0 - metrics: - common: - - class_path: eva.metrics.AverageLoss - - class_path: eva.metrics.BinaryClassificationMetrics -data: - class_path: eva.DataModule - init_args: - datasets: - train: - class_path: eva.datasets.EmbeddingsClassificationDataset - init_args: &DATASET_ARGS - root: *DATASET_EMBEDDINGS_ROOT - manifest_file: manifest.csv - split: train - target_transforms: - class_path: eva.core.data.transforms.ArrayToFloatTensor - val: - class_path: eva.datasets.EmbeddingsClassificationDataset - init_args: - <<: *DATASET_ARGS - split: val - test: - class_path: eva.datasets.EmbeddingsClassificationDataset - init_args: - <<: *DATASET_ARGS - split: test - predict: - - class_path: eva.vision.datasets.PatchCamelyon - init_args: &PREDICT_DATASET_ARGS - root: ${oc.env:DATA_ROOT, ./data}/patch_camelyon - split: train - download: true - # Set `download: true` to download the dataset from https://zenodo.org/records/1494286 - # The PatchCamelyon dataset is distributed under the following license: - # "Creative Commons Zero v1.0 Universal" - # (see: https://choosealicense.com/licenses/cc0-1.0/) - image_transforms: - class_path: eva.vision.data.transforms.common.ResizeAndCrop - init_args: - size: ${oc.env:RESIZE_DIM, 224} - mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} - std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} - - class_path: eva.vision.datasets.PatchCamelyon - init_args: - <<: *PREDICT_DATASET_ARGS - split: val - - class_path: eva.vision.datasets.PatchCamelyon - init_args: - <<: *PREDICT_DATASET_ARGS - split: test - dataloaders: - train: - batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 4096} - shuffle: true - val: - batch_size: *BATCH_SIZE - test: - batch_size: *BATCH_SIZE - predict: - batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64}