-
Notifications
You must be signed in to change notification settings - Fork 332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PASTIS dataset #315
Merged
Merged
PASTIS dataset #315
Changes from 39 commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
0dac108
draft
isaaccorley ef46dd3
add dataset to __init__
isaaccorley e9935c3
reorganize datasets and datamodules
isaaccorley ebb4252
fix mypy errors
isaaccorley 7718690
draft
isaaccorley 7d3ecb1
add dataset to __init__
isaaccorley 3a0ce52
reorganize datasets and datamodules
isaaccorley 028d2e1
fix mypy errors
isaaccorley 2bff295
Merge branch 'datasets/pastis-r' of github.com:isaaccorley/torchgeo i…
isaaccorley f54073e
refactor
isaaccorley 0c2e635
Merge branch 'main' into datasets/pastis-r
isaaccorley a751e5c
Merge branch 'main' into datasets/pastis-r
isaaccorley 472b914
Merge branch 'main' into datasets/pastis-r
isaaccorley f97fb16
Merge branch 'main' into datasets/pastis-r
isaaccorley f613c6b
Merge branch 'main' into datasets/pastis-r
isaaccorley 99c277e
Adding docs
calebrob6 7ada35e
Merge branch 'main' into datasets/pastis-r
calebrob6 7116785
Adding plotting, cleaning up some stuff
calebrob6 42c3217
Black and isort
calebrob6 09b58b5
Fix the datamodule import
calebrob6 20c5932
Pyupgrade
calebrob6 60f33fe
Fixing some docstrings
calebrob6 ee50b0e
Flake8
calebrob6 c71074b
Isort
calebrob6 012f53b
Fix docstrings in datamodules
calebrob6 b4dbdf6
Fixing fns and docstring
calebrob6 cacfc62
Trying to fix the docs
calebrob6 2924fbe
Trying to fix docs
calebrob6 c02ab27
Adding tests
calebrob6 c7923c3
Black
calebrob6 2695665
newline
calebrob6 24047d3
Made the test dataset larger
calebrob6 92381b3
Remove the datamodules
calebrob6 a4c3294
Update docs/api/non_geo_datasets.csv
calebrob6 c4b51a7
Update torchgeo/datasets/pastis.py
calebrob6 0312f6e
Update torchgeo/datasets/pastis.py
calebrob6 c4df0b7
Update torchgeo/datasets/pastis.py
calebrob6 0c6b51c
Updating cmap
calebrob6 31e7851
Describe the different band combinations
calebrob6 85f98e5
Merging datasets
calebrob6 5cd1efb
Handle the instance segmentation case in plotting
calebrob6 7dfb4ce
Update torchgeo/datasets/pastis.py
calebrob6 09b7ac4
Made some code prettier
calebrob6 5c3c42c
Adding instance plotting
calebrob6 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import hashlib | ||
import os | ||
import shutil | ||
from typing import Union | ||
|
||
import fiona | ||
import numpy as np | ||
|
||
SIZE = 32 | ||
NUM_SAMPLES = 5 | ||
MAX_NUM_TIME_STEPS = 10 | ||
np.random.seed(0) | ||
|
||
FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]] | ||
|
||
filenames: FILENAME_HIERARCHY = { | ||
"DATA_S2": ["S2"], | ||
"DATA_S1A": ["S1A"], | ||
"DATA_S1D": ["S1D"], | ||
"ANNOTATIONS": ["TARGET"], | ||
"INSTANCE_ANNOTATIONS": ["INSTANCES"], | ||
} | ||
|
||
|
||
def create_file(path: str) -> None: | ||
for i in range(NUM_SAMPLES): | ||
new_path = f"{path}_{i}.npy" | ||
fn = os.path.basename(new_path) | ||
t = np.random.randint(1, MAX_NUM_TIME_STEPS) | ||
if fn.startswith("S2"): | ||
data = np.random.randint(0, 256, size=(t, 10, SIZE, SIZE)).astype(np.int16) | ||
elif fn.startswith("S1A"): | ||
data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16) | ||
elif fn.startswith("S1D"): | ||
data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16) | ||
elif fn.startswith("TARGET"): | ||
data = np.random.randint(0, 20, size=(3, SIZE, SIZE)).astype(np.uint8) | ||
elif fn.startswith("INSTANCES"): | ||
data = np.random.randint(0, 100, size=(SIZE, SIZE)).astype(np.int64) | ||
np.save(new_path, data) | ||
|
||
|
||
def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: | ||
if isinstance(hierarchy, dict): | ||
# Recursive case | ||
for key, value in hierarchy.items(): | ||
path = os.path.join(directory, key) | ||
os.makedirs(path, exist_ok=True) | ||
create_directory(path, value) | ||
else: | ||
# Base case | ||
for value in hierarchy: | ||
path = os.path.join(directory, value) | ||
create_file(path) | ||
|
||
|
||
if __name__ == "__main__": | ||
create_directory("PASTIS-R", filenames) | ||
|
||
schema = {"geometry": "Polygon", "properties": {"Fold": "int", "ID_PATCH": "int"}} | ||
with fiona.open( | ||
os.path.join("PASTIS-R", "metadata.geojson"), | ||
"w", | ||
"GeoJSON", | ||
crs="EPSG:4326", | ||
schema=schema, | ||
) as f: | ||
for i in range(NUM_SAMPLES): | ||
f.write( | ||
{ | ||
"geometry": { | ||
"type": "Polygon", | ||
"coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]], | ||
}, | ||
"id": str(i), | ||
"properties": {"Fold": i % 5, "ID_PATCH": i}, | ||
} | ||
) | ||
|
||
filename = "PASTIS-R.zip" | ||
shutil.make_archive(filename.replace(".zip", ""), "zip", ".", "PASTIS-R") | ||
|
||
# Compute checksums | ||
with open(filename, "rb") as f: | ||
md5 = hashlib.md5(f.read()).hexdigest() | ||
print(f"{filename}: {md5}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
import shutil | ||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from _pytest.fixtures import SubRequest | ||
from pytest import MonkeyPatch | ||
from torch.utils.data import ConcatDataset | ||
|
||
import torchgeo.datasets.utils | ||
from torchgeo.datasets import ( | ||
PASTIS, | ||
PASTISInstanceSegmentation, | ||
PASTISSemanticSegmentation, | ||
) | ||
|
||
|
||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: | ||
shutil.copy(url, root) | ||
|
||
|
||
class TestPASTIS: | ||
@pytest.fixture | ||
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> PASTIS: | ||
monkeypatch.setattr(torchgeo.datasets.pastis, "download_url", download_url) | ||
|
||
md5 = "9b11ae132623a0d13f7f0775d2003703" | ||
monkeypatch.setattr(PASTIS, "md5", md5) | ||
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip") | ||
monkeypatch.setattr(PASTIS, "url", url) | ||
root = str(tmp_path) | ||
transforms = nn.Identity() | ||
return PASTIS(root, (0, 1), "s2", transforms, download=True, checksum=True) | ||
|
||
def test_getitem_not_implemented(self, dataset: PASTIS) -> None: | ||
with pytest.raises(NotImplementedError): | ||
dataset[0] | ||
|
||
def test_load_target_not_implemented(self, dataset: PASTIS) -> None: | ||
with pytest.raises(NotImplementedError): | ||
dataset._load_target(0) | ||
|
||
|
||
class TestPASTISSemanticSegmentation: | ||
@pytest.fixture( | ||
params=[ | ||
{"folds": (0, 1), "bands": "s2"}, | ||
{"folds": (0, 1), "bands": "s1a"}, | ||
{"folds": (0, 1), "bands": "s1d"}, | ||
] | ||
) | ||
def dataset( | ||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest | ||
) -> PASTISSemanticSegmentation: | ||
monkeypatch.setattr(torchgeo.datasets.pastis, "download_url", download_url) | ||
|
||
md5 = "2084aaa69ec55da5ddb0be69e1e941fe" | ||
monkeypatch.setattr(PASTIS, "md5", md5) | ||
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip") | ||
monkeypatch.setattr(PASTIS, "url", url) | ||
root = str(tmp_path) | ||
folds = request.param["folds"] | ||
bands = request.param["bands"] | ||
transforms = nn.Identity() | ||
return PASTISSemanticSegmentation( | ||
root, folds, bands, transforms, download=True, checksum=True | ||
) | ||
|
||
def test_getitem(self, dataset: PASTISSemanticSegmentation) -> None: | ||
x = dataset[0] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["image"], torch.Tensor) | ||
assert isinstance(x["mask"], torch.Tensor) | ||
|
||
def test_len(self, dataset: PASTISSemanticSegmentation) -> None: | ||
assert len(dataset) == 2 | ||
|
||
def test_add(self, dataset: PASTISSemanticSegmentation) -> None: | ||
ds = dataset + dataset | ||
assert isinstance(ds, ConcatDataset) | ||
assert len(ds) == 4 | ||
|
||
def test_already_extracted(self, dataset: PASTISSemanticSegmentation) -> None: | ||
PASTISSemanticSegmentation(root=dataset.root, download=True) | ||
|
||
def test_already_downloaded(self, tmp_path: Path) -> None: | ||
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip") | ||
root = str(tmp_path) | ||
shutil.copy(url, root) | ||
PASTISSemanticSegmentation(root) | ||
|
||
def test_not_downloaded(self, tmp_path: Path) -> None: | ||
with pytest.raises(RuntimeError, match="Dataset not found"): | ||
PASTISSemanticSegmentation(str(tmp_path)) | ||
|
||
def test_corrupted(self, tmp_path: Path) -> None: | ||
with open(os.path.join(tmp_path, "PASTIS-R.zip"), "w") as f: | ||
f.write("bad") | ||
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): | ||
PASTISSemanticSegmentation(root=str(tmp_path), checksum=True) | ||
|
||
def test_invalid_fold(self) -> None: | ||
with pytest.raises(AssertionError): | ||
PASTISSemanticSegmentation(folds=(6,)) | ||
|
||
def test_plot(self, dataset: PASTISSemanticSegmentation) -> None: | ||
x = dataset[0].copy() | ||
dataset.plot(x, suptitle="Test") | ||
plt.close() | ||
dataset.plot(x, show_titles=False) | ||
plt.close() | ||
x["prediction"] = x["mask"].clone() | ||
dataset.plot(x) | ||
plt.close() | ||
|
||
|
||
class TestPASTISInstanceSegmentation: | ||
@pytest.fixture( | ||
params=[ | ||
{"folds": (0, 1), "bands": "s2"}, | ||
{"folds": (0, 1), "bands": "s1a"}, | ||
{"folds": (0, 1), "bands": "s1d"}, | ||
] | ||
) | ||
def dataset( | ||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest | ||
) -> PASTISInstanceSegmentation: | ||
monkeypatch.setattr(torchgeo.datasets.pastis, "download_url", download_url) | ||
|
||
md5 = "9b11ae132623a0d13f7f0775d2003703" | ||
monkeypatch.setattr(PASTIS, "md5", md5) | ||
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip") | ||
monkeypatch.setattr(PASTIS, "url", url) | ||
root = str(tmp_path) | ||
folds = request.param["folds"] | ||
bands = request.param["bands"] | ||
transforms = nn.Identity() | ||
return PASTISInstanceSegmentation( | ||
root, folds, bands, transforms, download=True, checksum=True | ||
) | ||
|
||
def test_getitem(self, dataset: PASTISSemanticSegmentation) -> None: | ||
x = dataset[0] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["image"], torch.Tensor) | ||
assert isinstance(x["mask"], torch.Tensor) | ||
assert isinstance(x["boxes"], torch.Tensor) | ||
assert isinstance(x["label"], torch.Tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do eventually add a datamodule (I see the PR) then we'll also want to include the uncompressed files so that the tests don't extract the zip file and create untracked git files.