Skip to content

Commit

Permalink
Add GeoNRW dataset (#2209)
Browse files Browse the repository at this point in the history
* dataset and module

* test with training

* add tests

* start the fight with mypy

* kick off tests

* class var ruff

* don't download

* forgot tests data

* already downloaded

* coverage

* review

* mypy

* docs

* docs

* suggestion

* plotting

* versionadded: 2 digits

* Type hint unnecessary

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
nilsleh and adamjstewart authored Aug 27, 2024
1 parent 6d758ab commit 2d6e27e
Show file tree
Hide file tree
Showing 36 changed files with 606 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ FireRisk

.. autoclass:: FireRiskDataModule

GeoNRW
^^^^^^

.. autoclass:: GeoNRWDataModule

GID-15
^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,11 @@ Forest Damage

.. autoclass:: ForestDamage

GeoNRW
^^^^^^^

.. autoclass:: GeoNRW

GID-15
^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`FAIR1M`_,OD,Gaofen/Google Earth,"CC-BY-NC-SA-3.0","15,000",37,"1,024x1,024",0.3--0.8,RGB
`FireRisk`_,C,NAIP Aerial,"CC-BY-NC-4.0","91,872",7,"320x320",1,RGB
`Forest Damage`_,OD,Drone imagery,"CDLA-Permissive-1.0","1,543",4,"1,500x1,500",,RGB
`GeoNRW`_,S,Aerial,"CC-BY-4.0","7,783",11,"1,000x1,000",1,"RGB, DEM"
`GID-15`_,S,Gaofen-2,-,150,15,"6,800x7,200",3,RGB
`IDTReeS`_,"OD,C",Aerial,"CC-BY-4.0",591,33,200x200,0.1--1,RGB
`Inria Aerial Image Labeling`_,S,Aerial,-,360,2,"5,000x5,000",0.3,RGB
Expand Down
16 changes: 16 additions & 0 deletions tests/conf/geonrw.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 3
num_classes: 11
num_filters: 1
ignore_index: null
data:
class_path: GeoNRWDataModule
init_args:
batch_size: 1
dict_kwargs:
root: "tests/data/geonrw"
Binary file added tests/data/geonrw/aachen/0_0_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/aachen/0_0_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/aachen/0_0_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/aachen/1_1_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/aachen/1_1_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/aachen/1_1_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/bergisch/0_0_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/bergisch/0_0_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/bergisch/0_0_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/bergisch/1_1_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/bergisch/1_1_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/bergisch/1_1_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/bielefeld/0_0_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/bielefeld/0_0_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/bielefeld/0_0_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/bielefeld/1_1_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/bielefeld/1_1_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/bielefeld/1_1_seg.tif
Binary file not shown.
87 changes: 87 additions & 0 deletions tests/data/geonrw/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil
import tarfile

import numpy as np
from PIL import Image

# Constants
IMAGE_SIZE = (100, 100)
TRAIN_CITIES = ['aachen', 'bergisch', 'bielefeld']
TEST_CITIES = ['duesseldorf']
CLASSES = [
'background',
'forest',
'water',
'agricultural',
'residential,commercial,industrial',
'grassland,swamp,shrubbery',
'railway,trainstation',
'highway,squares',
'airport,shipyard',
'roads',
'buildings',
]
NUM_SAMPLES_PER_CITY = 2


def create_directories(cities: list[str]) -> None:
for city in cities:
if os.path.exists(city):
shutil.rmtree(city)
os.makedirs(city, exist_ok=True)


def generate_dummy_data(cities: list[str]) -> None:
for city in cities:
for i in range(NUM_SAMPLES_PER_CITY):
utm_coords = f'{i}_{i}'
rgb_image = np.random.randint(0, 256, (*IMAGE_SIZE, 3), dtype=np.uint8)
dem_image = np.random.randint(0, 256, IMAGE_SIZE, dtype=np.uint8)
seg_image = np.random.randint(0, len(CLASSES), IMAGE_SIZE, dtype=np.uint8)

Image.fromarray(rgb_image).save(os.path.join(city, f'{utm_coords}_rgb.jp2'))
Image.fromarray(dem_image).save(os.path.join(city, f'{utm_coords}_dem.tif'))
Image.fromarray(seg_image).save(os.path.join(city, f'{utm_coords}_seg.tif'))


def create_tarball(output_filename: str, source_dirs: list[str]) -> None:
with tarfile.open(output_filename, 'w:gz') as tar:
for source_dir in source_dirs:
tar.add(source_dir, arcname=os.path.basename(source_dir))


def calculate_md5(filename: str) -> str:
hash_md5 = hashlib.md5()
with open(filename, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
hash_md5.update(chunk)
return hash_md5.hexdigest()


# Main function
def main() -> None:
train_cities = TRAIN_CITIES
test_cities = TEST_CITIES

create_directories(train_cities)
create_directories(test_cities)

generate_dummy_data(train_cities)
generate_dummy_data(test_cities)

tarball_name = 'nrw_dataset.tar.gz'
create_tarball(tarball_name, train_cities + test_cities)

md5sum = calculate_md5(tarball_name)
print(f'MD5 checksum: {md5sum}')


if __name__ == '__main__':
main()
Binary file added tests/data/geonrw/duesseldorf/0_0_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/duesseldorf/0_0_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/duesseldorf/0_0_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/duesseldorf/1_1_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/duesseldorf/1_1_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/duesseldorf/1_1_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/nrw_dataset.tar.gz
Binary file not shown.
74 changes: 74 additions & 0 deletions tests/datasets/test_geonrw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import DatasetNotFoundError, GeoNRW


class TestGeoNRW:
@pytest.fixture(params=['train', 'test'])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> GeoNRW:
md5 = '6ffc014d4b345bba3076e8d76ab481fa'
monkeypatch.setattr(GeoNRW, 'md5', md5)
url = os.path.join('tests', 'data', 'geonrw', 'nrw_dataset.tar.gz')
monkeypatch.setattr(GeoNRW, 'url', url)
monkeypatch.setattr(GeoNRW, 'train_list', ['aachen', 'bergisch', 'bielefeld'])
monkeypatch.setattr(GeoNRW, 'test_list', ['duesseldorf'])
root = tmp_path
split = request.param
transforms = nn.Identity()
return GeoNRW(root, split, transforms, download=True, checksum=True)

def test_getitem(self, dataset: GeoNRW) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert x['image'].shape[0] == 3
assert isinstance(x['mask'], torch.Tensor)
assert x['image'].shape[-2:] == x['mask'].shape[-2:]

def test_len(self, dataset: GeoNRW) -> None:
if dataset.split == 'train':
assert len(dataset) == 6
else:
assert len(dataset) == 2

def test_already_downloaded(self, dataset: GeoNRW) -> None:
GeoNRW(root=dataset.root)

def test_not_yet_extracted(self, tmp_path: Path) -> None:
filename = 'nrw_dataset.tar.gz'
dir = os.path.join('tests', 'data', 'geonrw')
shutil.copyfile(
os.path.join(dir, filename), os.path.join(str(tmp_path), filename)
)
GeoNRW(root=str(tmp_path))

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
GeoNRW(split='foo')

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
GeoNRW(tmp_path)

def test_plot(self, dataset: GeoNRW) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()

sample = dataset[0]
sample['prediction'] = torch.clone(sample['mask'])
dataset.plot(sample, suptitle='Prediction')
plt.close()
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class TestSemanticSegmentationTask:
'chesapeake_cvpr_7',
'deepglobelandcover',
'etci2021',
'geonrw',
'gid15',
'inria',
'l7irish',
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .fair1m import FAIR1MDataModule
from .fire_risk import FireRiskDataModule
from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule
from .geonrw import GeoNRWDataModule
from .gid15 import GID15DataModule
from .inria import InriaAerialImageLabelingDataModule
from .iobench import IOBenchDataModule
Expand Down Expand Up @@ -73,6 +74,7 @@
'EuroSAT100DataModule',
'FAIR1MDataModule',
'FireRiskDataModule',
'GeoNRWDataModule',
'GID15DataModule',
'InriaAerialImageLabelingDataModule',
'LandCoverAIDataModule',
Expand Down
67 changes: 67 additions & 0 deletions torchgeo/datamodules/geonrw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""GeoNRW datamodule."""

import os
from typing import Any

import kornia.augmentation as K
from torch.utils.data import Subset

from ..datasets import GeoNRW
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule
from .utils import group_shuffle_split


class GeoNRWDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the GeoNRW dataset.
Implements 80/20 train/val splits based on city locations.
See :func:`setup` for more details.
.. versionadded: 0.6
"""

def __init__(
self, batch_size: int = 64, num_workers: int = 0, size: int = 256, **kwargs: Any
) -> None:
"""Initialize a new GeoNRWDataModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
size: resize images of input size 1000x1000 to size x size
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.GeoNRW`.
"""
super().__init__(GeoNRW, batch_size, num_workers, **kwargs)

self.train_aug = AugmentationSequential(
K.Resize(size),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=['image', 'mask'],
)

self.aug = AugmentationSequential(K.Resize(size), data_keys=['image', 'mask'])

self.size = size

def setup(self, stage: str) -> None:
"""Set up datasets.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if stage in ['fit', 'validate']:
dataset = GeoNRW(split='train', **self.kwargs)
city_paths = [os.path.dirname(path) for path in dataset.file_list]
train_indices, val_indices = group_shuffle_split(
city_paths, test_size=0.2, random_state=0
)
self.train_dataset = Subset(dataset, train_indices)
self.val_dataset = Subset(dataset, val_indices)
if stage in ['test']:
self.test_dataset = GeoNRW(split='test', **self.kwargs)
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
UnionDataset,
VectorDataset,
)
from .geonrw import GeoNRW
from .gid15 import GID15
from .globbiomass import GlobBiomass
from .idtrees import IDTReeS
Expand Down Expand Up @@ -213,6 +214,7 @@
'FAIR1M',
'FireRisk',
'ForestDamage',
'GeoNRW',
'GID15',
'IDTReeS',
'InriaAerialImageLabeling',
Expand Down
Loading

0 comments on commit 2d6e27e

Please sign in to comment.