Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GeoNRW dataset #2209

Merged
merged 20 commits into from
Aug 27, 2024
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ FireRisk

.. autoclass:: FireRiskDataModule

GeoNRW
^^^^^^

.. autoclass:: GeoNRWDataModule

GID-15
^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,11 @@ Forest Damage

.. autoclass:: ForestDamage

GeoNRW
^^^^^^^

.. autoclass:: GeoNRW

GID-15
^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`FAIR1M`_,OD,Gaofen/Google Earth,"CC-BY-NC-SA-3.0","15,000",37,"1,024x1,024",0.3--0.8,RGB
`FireRisk`_,C,NAIP Aerial,"CC-BY-NC-4.0","91,872",7,"320x320",1,RGB
`Forest Damage`_,OD,Drone imagery,"CDLA-Permissive-1.0","1,543",4,"1,500x1,500",,RGB
`GeoNRW`_,S,Aerial,"CC-BY-4.0","7,783",11,"1,000x1,000",1,"RGB, DEM"
`GID-15`_,S,Gaofen-2,-,150,15,"6,800x7,200",3,RGB
`IDTReeS`_,"OD,C",Aerial,"CC-BY-4.0",591,33,200x200,0.1--1,RGB
`Inria Aerial Image Labeling`_,S,Aerial,-,360,2,"5,000x5,000",0.3,RGB
Expand Down
16 changes: 16 additions & 0 deletions tests/conf/geonrw.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 3
num_classes: 11
num_filters: 1
ignore_index: null
data:
class_path: GeoNRWDataModule
init_args:
batch_size: 1
dict_kwargs:
root: "tests/data/geonrw"
Binary file added tests/data/geonrw/aachen/0_0_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/aachen/0_0_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/aachen/0_0_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/aachen/1_1_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/aachen/1_1_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/aachen/1_1_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/bergisch/0_0_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/bergisch/0_0_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/bergisch/0_0_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/bergisch/1_1_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/bergisch/1_1_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/bergisch/1_1_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/bielefeld/0_0_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/bielefeld/0_0_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/bielefeld/0_0_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/bielefeld/1_1_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/bielefeld/1_1_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/bielefeld/1_1_seg.tif
Binary file not shown.
87 changes: 87 additions & 0 deletions tests/data/geonrw/data.py
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil
import tarfile

import numpy as np
from PIL import Image

# Constants
IMAGE_SIZE = (100, 100)
TRAIN_CITIES = ['aachen', 'bergisch', 'bielefeld']
TEST_CITIES = ['duesseldorf']
CLASSES = [
'background',
'forest',
'water',
'agricultural',
'residential,commercial,industrial',
'grassland,swamp,shrubbery',
'railway,trainstation',
'highway,squares',
'airport,shipyard',
'roads',
'buildings',
]
NUM_SAMPLES_PER_CITY = 2


def create_directories(cities: list[str]) -> None:
for city in cities:
if os.path.exists(city):
shutil.rmtree(city)
os.makedirs(city, exist_ok=True)


def generate_dummy_data(cities: list[str]) -> None:
for city in cities:
for i in range(NUM_SAMPLES_PER_CITY):
utm_coords = f'{i}_{i}'
rgb_image = np.random.randint(0, 256, (*IMAGE_SIZE, 3), dtype=np.uint8)
dem_image = np.random.randint(0, 256, IMAGE_SIZE, dtype=np.uint8)
seg_image = np.random.randint(0, len(CLASSES), IMAGE_SIZE, dtype=np.uint8)

Image.fromarray(rgb_image).save(os.path.join(city, f'{utm_coords}_rgb.jp2'))
Image.fromarray(dem_image).save(os.path.join(city, f'{utm_coords}_dem.tif'))
Image.fromarray(seg_image).save(os.path.join(city, f'{utm_coords}_seg.tif'))


def create_tarball(output_filename: str, source_dirs: list[str]) -> None:
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
with tarfile.open(output_filename, 'w:gz') as tar:
for source_dir in source_dirs:
tar.add(source_dir, arcname=os.path.basename(source_dir))


def calculate_md5(filename: str) -> str:
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
hash_md5 = hashlib.md5()
with open(filename, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
hash_md5.update(chunk)
return hash_md5.hexdigest()


# Main function
def main() -> None:
train_cities = TRAIN_CITIES
test_cities = TEST_CITIES
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

create_directories(train_cities)
create_directories(test_cities)

generate_dummy_data(train_cities)
generate_dummy_data(test_cities)

tarball_name = 'nrw_dataset.tar.gz'
create_tarball(tarball_name, train_cities + test_cities)

md5sum = calculate_md5(tarball_name)
print(f'MD5 checksum: {md5sum}')


if __name__ == '__main__':
main()
Binary file added tests/data/geonrw/duesseldorf/0_0_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/duesseldorf/0_0_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/duesseldorf/0_0_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/duesseldorf/1_1_dem.tif
Binary file not shown.
Binary file added tests/data/geonrw/duesseldorf/1_1_rgb.jp2
Binary file not shown.
Binary file added tests/data/geonrw/duesseldorf/1_1_seg.tif
Binary file not shown.
Binary file added tests/data/geonrw/nrw_dataset.tar.gz
Binary file not shown.
74 changes: 74 additions & 0 deletions tests/datasets/test_geonrw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import DatasetNotFoundError, GeoNRW


class TestGeoNRW:
@pytest.fixture(params=['train', 'test'])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> GeoNRW:
md5 = '6ffc014d4b345bba3076e8d76ab481fa'
monkeypatch.setattr(GeoNRW, 'md5', md5)
url = os.path.join('tests', 'data', 'geonrw', 'nrw_dataset.tar.gz')
monkeypatch.setattr(GeoNRW, 'url', url)
monkeypatch.setattr(GeoNRW, 'train_list', ['aachen', 'bergisch', 'bielefeld'])
monkeypatch.setattr(GeoNRW, 'test_list', ['duesseldorf'])
root = tmp_path
split = request.param
transforms = nn.Identity()
return GeoNRW(root, split, transforms, download=True, checksum=True)

def test_getitem(self, dataset: GeoNRW) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert x['image'].shape[0] == 3
assert isinstance(x['mask'], torch.Tensor)
assert x['image'].shape[-2:] == x['mask'].shape[-2:]

def test_len(self, dataset: GeoNRW) -> None:
if dataset.split == 'train':
assert len(dataset) == 6
else:
assert len(dataset) == 2

def test_already_downloaded(self, dataset: GeoNRW) -> None:
GeoNRW(root=dataset.root)

def test_not_yet_extracted(self, tmp_path: Path) -> None:
filename = 'nrw_dataset.tar.gz'
dir = os.path.join('tests', 'data', 'geonrw')
shutil.copyfile(
os.path.join(dir, filename), os.path.join(str(tmp_path), filename)
)
GeoNRW(root=str(tmp_path))

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
GeoNRW(split='foo')

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
GeoNRW(tmp_path)

def test_plot(self, dataset: GeoNRW) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()

sample = dataset[0]
sample['prediction'] = torch.clone(sample['mask'])
dataset.plot(sample, suptitle='Prediction')
plt.close()
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class TestSemanticSegmentationTask:
'chesapeake_cvpr_7',
'deepglobelandcover',
'etci2021',
'geonrw',
'gid15',
'inria',
'l7irish',
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .fair1m import FAIR1MDataModule
from .fire_risk import FireRiskDataModule
from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule
from .geonrw import GeoNRWDataModule
from .gid15 import GID15DataModule
from .inria import InriaAerialImageLabelingDataModule
from .iobench import IOBenchDataModule
Expand Down Expand Up @@ -73,6 +74,7 @@
'EuroSAT100DataModule',
'FAIR1MDataModule',
'FireRiskDataModule',
'GeoNRWDataModule',
'GID15DataModule',
'InriaAerialImageLabelingDataModule',
'LandCoverAIDataModule',
Expand Down
67 changes: 67 additions & 0 deletions torchgeo/datamodules/geonrw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""GeoNRW datamodule."""

import os
from typing import Any

import kornia.augmentation as K
from torch.utils.data import Subset

from ..datasets import GeoNRW
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule
from .utils import group_shuffle_split


class GeoNRWDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the GeoNRW dataset.

Implements 80/20 train/val splits based on city locations.
See :func:`setup` for more details.
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

.. versionadded: 0.6.0
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self, batch_size: int = 64, num_workers: int = 0, size: int = 256, **kwargs: Any
) -> None:
"""Initialize a new GeoNRWDataModule instance.

Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
size: resize images of input size 1000x1000 to size x size
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.GeoNRW`.
"""
super().__init__(GeoNRW, batch_size, num_workers, **kwargs)

self.train_aug = AugmentationSequential(
K.Resize(size),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=['image', 'mask'],
)

self.aug = AugmentationSequential(K.Resize(size), data_keys=['image', 'mask'])

self.size = size

def setup(self, stage: str) -> None:
"""Set up datasets.

Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if stage in ['fit', 'validate']:
dataset = GeoNRW(split='train', **self.kwargs)
city_paths = [os.path.dirname(path) for path in dataset.file_list]
train_indices, val_indices = group_shuffle_split(
city_paths, test_size=0.2, random_state=0
)
self.train_dataset = Subset(dataset, train_indices)
self.val_dataset = Subset(dataset, val_indices)
if stage in ['test']:
self.test_dataset = GeoNRW(split='test', **self.kwargs)
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
UnionDataset,
VectorDataset,
)
from .geonrw import GeoNRW
from .gid15 import GID15
from .globbiomass import GlobBiomass
from .idtrees import IDTReeS
Expand Down Expand Up @@ -213,6 +214,7 @@
'FAIR1M',
'FireRisk',
'ForestDamage',
'GeoNRW',
'GID15',
'IDTReeS',
'InriaAerialImageLabeling',
Expand Down
Loading