Skip to content

Commit

Permalink
Move SKIPPD to HF and add forecast task (#1548)
Browse files Browse the repository at this point in the history
* move to hf and adapt tests

* comments

* requested changes

* dim typo

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
nilsleh and adamjstewart authored Sep 24, 2023
1 parent b796bbd commit 242fa90
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 65 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
82 changes: 50 additions & 32 deletions tests/data/skippd/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
# Licensed under the MIT License.

import hashlib
import os
import shutil
import zipfile
from datetime import datetime, timedelta

import h5py
Expand All @@ -17,43 +16,62 @@
NUM_SAMPLES = 3
NUM_CHANNELS = 3
SIZE = 64
TIME_STEPS = 16

np.random.seed(0)

data_dir = "dj417rh1007"
data_file = "2017_2019_images_pv_processed.hdf5"
tasks = ["nowcast", "forecast"]
data_file = "2017_2019_images_pv_processed_{}.hdf5"
splits = ["trainval", "test"]


# Create dataset file
data = np.random.randint(
RGB_MAX, size=(NUM_SAMPLES, SIZE, SIZE, NUM_CHANNELS), dtype=np.int16
)
labels = np.random.random(size=(NUM_SAMPLES))

if __name__ == "__main__":
# Remove old data
if os.path.exists(data_dir):
shutil.rmtree(data_dir)
data = {
"nowcast": np.random.randint(
RGB_MAX, size=(NUM_SAMPLES, SIZE, SIZE, NUM_CHANNELS), dtype=np.int16
),
"forecast": np.random.randint(
RGB_MAX,
size=(NUM_SAMPLES, TIME_STEPS, SIZE, SIZE, NUM_CHANNELS),
dtype=np.int16,
),
}


labels = {
"nowcast": np.random.random(size=(NUM_SAMPLES)),
"forecast": np.random.random(size=(NUM_SAMPLES, TIME_STEPS)),
}

os.makedirs(data_dir)

with h5py.File(os.path.join(data_dir, data_file), "w") as f:
if __name__ == "__main__":
for task in tasks:
with h5py.File(data_file.format(task), "w") as f:
for split in splits:
grp = f.create_group(split)
grp.create_dataset("images_log", data=data[task])
grp.create_dataset("pv_log", data=labels[task])

# create time stamps
for split in splits:
grp = f.create_group(split)
grp.create_dataset("images_log", data=data)
grp.create_dataset("pv_log", data=labels)

# create time stamps
for split in splits:
time_stamps = np.array(
[datetime.now() - timedelta(days=i) for i in range(NUM_SAMPLES)]
)
np.save(os.path.join(data_dir, f"times_{split}.npy"), time_stamps)

# Compress data
shutil.make_archive(data_dir, "zip", ".", data_dir)

# Compute checksums
with open(data_dir + ".zip", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{data_dir}.zip: {md5}")
time_stamps = np.array(
[datetime.now() - timedelta(days=i) for i in range(NUM_SAMPLES)]
)
np.save(f"times_{split}_{task}.npy", time_stamps)

# Compress data
with zipfile.ZipFile(
data_file.format(task).replace(".hdf5", ".zip"), "w"
) as zip:
for file in [
data_file.format(task),
f"times_trainval_{task}.npy",
f"times_test_{task}.npy",
]:
zip.write(file, arcname=file)

# Compute checksums
with open(data_file.format(task).replace(".hdf5", ".zip"), "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{task}: {md5}")
Binary file removed tests/data/skippd/dj417rh1007.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/skippd/times_trainval_forecast.npy
Binary file not shown.
Binary file added tests/data/skippd/times_trainval_nowcast.npy
Binary file not shown.
41 changes: 31 additions & 10 deletions tests/datasets/test_skippd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import builtins
import os
import shutil
from itertools import product
from pathlib import Path
from typing import Any

Expand All @@ -23,21 +24,32 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:


class TestSKIPPD:
@pytest.fixture(params=["trainval", "test"])
@pytest.fixture(params=product(["nowcast", "forecast"], ["trainval", "test"]))
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> SKIPPD:
task, split = request.param

monkeypatch.setattr(torchgeo.datasets.skippd, "download_url", download_url)

md5 = "1133b2de453a9674776abd7519af5051"
md5 = {
"nowcast": "6f5e54906927278b189f9281a2f54f39",
"forecast": "f3b5d7d5c28ba238144fa1e726c46969",
}
monkeypatch.setattr(SKIPPD, "md5", md5)
url = os.path.join("tests", "data", "skippd", "dj417rh1007.zip")
url = os.path.join("tests", "data", "skippd", "{}")
monkeypatch.setattr(SKIPPD, "url", url)
monkeypatch.setattr(plt, "show", lambda *args: None)
root = str(tmp_path)
split = request.param
transforms = nn.Identity()
return SKIPPD(root, split, transforms, download=True, checksum=True)
return SKIPPD(
root=root,
task=task,
split=split,
transforms=transforms,
download=True,
checksum=True,
)

@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
Expand All @@ -62,11 +74,14 @@ def test_mock_missing_module(
def test_already_extracted(self, dataset: SKIPPD) -> None:
SKIPPD(root=dataset.root, download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "skippd", "dj417rh1007.zip")
@pytest.mark.parametrize("task", ["nowcast", "forecast"])
def test_already_downloaded(self, tmp_path: Path, task: str) -> None:
pathname = os.path.join(
"tests", "data", "skippd", f"2017_2019_images_pv_processed_{task}.zip"
)
root = str(tmp_path)
shutil.copy(pathname, root)
SKIPPD(root)
SKIPPD(root=root, task=task)

@pytest.mark.parametrize("index", [0, 1, 2])
def test_getitem(self, dataset: SKIPPD, index: int) -> None:
Expand All @@ -75,7 +90,10 @@ def test_getitem(self, dataset: SKIPPD, index: int) -> None:
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert isinstance(x["date"], str)
assert x["image"].shape == (3, 64, 64)
if dataset.task == "nowcast":
assert x["image"].shape == (3, 64, 64)
else:
assert x["image"].shape == (48, 64, 64)

def test_len(self, dataset: SKIPPD) -> None:
assert len(dataset) == 3
Expand All @@ -93,6 +111,9 @@ def test_plot(self, dataset: SKIPPD) -> None:
plt.close()

sample = dataset[0]
sample["prediction"] = sample["label"]
if dataset.task == "nowcast":
sample["prediction"] = sample["label"]
else:
sample["prediction"] = sample["label"][-1]
dataset.plot(sample)
plt.close()
81 changes: 58 additions & 23 deletions torchgeo/datasets/skippd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
from einops import rearrange
from matplotlib.figure import Figure
from torch import Tensor

Expand All @@ -33,31 +34,46 @@ class SKIPPD(NonGeoDataset):
* fish-eye RGB images (64x64px)
* power output measurements from 30-kW rooftop PV array
* 1-min interval across 3 years (2017-2019)
Nowcast task:
* 349,372 images under the split key *trainval*
* 14,003 images under the split key *test*
Forecast task:
* 130,412 images under the split key *trainval*
* 2,462 images under the split key *test*
* consists of a concatenated RGB time-series of 16
time-steps
If you use this dataset in your research, please cite:
* https://doi.org/10.48550/arXiv.2207.00913
.. versionadded:: 0.5
"""

url = "https://stacks.stanford.edu/object/dj417rh1007"
md5 = "b38d0f322aaeb254445e2edd8bc5d012"

img_file_name = "2017_2019_images_pv_processed.hdf5"
url = "https://huggingface.co/datasets/torchgeo/skippd/resolve/main/{}"
md5 = {
"forecast": "f4f3509ddcc83a55c433be9db2e51077",
"nowcast": "0000761d403e45bb5f86c21d3c69aa80",
}

data_dir = "dj417rh1007"
data_file_name = "2017_2019_images_pv_processed_{}.hdf5"
zipfile_name = "2017_2019_images_pv_processed_{}.zip"

valid_splits = ["trainval", "test"]

valid_tasks = ["nowcast", "forecast"]

dateformat = "%m/%d/%Y, %H:%M:%S"

def __init__(
self,
root: str = "data",
split: str = "trainval",
task: str = "nowcast",
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
download: bool = False,
checksum: bool = False,
Expand All @@ -67,21 +83,27 @@ def __init__(
Args:
root: root directory where dataset can be found
split: one of "trainval", or "test"
task: one fo "nowcast", or "forecast"
transforms: a function/transform that takes an input sample
and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 after downloading files (may be slow)
Raises:
AssertionError: if ``countries`` contains invalid countries
AssertionError: if ``task`` or ``split`` is invalid
ImportError: if h5py is not installed
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
assert (
split in self.valid_splits
), f"Pleas choose one of these valid data splits {self.valid_splits}."
), f"Please choose one of these valid data splits {self.valid_splits}."
self.split = split

assert (
task in self.valid_tasks
), f"Please choose one of these valid tasks {self.valid_tasks}."
self.task = task

self.root = root
self.transforms = transforms
self.download = download
Expand All @@ -105,7 +127,7 @@ def __len__(self) -> int:
import h5py

with h5py.File(
os.path.join(self.root, self.data_dir, self.img_file_name), "r"
os.path.join(self.root, self.data_file_name.format(self.task)), "r"
) as f:
num_datapoints: int = f[self.split]["pv_log"].shape[0]

Expand Down Expand Up @@ -140,12 +162,18 @@ def _load_image(self, index: int) -> Tensor:
import h5py

with h5py.File(
os.path.join(self.root, self.data_dir, self.img_file_name), "r"
os.path.join(self.root, self.data_file_name.format(self.task)), "r"
) as f:
arr = f[self.split]["images_log"][index, :, :, :]
arr = f[self.split]["images_log"][index]

# forecast has dimension [16, 64, 64, 3] but reshape to [48, 64, 64]
# https://github.com/yuhao-nie/Stanford-solar-forecasting-dataset/blob/main/models/SUNSET_forecast.ipynb
if self.task == "forecast":
arr = rearrange(arr, "t h w c-> (t c) h w")
else:
arr = rearrange(arr, "h w c -> c h w")

# put channel first
tensor = torch.from_numpy(arr).permute(2, 0, 1).to(torch.float32)
tensor = torch.from_numpy(arr).to(torch.float32)
return tensor

def _load_features(self, index: int) -> dict[str, Union[str, Tensor]]:
Expand All @@ -160,14 +188,13 @@ def _load_features(self, index: int) -> dict[str, Union[str, Tensor]]:
import h5py

with h5py.File(
os.path.join(self.root, self.data_dir, self.img_file_name), "r"
os.path.join(self.root, self.data_file_name.format(self.task)), "r"
) as f:
label = f[self.split]["pv_log"][index]

path = os.path.join(self.root, self.data_dir, f"times_{self.split}.npy")
path = os.path.join(self.root, f"times_{self.split}_{self.task}.npy")
datestring = np.load(path, allow_pickle=True)[index].strftime(self.dateformat)

# put channel first
features: dict[str, Union[str, Tensor]] = {
"label": torch.tensor(label, dtype=torch.float32),
"date": datestring,
Expand All @@ -181,12 +208,12 @@ def _verify(self) -> None:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
# Check if the extracted files already exist
pathname = os.path.join(self.root, self.data_dir)
pathname = os.path.join(self.root, self.data_file_name.format(self.task))
if os.path.exists(pathname):
return

# Check if the zip files have already been downloaded
pathname = os.path.join(self.root, self.data_dir) + ".zip"
pathname = os.path.join(self.root, self.zipfile_name.format(self.task))
if os.path.exists(pathname):
self._extract()
return
Expand All @@ -210,16 +237,16 @@ def _download(self) -> None:
RuntimeError: if download doesn't work correctly or checksums don't match
"""
download_url(
self.url,
self.url.format(self.zipfile_name.format(self.task)),
self.root,
filename=self.data_dir,
md5=self.md5 if self.checksum else None,
filename=self.zipfile_name.format(self.task),
md5=self.md5[self.task] if self.checksum else None,
)
self._extract()

def _extract(self) -> None:
"""Extract the dataset."""
zipfile_path = os.path.join(self.root, self.data_dir) + ".zip"
zipfile_path = os.path.join(self.root, self.zipfile_name.format(self.task))
extract_archive(zipfile_path, self.root)

def plot(
Expand All @@ -230,6 +257,8 @@ def plot(
) -> Figure:
"""Plot a sample from the dataset.
In the ``forecast`` task the latest image is plotted.
Args:
sample: a sample return by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
Expand All @@ -238,15 +267,21 @@ def plot(
Returns:
a matplotlib Figure with the rendered sample
"""
image, label = sample["image"], sample["label"].item()
if self.task == "nowcast":
image, label = sample["image"].permute(1, 2, 0), sample["label"].item()
else:
image, label = (
sample["image"].permute(1, 2, 0).reshape(64, 64, 3, 16)[:, :, :, -1],
sample["label"][-1].item(),
)

showing_predictions = "prediction" in sample
if showing_predictions:
prediction = sample["prediction"].item()

fig, ax = plt.subplots(1, 1, figsize=(10, 10))

ax.imshow(image.permute(1, 2, 0) / 255)
ax.imshow(image / 255)
ax.axis("off")

if show_titles:
Expand Down

0 comments on commit 242fa90

Please sign in to comment.