diff --git a/src/eva/core/data/splitting/random.py b/src/eva/core/data/splitting/random.py index 2a471c04..274a1412 100644 --- a/src/eva/core/data/splitting/random.py +++ b/src/eva/core/data/splitting/random.py @@ -27,9 +27,9 @@ def random_split( if train_ratio + val_ratio + (test_ratio or 0) != 1: raise ValueError("The sum of the ratios must be equal to 1.") - np.random.seed(seed) + random_generator = np.random.default_rng(seed) n_samples = len(samples) - indices = np.random.permutation(n_samples) + indices = random_generator.permutation(n_samples) n_train = int(np.floor(train_ratio * n_samples)) n_val = n_samples - n_train if test_ratio == 0.0 else int(np.floor(val_ratio * n_samples)) or 1 diff --git a/src/eva/vision/data/datasets/segmentation/lits.py b/src/eva/vision/data/datasets/segmentation/lits.py index 74f8a9ae..504f4f0f 100644 --- a/src/eva/vision/data/datasets/segmentation/lits.py +++ b/src/eva/vision/data/datasets/segmentation/lits.py @@ -2,16 +2,15 @@ import functools import glob -import math import os from typing import Any, Callable, Dict, List, Literal, Tuple -import numpy as np import torch from torchvision import tv_tensors from typing_extensions import override from eva.core import utils +from eva.core.data import splitting from eva.vision.data.datasets import _validators from eva.vision.data.datasets.segmentation import base from eva.vision.utils import io @@ -23,18 +22,18 @@ class LiTS(base.ImageSegmentation): Webpage: https://competitions.codalab.org/competitions/17094 """ - _train_frac: float = 0.7 - _val_frac: float = 0.15 - _test_frac: float = 0.15 + _train_ratio: float = 0.7 + _val_ratio: float = 0.15 + _test_ratio: float = 0.15 """Index ranges per split.""" _sample_every_n_slices: int | None = None """The amount of slices to sub-sample per 3D CT scan image.""" _expected_dataset_lengths: Dict[str | None, int] = { - "train": 39531, - "val": 11191, - "test": 7916, + "train": 38686, + "val": 11192, + "test": 8760, None: 58638, } """Dataset version and split to the expected size.""" @@ -67,8 +66,6 @@ def __init__( self._root = root self._split = split self._seed = seed - self._random_generator = np.random.default_rng(seed=self._seed) - self._indices: List[Tuple[int, int]] = [] @property @@ -164,13 +161,14 @@ def _create_indices(self) -> List[Tuple[int, int]]: def _get_split_indices(self) -> List[int]: """Returns the sample indices for the specified dataset split.""" - indices = self._random_generator.permutation(len(self._volume_files)) - n_train = math.ceil(self._train_frac * len(self._volume_files)) - n_val = math.ceil(self._val_frac * len(self._volume_files)) + indices = list(range(len(self._volume_files))) + train_indices, val_indices, test_indices = splitting.random_split( + indices, self._train_ratio, self._val_ratio, self._test_ratio, seed=self._seed + ) split_indices_dict = { - "train": indices[:n_train], - "val": indices[n_train : n_train + n_val], - "test": indices[n_train + n_val :], + "train": train_indices, + "val": val_indices, + "test": test_indices, None: indices, } if self._split not in split_indices_dict: diff --git a/src/eva/vision/data/datasets/segmentation/lits_balanced.py b/src/eva/vision/data/datasets/segmentation/lits_balanced.py index 467c6366..43227252 100644 --- a/src/eva/vision/data/datasets/segmentation/lits_balanced.py +++ b/src/eva/vision/data/datasets/segmentation/lits_balanced.py @@ -21,9 +21,9 @@ class LiTSBalanced(lits.LiTS): """ _expected_dataset_lengths: Dict[str | None, int] = { - "train": 5602, - "val": 1516, - "test": 1258, + "train": 5514, + "val": 1332, + "test": 1530, None: 8376, } """Dataset version and split to the expected size.""" @@ -58,6 +58,7 @@ def _create_indices(self) -> List[Tuple[int, int]]: """ split_indices = set(self._get_split_indices()) indices: List[Tuple[int, int]] = [] + random_generator = np.random.default_rng(seed=self._seed) for sample_idx in range(len(self._volume_files)): if sample_idx not in split_indices: @@ -79,12 +80,12 @@ def _create_indices(self) -> List[Tuple[int, int]]: n_slice_samples = min(liver_and_tumor_filter.sum(), liver_only_filter.sum()) tumor_indices = list(np.where(liver_and_tumor_filter)[0]) tumor_indices = list( - self._random_generator.choice(tumor_indices, size=n_slice_samples, replace=False) + random_generator.choice(tumor_indices, size=n_slice_samples, replace=False) ) liver_indices = list(np.where(liver_only_filter)[0]) liver_indices = list( - self._random_generator.choice(liver_indices, size=n_slice_samples, replace=False) + random_generator.choice(liver_indices, size=n_slice_samples, replace=False) ) indices.extend([(sample_idx, slice_idx) for slice_idx in tumor_indices + liver_indices])