Skip to content

Commit

Permalink
add option save coords file to wsi dataset classes
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig committed Oct 9, 2024
1 parent a5c8f18 commit c41872d
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ data:
height: 224
target_mpp: 0.25
split: train
coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv
image_transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ data:
height: 224
target_mpp: 0.25
split: train
coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv
image_transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
Expand Down
1 change: 1 addition & 0 deletions configs/vision/pathology/offline/classification/panda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ data:
height: 224
target_mpp: 0.5
split: train
coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv
image_transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ data:
height: 224
target_mpp: 0.5
split: train
coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv
image_transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
Expand Down
7 changes: 3 additions & 4 deletions src/eva/core/callbacks/writers/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,14 @@ def _get_item_metadata(

def _check_if_exists(self) -> None:
"""Checks if the output directory already exists and if it should be overwritten."""
try:
os.makedirs(self._output_dir, exist_ok=self._overwrite)
except FileExistsError as e:
os.makedirs(self._output_dir, exist_ok=True)
if os.path.exists(os.path.join(self._output_dir, "manifest.csv")) and not self._overwrite:
raise FileExistsError(
f"The embeddings output directory already exists: {self._output_dir}. This "
"either means that they have been computed before or that a wrong output "
"directory is being used. Consider using `eva fit` instead, selecting a "
"different output directory or setting overwrite=True."
) from e
)
os.makedirs(self._output_dir, exist_ok=True)


Expand Down
7 changes: 4 additions & 3 deletions src/eva/vision/data/datasets/classification/camelyon16.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
target_mpp: float = 0.5,
backend: str = "openslide",
image_transforms: Callable | None = None,
coords_path: str | None = None,
seed: int = 42,
) -> None:
"""Initializes the dataset.
Expand All @@ -100,6 +101,7 @@ def __init__(
target_mpp: Target microns per pixel (mpp) for the patches.
backend: The backend to use for reading the whole-slide images.
image_transforms: Transforms to apply to the extracted image patches.
coords_path: File path to save the patch coordinates as .csv.
seed: Random seed for reproducibility.
"""
self._split = split
Expand All @@ -119,6 +121,7 @@ def __init__(
target_mpp=target_mpp,
backend=backend,
image_transforms=image_transforms,
coords_path=coords_path,
)

@property
Expand Down Expand Up @@ -207,9 +210,7 @@ def load_target(self, index: int) -> torch.Tensor:

@override
def load_metadata(self, index: int) -> Dict[str, Any]:
dataset_index, sample_index = self._get_dataset_idx(index), self._get_sample_idx(index)
patch_metadata = self.datasets[dataset_index].load_metadata(sample_index)
return {"wsi_id": self.filename(index).split(".")[0]} | patch_metadata
return wsi.MultiWsiDataset.load_metadata(self, index)

def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
"""Loads the file paths of the corresponding dataset split."""
Expand Down
7 changes: 4 additions & 3 deletions src/eva/vision/data/datasets/classification/panda.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
target_mpp: float = 0.5,
backend: str = "openslide",
image_transforms: Callable | None = None,
coords_path: str | None = None,
seed: int = 42,
) -> None:
"""Initializes the dataset.
Expand All @@ -62,6 +63,7 @@ def __init__(
target_mpp: Target microns per pixel (mpp) for the patches.
backend: The backend to use for reading the whole-slide images.
image_transforms: Transforms to apply to the extracted image patches.
coords_path: File path to save the patch coordinates as .csv.
seed: Random seed for reproducibility.
"""
self._split = split
Expand All @@ -80,6 +82,7 @@ def __init__(
target_mpp=target_mpp,
backend=backend,
image_transforms=image_transforms,
coords_path=coords_path,
)

@property
Expand Down Expand Up @@ -132,9 +135,7 @@ def load_target(self, index: int) -> torch.Tensor:

@override
def load_metadata(self, index: int) -> Dict[str, Any]:
dataset_index, sample_index = self._get_dataset_idx(index), self._get_sample_idx(index)
patch_metadata = self.datasets[dataset_index].load_metadata(sample_index)
return {"wsi_id": self.filename(index).split(".")[0]} | patch_metadata
return wsi.MultiWsiDataset.load_metadata(self, index)

def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
"""Loads the file paths of the corresponding dataset split."""
Expand Down
7 changes: 4 additions & 3 deletions src/eva/vision/data/datasets/classification/wsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
split: Literal["train", "val", "test"] | None = None,
image_transforms: Callable | None = None,
column_mapping: Dict[str, str] = default_column_mapping,
coords_path: str | None = None,
):
"""Initializes the dataset.
Expand All @@ -51,6 +52,7 @@ def __init__(
split: The split of the dataset to load.
image_transforms: Transforms to apply to the extracted image patches.
column_mapping: Mapping of the columns in the manifest file.
coords_path: File path to save the patch coordinates as .csv.
"""
self._split = split
self._column_mapping = self.default_column_mapping | column_mapping
Expand All @@ -66,6 +68,7 @@ def __init__(
target_mpp=target_mpp,
backend=backend,
image_transforms=image_transforms,
coords_path=coords_path,
)

@override
Expand All @@ -88,9 +91,7 @@ def load_target(self, index: int) -> np.ndarray:

@override
def load_metadata(self, index: int) -> Dict[str, Any]:
dataset_index, sample_index = self._get_dataset_idx(index), self._get_sample_idx(index)
patch_metadata = self.datasets[dataset_index].load_metadata(sample_index)
return {"wsi_id": self.filename(index).split(".")[0]} | patch_metadata
return wsi.MultiWsiDataset.load_metadata(self, index)

def _load_manifest(self, manifest_path: str) -> pd.DataFrame:
df = pd.read_csv(manifest_path)
Expand Down
21 changes: 21 additions & 0 deletions src/eva/vision/data/datasets/wsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from typing import Any, Callable, Dict, List

import pandas as pd
from loguru import logger
from torch.utils.data import dataset as torch_datasets
from torchvision import tv_tensors
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
overwrite_mpp: float | None = None,
backend: str = "openslide",
image_transforms: Callable | None = None,
coords_path: str | None = None,
):
"""Initializes a new dataset instance.
Expand All @@ -129,6 +131,7 @@ def __init__(
sampler: The sampler to use for sampling patch coordinates.
backend: The backend to use for reading the whole-slide images.
image_transforms: Transforms to apply to the extracted image patches.
coords_path: File path to save the patch coordinates as .csv.
"""
super().__init__()

Expand All @@ -141,6 +144,7 @@ def __init__(
self._sampler = sampler
self._backend = backend
self._image_transforms = image_transforms
self._coords_path = coords_path

self._concat_dataset: torch_datasets.ConcatDataset

Expand All @@ -157,6 +161,7 @@ def cumulative_sizes(self) -> List[int]:
@override
def configure(self) -> None:
self._concat_dataset = torch_datasets.ConcatDataset(datasets=self._load_datasets())
self._save_coords_to_file()

@override
def __len__(self) -> int:
Expand All @@ -170,6 +175,12 @@ def __getitem__(self, index: int) -> tv_tensors.Image:
def filename(self, index: int) -> str:
return os.path.basename(self._file_paths[self._get_dataset_idx(index)])

def load_metadata(self, index: int) -> Dict[str, Any]:
"""Loads the metadata for the patch at the specified index."""
dataset_index, sample_index = self._get_dataset_idx(index), self._get_sample_idx(index)
patch_metadata = self.datasets[dataset_index].load_metadata(sample_index)
return {"wsi_id": self.filename(index).split(".")[0]} | patch_metadata

def _load_datasets(self) -> list[WsiDataset]:
logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...")
wsi_datasets = []
Expand Down Expand Up @@ -200,3 +211,13 @@ def _get_dataset_idx(self, index: int) -> int:
def _get_sample_idx(self, index: int) -> int:
dataset_idx = self._get_dataset_idx(index)
return index if dataset_idx == 0 else index - self.cumulative_sizes[dataset_idx - 1]

def _save_coords_to_file(self):
if self._coords_path is not None:
coords = [
{"file": self._file_paths[i]} | dataset._coords.to_dict()
for i, dataset in enumerate(self.datasets)
]
os.makedirs(os.path.abspath(os.path.join(self._coords_path, os.pardir)), exist_ok=True)
pd.DataFrame(coords).to_csv(self._coords_path, index=False)
logger.info(f"Saved patch coordinates to: {self._coords_path}")
10 changes: 9 additions & 1 deletion src/eva/vision/data/wsi/patching/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses
import functools
from typing import List, Tuple
from typing import Any, Dict, List, Tuple

from eva.vision.data.wsi import backends
from eva.vision.data.wsi.patching import samplers
Expand Down Expand Up @@ -75,6 +75,14 @@ def from_file(

return cls(x_y, scaled_width, scaled_height, level_idx, sample_args.get("mask"))

def to_dict(self, include_keys: List[str] | None = None) -> Dict[str, Any]:
"""Convert the coordinates to a dictionary."""
include_keys = include_keys or ["x_y", "width", "height", "level_idx"]
coord_dict = dataclasses.asdict(self)
if include_keys:
coord_dict = {key: coord_dict[key] for key in include_keys}
return coord_dict


@functools.lru_cache(LRU_CACHE_SIZE)
def get_cached_coords(
Expand Down
12 changes: 11 additions & 1 deletion tests/eva/vision/data/datasets/test_wsi.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""WsiDataset & MultiWsiDataset tests."""

import os
import pathlib
from typing import Tuple

import pandas as pd
import pytest

from eva.vision.data import datasets
Expand Down Expand Up @@ -69,14 +71,16 @@ def test_patch_shape(width: int, height: int, target_mpp: float, root: str, back
assert dataset[0].shape == (3, scaled_width, scaled_height)


def test_multi_dataset(root: str):
def test_multi_dataset(root: str, tmp_path: pathlib.Path):
"""Test MultiWsiDataset with multiple whole-slide image paths."""
file_paths = [
os.path.join(root, "0/a.tiff"),
os.path.join(root, "0/b.tiff"),
os.path.join(root, "1/a.tiff"),
]

# get tmp csv file path for coords
coords_path = (tmp_path / "coords.csv").as_posix()
width, height = 32, 32
dataset = datasets.MultiWsiDataset(
root=root,
Expand All @@ -86,6 +90,7 @@ def test_multi_dataset(root: str):
target_mpp=0.25,
sampler=samplers.GridSampler(max_samples=None),
backend="openslide",
coords_path=coords_path,
)
dataset.setup()

Expand All @@ -94,6 +99,11 @@ def test_multi_dataset(root: str):
assert len(dataset) == _expected_n_patches(layer_shape, width, height, (0, 0)) * len(file_paths)
assert dataset.cumulative_sizes == [64, 128, 192]

assert os.path.exists(coords_path)
df_coords = pd.read_csv(coords_path)
assert "file" in df_coords.columns
assert "x_y" in df_coords.columns


def _expected_n_patches(layer_shape, width, height, overlap):
"""Calculate the expected number of patches."""
Expand Down

0 comments on commit c41872d

Please sign in to comment.