Skip to content

Commit

Permalink
DataModules: improve error messages (#1441)
Browse files Browse the repository at this point in the history
* DataModules: better error messages

* Fix batch size bug

* Fix type hints

* Use in one additional place

* Fix BatchGeoSampler batch size

* Increase test coverage
  • Loading branch information
adamjstewart authored Jul 7, 2023
1 parent b0ae5be commit 9d6d683
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 143 deletions.
76 changes: 69 additions & 7 deletions tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@


class CustomGeoDataset(GeoDataset):
def __init__(self, split: str = "train", download: bool = False) -> None:
def __init__(
self, split: str = "train", length: int = 1, download: bool = False
) -> None:
super().__init__()
self.index.insert(0, (0, 1, 2, 3, 4, 5))
for i in range(length):
self.index.insert(i, (0, 1, 2, 3, 4, 5))
self.res = 1

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
Expand Down Expand Up @@ -58,14 +61,16 @@ def setup(self, stage: str) -> None:


class CustomNonGeoDataset(NonGeoDataset):
def __init__(self, split: str = "train", download: bool = False) -> None:
pass
def __init__(
self, split: str = "train", length: int = 1, download: bool = False
) -> None:
self.length = length

def __getitem__(self, index: int) -> dict[str, Tensor]:
return {"image": torch.arange(3 * 2 * 2).view(3, 2, 2)}

def __len__(self) -> int:
return 1
return self.length

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
return plt.figure()
Expand Down Expand Up @@ -134,7 +139,8 @@ def test_plot(self, datamodule: CustomGeoDataModule) -> None:

def test_no_datasets(self) -> None:
dm = CustomGeoDataModule()
msg = "CustomGeoDataModule.setup does not define a '{}_dataset'"
msg = r"CustomGeoDataModule\.setup must define one of "
msg += r"\('{0}_dataset', 'dataset'\)\."
with pytest.raises(MisconfigurationException, match=msg.format("train")):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format("val")):
Expand All @@ -144,6 +150,48 @@ def test_no_datasets(self) -> None:
with pytest.raises(MisconfigurationException, match=msg.format("predict")):
dm.predict_dataloader()

def test_no_samplers(self) -> None:
dm = CustomGeoDataModule()
dm.dataset = CustomGeoDataset()
msg = r"CustomGeoDataModule\.setup must define one of "
msg += r"\('{0}_batch_sampler', '{0}_sampler', 'batch_sampler', 'sampler'\)\."
with pytest.raises(MisconfigurationException, match=msg.format("train")):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format("val")):
dm.val_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format("test")):
dm.test_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format("predict")):
dm.predict_dataloader()

def test_zero_length_dataset(self) -> None:
dm = CustomGeoDataModule()
dm.dataset = CustomGeoDataset(length=0)
msg = r"CustomGeoDataModule\.dataset has length 0."
with pytest.raises(MisconfigurationException, match=msg):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.val_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.test_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.predict_dataloader()

def test_zero_length_sampler(self) -> None:
dm = CustomGeoDataModule()
dm.dataset = CustomGeoDataset()
dm.sampler = RandomGeoSampler(dm.dataset, 1, 1)
dm.sampler.length = 0
msg = r"CustomGeoDataModule\.sampler has length 0."
with pytest.raises(MisconfigurationException, match=msg):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.val_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.test_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.predict_dataloader()


class TestNonGeoDataModule:
@pytest.fixture
Expand Down Expand Up @@ -193,7 +241,8 @@ def test_plot(self, datamodule: CustomNonGeoDataModule) -> None:

def test_no_datasets(self) -> None:
dm = CustomNonGeoDataModule()
msg = "CustomNonGeoDataModule.setup does not define a '{}_dataset'"
msg = r"CustomNonGeoDataModule\.setup must define one of "
msg += r"\('{0}_dataset', 'dataset'\)\."
with pytest.raises(MisconfigurationException, match=msg.format("train")):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format("val")):
Expand All @@ -202,3 +251,16 @@ def test_no_datasets(self) -> None:
dm.test_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format("predict")):
dm.predict_dataloader()

def test_zero_length_dataset(self) -> None:
dm = CustomNonGeoDataModule()
dm.dataset = CustomNonGeoDataset(length=0)
msg = r"CustomNonGeoDataModule\.dataset has length 0."
with pytest.raises(MisconfigurationException, match=msg):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.val_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.test_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.predict_dataloader()
Loading

0 comments on commit 9d6d683

Please sign in to comment.