Skip to content

Commit

Permalink
use splitting module
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig committed Oct 1, 2024
1 parent 7006cd4 commit 855ac79
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/eva/core/data/splitting/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def random_split(
if train_ratio + val_ratio + (test_ratio or 0) != 1:
raise ValueError("The sum of the ratios must be equal to 1.")

np.random.seed(seed)
random_generator = np.random.default_rng(seed)
n_samples = len(samples)
indices = np.random.permutation(n_samples)
indices = random_generator.permutation(n_samples)

n_train = int(np.floor(train_ratio * n_samples))
n_val = n_samples - n_train if test_ratio == 0.0 else int(np.floor(val_ratio * n_samples)) or 1
Expand Down
30 changes: 14 additions & 16 deletions src/eva/vision/data/datasets/segmentation/lits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

import functools
import glob
import math
import os
from typing import Any, Callable, Dict, List, Literal, Tuple

import numpy as np
import torch
from torchvision import tv_tensors
from typing_extensions import override

from eva.core import utils
from eva.core.data import splitting
from eva.vision.data.datasets import _validators
from eva.vision.data.datasets.segmentation import base
from eva.vision.utils import io
Expand All @@ -23,18 +22,18 @@ class LiTS(base.ImageSegmentation):
Webpage: https://competitions.codalab.org/competitions/17094
"""

_train_frac: float = 0.7
_val_frac: float = 0.15
_test_frac: float = 0.15
_train_ratio: float = 0.7
_val_ratio: float = 0.15
_test_ratio: float = 0.15
"""Index ranges per split."""

_sample_every_n_slices: int | None = None
"""The amount of slices to sub-sample per 3D CT scan image."""

_expected_dataset_lengths: Dict[str | None, int] = {
"train": 39531,
"val": 11191,
"test": 7916,
"train": 38686,
"val": 11192,
"test": 8760,
None: 58638,
}
"""Dataset version and split to the expected size."""
Expand Down Expand Up @@ -67,8 +66,6 @@ def __init__(
self._root = root
self._split = split
self._seed = seed
self._random_generator = np.random.default_rng(seed=self._seed)

self._indices: List[Tuple[int, int]] = []

@property
Expand Down Expand Up @@ -164,13 +161,14 @@ def _create_indices(self) -> List[Tuple[int, int]]:

def _get_split_indices(self) -> List[int]:
"""Returns the sample indices for the specified dataset split."""
indices = self._random_generator.permutation(len(self._volume_files))
n_train = math.ceil(self._train_frac * len(self._volume_files))
n_val = math.ceil(self._val_frac * len(self._volume_files))
indices = list(range(len(self._volume_files)))
train_indices, val_indices, test_indices = splitting.random_split(
indices, self._train_ratio, self._val_ratio, self._test_ratio, seed=self._seed
)
split_indices_dict = {
"train": indices[:n_train],
"val": indices[n_train : n_train + n_val],
"test": indices[n_train + n_val :],
"train": train_indices,
"val": val_indices,
"test": test_indices,
None: indices,
}
if self._split not in split_indices_dict:
Expand Down
11 changes: 6 additions & 5 deletions src/eva/vision/data/datasets/segmentation/lits_balanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class LiTSBalanced(lits.LiTS):
"""

_expected_dataset_lengths: Dict[str | None, int] = {
"train": 5602,
"val": 1516,
"test": 1258,
"train": 5514,
"val": 1332,
"test": 1530,
None: 8376,
}
"""Dataset version and split to the expected size."""
Expand Down Expand Up @@ -58,6 +58,7 @@ def _create_indices(self) -> List[Tuple[int, int]]:
"""
split_indices = set(self._get_split_indices())
indices: List[Tuple[int, int]] = []
random_generator = np.random.default_rng(seed=self._seed)

for sample_idx in range(len(self._volume_files)):
if sample_idx not in split_indices:
Expand All @@ -79,12 +80,12 @@ def _create_indices(self) -> List[Tuple[int, int]]:
n_slice_samples = min(liver_and_tumor_filter.sum(), liver_only_filter.sum())
tumor_indices = list(np.where(liver_and_tumor_filter)[0])
tumor_indices = list(
self._random_generator.choice(tumor_indices, size=n_slice_samples, replace=False)
random_generator.choice(tumor_indices, size=n_slice_samples, replace=False)
)

liver_indices = list(np.where(liver_only_filter)[0])
liver_indices = list(
self._random_generator.choice(liver_indices, size=n_slice_samples, replace=False)
random_generator.choice(liver_indices, size=n_slice_samples, replace=False)
)

indices.extend([(sample_idx, slice_idx) for slice_idx in tumor_indices + liver_indices])
Expand Down

0 comments on commit 855ac79

Please sign in to comment.