Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PASTIS dataset #315

Merged
merged 44 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
0dac108
draft
isaaccorley Dec 20, 2021
ef46dd3
add dataset to __init__
isaaccorley Jan 8, 2022
e9935c3
reorganize datasets and datamodules
isaaccorley Jul 6, 2022
ebb4252
fix mypy errors
isaaccorley Jul 6, 2022
7718690
draft
isaaccorley Dec 20, 2021
7d3ecb1
add dataset to __init__
isaaccorley Jan 8, 2022
3a0ce52
reorganize datasets and datamodules
isaaccorley Jul 6, 2022
028d2e1
fix mypy errors
isaaccorley Jul 6, 2022
2bff295
Merge branch 'datasets/pastis-r' of github.com:isaaccorley/torchgeo i…
isaaccorley Apr 23, 2023
f54073e
refactor
isaaccorley Apr 23, 2023
0c2e635
Merge branch 'main' into datasets/pastis-r
isaaccorley Apr 24, 2023
a751e5c
Merge branch 'main' into datasets/pastis-r
isaaccorley Apr 24, 2023
472b914
Merge branch 'main' into datasets/pastis-r
isaaccorley Apr 25, 2023
f97fb16
Merge branch 'main' into datasets/pastis-r
isaaccorley May 4, 2023
f613c6b
Merge branch 'main' into datasets/pastis-r
isaaccorley May 4, 2023
99c277e
Adding docs
calebrob6 Jul 27, 2023
7ada35e
Merge branch 'main' into datasets/pastis-r
calebrob6 Jul 27, 2023
7116785
Adding plotting, cleaning up some stuff
calebrob6 Jul 27, 2023
42c3217
Black and isort
calebrob6 Jul 27, 2023
09b58b5
Fix the datamodule import
calebrob6 Jul 27, 2023
20c5932
Pyupgrade
calebrob6 Jul 28, 2023
60f33fe
Fixing some docstrings
calebrob6 Jul 28, 2023
ee50b0e
Flake8
calebrob6 Jul 28, 2023
c71074b
Isort
calebrob6 Jul 28, 2023
012f53b
Fix docstrings in datamodules
calebrob6 Jul 28, 2023
b4dbdf6
Fixing fns and docstring
calebrob6 Jul 28, 2023
cacfc62
Trying to fix the docs
calebrob6 Jul 28, 2023
2924fbe
Trying to fix docs
calebrob6 Jul 28, 2023
c02ab27
Adding tests
calebrob6 Jul 28, 2023
c7923c3
Black
calebrob6 Jul 28, 2023
2695665
newline
calebrob6 Jul 28, 2023
24047d3
Made the test dataset larger
calebrob6 Jul 28, 2023
92381b3
Remove the datamodules
calebrob6 Jul 28, 2023
a4c3294
Update docs/api/non_geo_datasets.csv
calebrob6 Jul 31, 2023
c4b51a7
Update torchgeo/datasets/pastis.py
calebrob6 Jul 31, 2023
0312f6e
Update torchgeo/datasets/pastis.py
calebrob6 Jul 31, 2023
c4df0b7
Update torchgeo/datasets/pastis.py
calebrob6 Jul 31, 2023
0c6b51c
Updating cmap
calebrob6 Jul 31, 2023
31e7851
Describe the different band combinations
calebrob6 Jul 31, 2023
85f98e5
Merging datasets
calebrob6 Jul 31, 2023
5cd1efb
Handle the instance segmentation case in plotting
calebrob6 Jul 31, 2023
7dfb4ce
Update torchgeo/datasets/pastis.py
calebrob6 Aug 1, 2023
09b7ac4
Made some code prettier
calebrob6 Aug 1, 2023
5c3c42c
Adding instance plotting
calebrob6 Aug 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ OSCD

.. autoclass:: OSCD

PASTIS
^^^^^^

.. autoclass:: PASTIS

PatternNet
^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`Million-AID`_,C,Google Earth,1M,51--73,,0.5--153,RGB
`NASA Marine Debris`_,OD,PlanetScope,707,1,256x256,3,RGB
`OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI
`PASTIS`_,I,Sentinel-1/2,"2,433",19,128x128xT,10,MSI
`PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB
`Potsdam`_,S,Aerial,38,6,"6,000x6,000",0.05,MSI
`ReforesTree`_,"OD, R",Aerial,100,6,"4,000x4,000",0.02,RGB
Expand Down
Binary file added tests/data/pastis/PASTIS-R.zip
Binary file not shown.
91 changes: 91 additions & 0 deletions tests/data/pastis/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil
from typing import Union

import fiona
import numpy as np

SIZE = 32
NUM_SAMPLES = 5
MAX_NUM_TIME_STEPS = 10
np.random.seed(0)

FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]]

filenames: FILENAME_HIERARCHY = {
"DATA_S2": ["S2"],
"DATA_S1A": ["S1A"],
"DATA_S1D": ["S1D"],
"ANNOTATIONS": ["TARGET"],
"INSTANCE_ANNOTATIONS": ["INSTANCES"],
}


def create_file(path: str) -> None:
for i in range(NUM_SAMPLES):
new_path = f"{path}_{i}.npy"
fn = os.path.basename(new_path)
t = np.random.randint(1, MAX_NUM_TIME_STEPS)
if fn.startswith("S2"):
data = np.random.randint(0, 256, size=(t, 10, SIZE, SIZE)).astype(np.int16)
elif fn.startswith("S1A"):
data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16)
elif fn.startswith("S1D"):
data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16)
elif fn.startswith("TARGET"):
data = np.random.randint(0, 20, size=(3, SIZE, SIZE)).astype(np.uint8)
elif fn.startswith("INSTANCES"):
data = np.random.randint(0, 100, size=(SIZE, SIZE)).astype(np.int64)
np.save(new_path, data)


def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None:
if isinstance(hierarchy, dict):
# Recursive case
for key, value in hierarchy.items():
path = os.path.join(directory, key)
os.makedirs(path, exist_ok=True)
create_directory(path, value)
else:
# Base case
for value in hierarchy:
path = os.path.join(directory, value)
create_file(path)


if __name__ == "__main__":
create_directory("PASTIS-R", filenames)

schema = {"geometry": "Polygon", "properties": {"Fold": "int", "ID_PATCH": "int"}}
with fiona.open(
os.path.join("PASTIS-R", "metadata.geojson"),
"w",
"GeoJSON",
crs="EPSG:4326",
schema=schema,
) as f:
for i in range(NUM_SAMPLES):
f.write(
{
"geometry": {
"type": "Polygon",
"coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]],
},
"id": str(i),
"properties": {"Fold": i % 5, "ID_PATCH": i},
}
)

filename = "PASTIS-R.zip"
shutil.make_archive(filename.replace(".zip", ""), "zip", ".", "PASTIS-R")

# Compute checksums
with open(filename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{filename}: {md5}")
110 changes: 110 additions & 0 deletions tests/datasets/test_pastis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset

import torchgeo.datasets.utils
from torchgeo.datasets import PASTIS


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


class TestPASTIS:
@pytest.fixture(
params=[
{"folds": (0, 1), "bands": "s2", "mode": "semantic"},
{"folds": (0, 1), "bands": "s1a", "mode": "semantic"},
{"folds": (0, 1), "bands": "s1d", "mode": "instance"},
]
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> PASTIS:
monkeypatch.setattr(torchgeo.datasets.pastis, "download_url", download_url)

md5 = "9b11ae132623a0d13f7f0775d2003703"
monkeypatch.setattr(PASTIS, "md5", md5)
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip")
monkeypatch.setattr(PASTIS, "url", url)
root = str(tmp_path)
folds = request.param["folds"]
bands = request.param["bands"]
mode = request.param["mode"]
transforms = nn.Identity()
return PASTIS(
root, folds, bands, mode, transforms, download=True, checksum=True
)

def test_getitem_semantic(self, dataset: PASTIS) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)

def test_getitem_instance(self, dataset: PASTIS) -> None:
dataset.mode = "instance"
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
assert isinstance(x["boxes"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)

def test_len(self, dataset: PASTIS) -> None:
assert len(dataset) == 2

def test_add(self, dataset: PASTIS) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
assert len(ds) == 4

def test_already_extracted(self, dataset: PASTIS) -> None:
PASTIS(root=dataset.root, download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip")
root = str(tmp_path)
shutil.copy(url, root)
PASTIS(root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
PASTIS(str(tmp_path))

def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "PASTIS-R.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
PASTIS(root=str(tmp_path), checksum=True)

def test_invalid_fold(self) -> None:
with pytest.raises(AssertionError):
PASTIS(folds=(6,))

def test_invalid_mode(self) -> None:
with pytest.raises(AssertionError):
PASTIS(mode="invalid")

def test_plot(self, dataset: PASTIS) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["mask"].clone()
if dataset.mode == "instance":
x["prediction_labels"] = x["label"].clone()
dataset.plot(x)
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from .nlcd import NLCD
from .openbuildings import OpenBuildings
from .oscd import OSCD
from .pastis import PASTIS
from .patternnet import PatternNet
from .potsdam import Potsdam2D
from .reforestree import ReforesTree
Expand Down Expand Up @@ -194,6 +195,7 @@
"MillionAID",
"NASAMarineDebris",
"OSCD",
"PASTIS",
"PatternNet",
"Potsdam2D",
"RESISC45",
Expand Down
Loading