diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index e68f45164a1..1fd6eb8983d 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -21,9 +21,12 @@ class CustomGeoDataset(GeoDataset): - def __init__(self, split: str = "train", download: bool = False) -> None: + def __init__( + self, split: str = "train", length: int = 1, download: bool = False + ) -> None: super().__init__() - self.index.insert(0, (0, 1, 2, 3, 4, 5)) + for i in range(length): + self.index.insert(i, (0, 1, 2, 3, 4, 5)) self.res = 1 def __getitem__(self, query: BoundingBox) -> dict[str, Any]: @@ -58,14 +61,16 @@ def setup(self, stage: str) -> None: class CustomNonGeoDataset(NonGeoDataset): - def __init__(self, split: str = "train", download: bool = False) -> None: - pass + def __init__( + self, split: str = "train", length: int = 1, download: bool = False + ) -> None: + self.length = length def __getitem__(self, index: int) -> dict[str, Tensor]: return {"image": torch.arange(3 * 2 * 2).view(3, 2, 2)} def __len__(self) -> int: - return 1 + return self.length def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: return plt.figure() @@ -134,7 +139,8 @@ def test_plot(self, datamodule: CustomGeoDataModule) -> None: def test_no_datasets(self) -> None: dm = CustomGeoDataModule() - msg = "CustomGeoDataModule.setup does not define a '{}_dataset'" + msg = r"CustomGeoDataModule\.setup must define one of " + msg += r"\('{0}_dataset', 'dataset'\)\." with pytest.raises(MisconfigurationException, match=msg.format("train")): dm.train_dataloader() with pytest.raises(MisconfigurationException, match=msg.format("val")): @@ -144,6 +150,48 @@ def test_no_datasets(self) -> None: with pytest.raises(MisconfigurationException, match=msg.format("predict")): dm.predict_dataloader() + def test_no_samplers(self) -> None: + dm = CustomGeoDataModule() + dm.dataset = CustomGeoDataset() + msg = r"CustomGeoDataModule\.setup must define one of " + msg += r"\('{0}_batch_sampler', '{0}_sampler', 'batch_sampler', 'sampler'\)\." + with pytest.raises(MisconfigurationException, match=msg.format("train")): + dm.train_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("val")): + dm.val_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("test")): + dm.test_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("predict")): + dm.predict_dataloader() + + def test_zero_length_dataset(self) -> None: + dm = CustomGeoDataModule() + dm.dataset = CustomGeoDataset(length=0) + msg = r"CustomGeoDataModule\.dataset has length 0." + with pytest.raises(MisconfigurationException, match=msg): + dm.train_dataloader() + with pytest.raises(MisconfigurationException, match=msg): + dm.val_dataloader() + with pytest.raises(MisconfigurationException, match=msg): + dm.test_dataloader() + with pytest.raises(MisconfigurationException, match=msg): + dm.predict_dataloader() + + def test_zero_length_sampler(self) -> None: + dm = CustomGeoDataModule() + dm.dataset = CustomGeoDataset() + dm.sampler = RandomGeoSampler(dm.dataset, 1, 1) + dm.sampler.length = 0 + msg = r"CustomGeoDataModule\.sampler has length 0." + with pytest.raises(MisconfigurationException, match=msg): + dm.train_dataloader() + with pytest.raises(MisconfigurationException, match=msg): + dm.val_dataloader() + with pytest.raises(MisconfigurationException, match=msg): + dm.test_dataloader() + with pytest.raises(MisconfigurationException, match=msg): + dm.predict_dataloader() + class TestNonGeoDataModule: @pytest.fixture @@ -193,7 +241,8 @@ def test_plot(self, datamodule: CustomNonGeoDataModule) -> None: def test_no_datasets(self) -> None: dm = CustomNonGeoDataModule() - msg = "CustomNonGeoDataModule.setup does not define a '{}_dataset'" + msg = r"CustomNonGeoDataModule\.setup must define one of " + msg += r"\('{0}_dataset', 'dataset'\)\." with pytest.raises(MisconfigurationException, match=msg.format("train")): dm.train_dataloader() with pytest.raises(MisconfigurationException, match=msg.format("val")): @@ -202,3 +251,16 @@ def test_no_datasets(self) -> None: dm.test_dataloader() with pytest.raises(MisconfigurationException, match=msg.format("predict")): dm.predict_dataloader() + + def test_zero_length_dataset(self) -> None: + dm = CustomNonGeoDataModule() + dm.dataset = CustomNonGeoDataset(length=0) + msg = r"CustomNonGeoDataModule\.dataset has length 0." + with pytest.raises(MisconfigurationException, match=msg): + dm.train_dataloader() + with pytest.raises(MisconfigurationException, match=msg): + dm.val_dataloader() + with pytest.raises(MisconfigurationException, match=msg): + dm.test_dataloader() + with pytest.raises(MisconfigurationException, match=msg): + dm.predict_dataloader() diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index f733467f042..c0e1b714710 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -87,6 +87,33 @@ def prepare_data(self) -> None: if self.kwargs.get("download", False): self.dataset_class(**self.kwargs) + def _valid_attribute(self, *args: str) -> Any: + """Find a valid attribute with length > 0. + + Args: + args: One or more names of attributes to check. + + Returns: + The first valid attribute found. + + Raises: + MisconfigurationException: If no attribute is defined, or has length 0. + """ + for arg in args: + obj = getattr(self, arg) + + if obj is None: + continue + + if not obj: + msg = f"{self.__class__.__name__}.{arg} has length 0." + raise MisconfigurationException(msg) + + return obj + + msg = f"{self.__class__.__name__}.setup must define one of {args}." + raise MisconfigurationException(msg) + def on_after_batch_transfer( self, batch: dict[str, Tensor], dataloader_idx: int ) -> dict[str, Tensor]: @@ -101,14 +128,15 @@ def on_after_batch_transfer( """ if self.trainer: if self.trainer.training: - aug = self.train_aug or self.aug + split = "train" elif self.trainer.validating or self.trainer.sanity_checking: - aug = self.val_aug or self.aug + split = "val" elif self.trainer.testing: - aug = self.test_aug or self.aug + split = "test" elif self.trainer.predicting: - aug = self.predict_aug or self.aug + split = "predict" + aug = self._valid_attribute(f"{split}_aug", "aug") batch = aug(batch) return batch @@ -220,6 +248,41 @@ def setup(self, stage: str) -> None: self.test_dataset, self.patch_size, self.patch_size ) + def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders. + + Args: + split: Either 'train', 'val', 'test', or 'predict'. + + Returns: + A collection of data loaders specifying samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + dataset or sampler, or if the dataset or sampler has length 0. + """ + dataset = self._valid_attribute(f"{split}_dataset", "dataset") + sampler = self._valid_attribute( + f"{split}_batch_sampler", f"{split}_sampler", "batch_sampler", "sampler" + ) + batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size") + + if isinstance(sampler, BatchGeoSampler): + batch_size = 1 + batch_sampler = sampler + sampler = None + else: + batch_sampler = None + + return DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for training. @@ -228,27 +291,9 @@ def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: Raises: MisconfigurationException: If :meth:`setup` does not define a - 'train_dataset'. + dataset or sampler, or if the dataset or sampler has length 0. """ - dataset = self.train_dataset or self.dataset - sampler = self.train_sampler or self.sampler - batch_sampler = self.train_batch_sampler or self.batch_sampler - if dataset is not None and (sampler or batch_sampler) is not None: - batch_size = self.train_batch_size or self.batch_size - if batch_sampler is not None: - batch_size = 1 - sampler = None - return DataLoader( - dataset=dataset, - batch_size=batch_size, - sampler=sampler, - batch_sampler=batch_sampler, - num_workers=self.num_workers, - collate_fn=self.collate_fn, - ) - else: - msg = f"{self.__class__.__name__}.setup does not define a 'train_dataset'" - raise MisconfigurationException(msg) + return self._dataloader_factory("train") def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for validation. @@ -258,27 +303,9 @@ def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: Raises: MisconfigurationException: If :meth:`setup` does not define a - 'val_dataset'. + dataset or sampler, or if the dataset or sampler has length 0. """ - dataset = self.val_dataset or self.dataset - sampler = self.val_sampler or self.sampler - batch_sampler = self.val_batch_sampler or self.batch_sampler - if dataset is not None and (sampler or batch_sampler) is not None: - batch_size = self.val_batch_size or self.batch_size - if batch_sampler is not None: - batch_size = 1 - sampler = None - return DataLoader( - dataset=dataset, - batch_size=batch_size, - sampler=sampler, - batch_sampler=batch_sampler, - num_workers=self.num_workers, - collate_fn=self.collate_fn, - ) - else: - msg = f"{self.__class__.__name__}.setup does not define a 'val_dataset'" - raise MisconfigurationException(msg) + return self._dataloader_factory("val") def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for testing. @@ -288,27 +315,9 @@ def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: Raises: MisconfigurationException: If :meth:`setup` does not define a - 'test_dataset'. + dataset or sampler, or if the dataset or sampler has length 0. """ - dataset = self.test_dataset or self.dataset - sampler = self.test_sampler or self.sampler - batch_sampler = self.test_batch_sampler or self.batch_sampler - if dataset is not None and (sampler or batch_sampler) is not None: - batch_size = self.test_batch_size or self.batch_size - if batch_sampler is not None: - batch_size = 1 - sampler = None - return DataLoader( - dataset=dataset, - batch_size=batch_size, - sampler=sampler, - batch_sampler=batch_sampler, - num_workers=self.num_workers, - collate_fn=self.collate_fn, - ) - else: - msg = f"{self.__class__.__name__}.setup does not define a 'test_dataset'" - raise MisconfigurationException(msg) + return self._dataloader_factory("test") def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for prediction. @@ -318,27 +327,9 @@ def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]: Raises: MisconfigurationException: If :meth:`setup` does not define a - 'predict_dataset'. + dataset or sampler, or if the dataset or sampler has length 0. """ - dataset = self.predict_dataset or self.dataset - sampler = self.predict_sampler or self.sampler - batch_sampler = self.predict_batch_sampler or self.batch_sampler - if dataset is not None and (sampler or batch_sampler) is not None: - batch_size = self.predict_batch_size or self.batch_size - if batch_sampler is not None: - batch_size = 1 - sampler = None - return DataLoader( - dataset=dataset, - batch_size=batch_size, - sampler=sampler, - batch_sampler=batch_sampler, - num_workers=self.num_workers, - collate_fn=self.collate_fn, - ) - else: - msg = f"{self.__class__.__name__}.setup does not define a 'predict_dataset'" - raise MisconfigurationException(msg) + return self._dataloader_factory("predict") def transfer_batch_to_device( self, batch: dict[str, Tensor], device: torch.device, dataloader_idx: int @@ -412,6 +403,29 @@ def setup(self, stage: str) -> None: split="test", **self.kwargs ) + def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders. + + Args: + split: Either 'train', 'val', 'test', or 'predict'. + + Returns: + A collection of data loaders specifying samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + dataset or sampler, or if the dataset or sampler has length 0. + """ + dataset = self._valid_attribute(f"{split}_dataset", "dataset") + batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size") + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=split == "train", + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for training. @@ -420,20 +434,9 @@ def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: Raises: MisconfigurationException: If :meth:`setup` does not define a - 'train_dataset'. + dataset, or if the dataset has length 0. """ - dataset = self.train_dataset or self.dataset - if dataset is not None: - return DataLoader( - dataset=dataset, - batch_size=self.train_batch_size or self.batch_size, - shuffle=True, - num_workers=self.num_workers, - collate_fn=self.collate_fn, - ) - else: - msg = f"{self.__class__.__name__}.setup does not define a 'train_dataset'" - raise MisconfigurationException(msg) + return self._dataloader_factory("train") def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for validation. @@ -443,20 +446,9 @@ def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: Raises: MisconfigurationException: If :meth:`setup` does not define a - 'val_dataset'. + dataset, or if the dataset has length 0. """ - dataset = self.val_dataset or self.dataset - if dataset is not None: - return DataLoader( - dataset=dataset, - batch_size=self.val_batch_size or self.batch_size, - shuffle=False, - num_workers=self.num_workers, - collate_fn=self.collate_fn, - ) - else: - msg = f"{self.__class__.__name__}.setup does not define a 'val_dataset'" - raise MisconfigurationException(msg) + return self._dataloader_factory("val") def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for testing. @@ -466,20 +458,9 @@ def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: Raises: MisconfigurationException: If :meth:`setup` does not define a - 'test_dataset'. + dataset, or if the dataset has length 0. """ - dataset = self.test_dataset or self.dataset - if dataset is not None: - return DataLoader( - dataset=dataset, - batch_size=self.test_batch_size or self.batch_size, - shuffle=False, - num_workers=self.num_workers, - collate_fn=self.collate_fn, - ) - else: - msg = f"{self.__class__.__name__}.setup does not define a 'test_dataset'" - raise MisconfigurationException(msg) + return self._dataloader_factory("test") def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for prediction. @@ -489,17 +470,6 @@ def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]: Raises: MisconfigurationException: If :meth:`setup` does not define a - 'predict_dataset'. + dataset, or if the dataset has length 0. """ - dataset = self.predict_dataset or self.dataset - if dataset is not None: - return DataLoader( - dataset=dataset, - batch_size=self.predict_batch_size or self.batch_size, - shuffle=False, - num_workers=self.num_workers, - collate_fn=self.collate_fn, - ) - else: - msg = f"{self.__class__.__name__}.setup does not define a 'predict_dataset'" - raise MisconfigurationException(msg) + return self._dataloader_factory("predict")