Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ioangatop committed Apr 22, 2024
2 parents 44691e1 + bc57f9e commit 5fde5fe
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 13 deletions.
8 changes: 4 additions & 4 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ maintainers = [
]
requires-python = ">=3.10"
dependencies = [
"lightning>=2.2.1",
"lightning>=2.2.2",
"jsonargparse[omegaconf]>=4.27.4",
"tensorboard>=2.16.2",
"loguru>=0.7.2",
Expand Down
3 changes: 1 addition & 2 deletions src/eva/vision/data/datasets/segmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import abc
from typing import Any, Callable, Dict, List, Tuple

import numpy as np
from torchvision import tv_tensors
from typing_extensions import override

Expand Down Expand Up @@ -54,7 +53,7 @@ def load_image(self, index: int) -> tv_tensors.Image:
index: The index of the data sample to load.
Returns:
An image torchvision tensor.
An image torchvision tensor (channels, height, width).
"""

@abc.abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def load_masks(self, index: int) -> tv_tensors.Mask:
list_of_mask_arrays = [io.read_nifti_slice(path, slice_index) for path in mask_paths]
masks = np.concatenate(list_of_mask_arrays, axis=-1)
return tv_tensors.Mask(masks.transpose(2, 0, 1))

def _get_masks_dir(self, index: int) -> str:
"""Returns the directory of the corresponding masks."""
sample_dir = self._get_sample_dir(index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import os
from typing import Literal

import numpy as np
import pytest
from torchvision import tv_tensors

from eva.vision.data import datasets

Expand Down Expand Up @@ -35,10 +35,10 @@ def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: i
assert len(sample) == 2
# assert the format of the `image` and `mask`
image, mask = sample
assert isinstance(image, np.ndarray)
assert image.shape == (16, 16, 3)
assert isinstance(mask, np.ndarray)
assert mask.shape == (16, 16, 3)
assert isinstance(image, tv_tensors.Image)
assert image.shape == (3, 16, 16)
assert isinstance(mask, tv_tensors.Mask)
assert mask.shape == (3, 16, 16)


@pytest.fixture(scope="function")
Expand Down

0 comments on commit 5fde5fe

Please sign in to comment.