From c040e6eb11b659e890805df3dbb3d5caf4510e4a Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 16 Apr 2024 18:38:00 +0200 Subject: [PATCH] Ruff: prefer single quotes over double quotes --- docs/conf.py | 112 ++--- docs/tutorials/custom_raster_dataset.ipynb | 36 +- docs/tutorials/getting_started.ipynb | 18 +- docs/tutorials/indices.ipynb | 16 +- docs/tutorials/pretrained_weights.ipynb | 14 +- docs/tutorials/trainers.ipynb | 16 +- docs/tutorials/transforms.ipynb | 62 +-- experiments/ssl4eo/class_imbalance.py | 26 +- experiments/ssl4eo/compress_dataset.py | 26 +- experiments/ssl4eo/compute_dataset_pca.py | 12 +- .../ssl4eo/compute_dataset_statistics.py | 20 +- experiments/ssl4eo/delete_excess.py | 10 +- experiments/ssl4eo/delete_mismatch.py | 30 +- experiments/ssl4eo/download_ssl4eo.py | 162 +++---- experiments/ssl4eo/flops.py | 6 +- .../ssl4eo/landsat/chip_landsat_benchmark.py | 32 +- .../ssl4eo/landsat/plot_landsat_bands.py | 74 +-- .../ssl4eo/landsat/plot_landsat_timeline.py | 52 +-- .../ssl4eo/plot_example_predictions.py | 42 +- experiments/ssl4eo/sample_conus.py | 44 +- experiments/ssl4eo/sample_ssl4eo.py | 42 +- experiments/torchgeo/benchmark.py | 160 +++---- .../torchgeo/find_optimal_hyperparams.py | 16 +- experiments/torchgeo/plot_bar_chart.py | 46 +- .../torchgeo/plot_dataloader_benchmark.py | 30 +- .../torchgeo/plot_percentage_benchmark.py | 42 +- .../torchgeo/run_benchmarks_experiments.py | 32 +- .../run_chesapeake_cvpr_experiments.py | 56 +-- .../torchgeo/run_chesapeakecvpr_models.py | 94 ++-- experiments/torchgeo/run_cowc_experiments.py | 42 +- .../torchgeo/run_cowc_seed_experiments.py | 44 +- .../torchgeo/run_landcoverai_experiments.py | 52 +-- .../run_landcoverai_seed_experiments.py | 52 +-- .../torchgeo/run_resisc45_experiments.py | 48 +- .../torchgeo/run_so2sat_byol_experiments.py | 50 +-- .../torchgeo/run_so2sat_experiments.py | 48 +- .../torchgeo/run_so2sat_seed_experiments.py | 50 +-- hubconf.py | 4 +- pyproject.toml | 1 + tests/data/agb_live_woody_density/data.py | 70 +-- tests/data/agrifieldnet/data.py | 64 +-- tests/data/airphen/data.py | 20 +- tests/data/astergdem/data.py | 40 +- tests/data/biomassters/data.py | 114 ++--- tests/data/cbf/data.py | 36 +- tests/data/cdl/data.py | 40 +- tests/data/chabud/data.py | 50 +-- tests/data/chesapeake/BAYWIDE/data.py | 28 +- tests/data/cms_mangrove_canopy/data.py | 44 +- tests/data/cowc_counting/data.py | 50 +-- tests/data/cowc_detection/data.py | 50 +-- tests/data/cropharvest/data.py | 136 +++--- tests/data/deepglobelandcover/data.py | 26 +- tests/data/dfc2022/data.py | 84 ++-- tests/data/eddmaps/data.py | 166 +++---- tests/data/enviroatlas/data.py | 388 ++++++++-------- tests/data/esri2020/data.py | 40 +- tests/data/etci2021/data.py | 64 +-- tests/data/eudem/data.py | 40 +- tests/data/eurocrops/data.py | 42 +- tests/data/fire_risk/data.py | 70 +-- tests/data/forestdamage/data.py | 58 +-- tests/data/gbif/data.py | 104 ++--- tests/data/globbiomass/data.py | 38 +- tests/data/inaturalist/data.py | 80 ++-- tests/data/inria/data.py | 28 +- tests/data/l7irish/data.py | 56 +-- tests/data/l8biome/data.py | 78 ++-- tests/data/landcoverai/data.py | 62 +-- tests/data/landcoverai/split.py | 10 +- tests/data/levircd/levircd/data.py | 24 +- tests/data/levircd/levircdplus/data.py | 30 +- tests/data/mapinwild/data.py | 144 +++--- tests/data/millionaid/data.py | 22 +- tests/data/naip/data.py | 34 +- tests/data/nccm/data.py | 30 +- tests/data/nlcd/data.py | 36 +- tests/data/openbuildings/data.py | 50 +-- tests/data/oscd/data.py | 62 +-- tests/data/pastis/data.py | 56 +-- tests/data/prisma/data.py | 28 +- tests/data/raster/data.py | 48 +- .../data.py | 102 ++--- tests/data/reforestree/data.py | 34 +- tests/data/rwanda_field_boundary/data.py | 94 ++-- tests/data/seasonet/data.py | 96 ++-- tests/data/sen12ms/data.py | 54 +-- tests/data/sentinel1/data.py | 36 +- tests/data/sentinel2/data.py | 258 +++++------ tests/data/skippd/data.py | 34 +- tests/data/so2sat/data.py | 22 +- tests/data/south_africa_crop_type/data.py | 62 +-- tests/data/south_america_soybean/data.py | 34 +- tests/data/spacenet/data.py | 160 +++---- tests/data/ssl4eo/l/data.py | 162 +++---- tests/data/ssl4eo/s12/data.py | 154 +++---- tests/data/ssl4eo_benchmark_landsat/data.py | 190 ++++---- tests/data/sustainbench_crop_yield/data.py | 26 +- tests/data/usavars/data.py | 82 ++-- tests/data/vector/data.py | 38 +- tests/data/vhr10/data.py | 52 +-- .../western_usa_live_fuel_moisture/data.py | 352 +++++++-------- tests/data/zuericrop/data.py | 12 +- tests/datamodules/test_chesapeake.py | 10 +- tests/datamodules/test_fair1m.py | 16 +- tests/datamodules/test_geo.py | 82 ++-- tests/datamodules/test_levircd.py | 140 +++--- tests/datamodules/test_oscd.py | 58 +-- tests/datamodules/test_usavars.py | 18 +- tests/datamodules/test_utils.py | 10 +- tests/datamodules/test_xview2.py | 10 +- tests/datasets/test_advance.py | 46 +- tests/datasets/test_agb_live_woody_density.py | 24 +- tests/datasets/test_agrifieldnet.py | 24 +- tests/datasets/test_airphen.py | 18 +- tests/datasets/test_astergdem.py | 18 +- tests/datasets/test_benin_cashews.py | 46 +- tests/datasets/test_bigearthnet.py | 104 ++--- tests/datasets/test_biomassters.py | 16 +- tests/datasets/test_cbf.py | 26 +- tests/datasets/test_cdl.py | 36 +- tests/datasets/test_chabud.py | 48 +- tests/datasets/test_chesapeake.py | 110 ++--- tests/datasets/test_cloud_cover.py | 52 +-- tests/datasets/test_cms_mangrove_canopy.py | 46 +- tests/datasets/test_cowc.py | 76 ++-- tests/datasets/test_cropharvest.py | 40 +- tests/datasets/test_cv4a_kenya_crop_type.py | 56 +-- tests/datasets/test_cyclone.py | 50 +-- tests/datasets/test_deepglobelandcover.py | 30 +- tests/datasets/test_dfc2022.py | 56 +-- tests/datasets/test_eddmaps.py | 8 +- tests/datasets/test_enviroatlas.py | 38 +- tests/datasets/test_esri2020.py | 46 +- tests/datasets/test_etci2021.py | 60 +-- tests/datasets/test_eudem.py | 30 +- tests/datasets/test_eurocrops.py | 32 +- tests/datasets/test_eurosat.py | 50 +-- tests/datasets/test_fair1m.py | 86 ++-- tests/datasets/test_fire_risk.py | 24 +- tests/datasets/test_forestdamage.py | 38 +- tests/datasets/test_gbif.py | 8 +- tests/datasets/test_geo.py | 282 ++++++------ tests/datasets/test_gid15.py | 38 +- tests/datasets/test_globbiomass.py | 30 +- tests/datasets/test_idtrees.py | 90 ++-- tests/datasets/test_inaturalist.py | 8 +- tests/datasets/test_inria.py | 38 +- tests/datasets/test_l7irish.py | 36 +- tests/datasets/test_l8biome.py | 36 +- tests/datasets/test_landcoverai.py | 62 +-- tests/datasets/test_landsat.py | 20 +- tests/datasets/test_levircd.py | 84 ++-- tests/datasets/test_loveda.py | 52 +-- tests/datasets/test_mapinwild.py | 86 ++-- tests/datasets/test_millionaid.py | 26 +- tests/datasets/test_naip.py | 12 +- tests/datasets/test_nasa_marine_debris.py | 40 +- tests/datasets/test_nccm.py | 32 +- tests/datasets/test_nlcd.py | 36 +- tests/datasets/test_openbuildings.py | 44 +- tests/datasets/test_oscd.py | 84 ++-- tests/datasets/test_pastis.py | 56 +-- tests/datasets/test_patternnet.py | 26 +- tests/datasets/test_potsdam.py | 40 +- tests/datasets/test_prisma.py | 12 +- tests/datasets/test_reforestree.py | 42 +- tests/datasets/test_resisc45.py | 44 +- tests/datasets/test_rwanda_field_boundary.py | 80 ++-- tests/datasets/test_seasonet.py | 88 ++-- tests/datasets/test_seco.py | 52 +-- tests/datasets/test_sen12ms.py | 66 +-- tests/datasets/test_sentinel.py | 52 +-- tests/datasets/test_skippd.py | 54 +-- tests/datasets/test_so2sat.py | 42 +- tests/datasets/test_south_africa_crop_type.py | 24 +- tests/datasets/test_south_america_soybean.py | 38 +- tests/datasets/test_spacenet.py | 242 +++++----- tests/datasets/test_splits.py | 70 +-- tests/datasets/test_ssl4eo.py | 84 ++-- tests/datasets/test_ssl4eo_benchmark.py | 70 +-- .../datasets/test_sustainbench_crop_yield.py | 38 +- tests/datasets/test_ucmerced.py | 40 +- tests/datasets/test_usavars.py | 102 ++--- tests/datasets/test_utils.py | 242 +++++----- tests/datasets/test_vaihingen.py | 46 +- tests/datasets/test_vhr10.py | 78 ++-- .../test_western_usa_live_fuel_moisture.py | 38 +- tests/datasets/test_xview2.py | 52 +-- tests/datasets/test_zuericrop.py | 58 +-- tests/models/test_api.py | 6 +- tests/models/test_changestar.py | 46 +- tests/models/test_dofa.py | 22 +- tests/models/test_farseg.py | 14 +- tests/models/test_fcn.py | 2 +- tests/models/test_fcsiam.py | 16 +- tests/models/test_rcf.py | 24 +- tests/models/test_resnet.py | 28 +- tests/models/test_swin.py | 12 +- tests/models/test_vit.py | 14 +- tests/samplers/test_batch.py | 14 +- tests/samplers/test_single.py | 26 +- tests/samplers/test_utils.py | 2 +- tests/test_main.py | 2 +- tests/trainers/conftest.py | 20 +- tests/trainers/test_byol.py | 72 +-- tests/trainers/test_classification.py | 160 +++---- tests/trainers/test_detection.py | 82 ++-- tests/trainers/test_moco.py | 90 ++-- tests/trainers/test_regression.py | 164 +++---- tests/trainers/test_segmentation.py | 146 +++--- tests/trainers/test_simclr.py | 88 ++-- tests/trainers/test_utils.py | 26 +- tests/transforms/test_color.py | 28 +- tests/transforms/test_indices.py | 42 +- tests/transforms/test_transforms.py | 106 ++--- torchgeo/__init__.py | 4 +- torchgeo/datamodules/__init__.py | 94 ++-- torchgeo/datamodules/agrifieldnet.py | 10 +- torchgeo/datamodules/bigearthnet.py | 6 +- torchgeo/datamodules/chabud.py | 12 +- torchgeo/datamodules/chesapeake.py | 34 +- torchgeo/datamodules/cowc.py | 4 +- torchgeo/datamodules/cyclone.py | 10 +- torchgeo/datamodules/deepglobelandcover.py | 10 +- torchgeo/datamodules/etci2021.py | 14 +- torchgeo/datamodules/eurosat.py | 56 +-- torchgeo/datamodules/fair1m.py | 22 +- torchgeo/datamodules/fire_risk.py | 10 +- torchgeo/datamodules/geo.py | 76 ++-- torchgeo/datamodules/gid15.py | 12 +- torchgeo/datamodules/inria.py | 18 +- torchgeo/datamodules/l7irish.py | 10 +- torchgeo/datamodules/l8biome.py | 10 +- torchgeo/datamodules/landcoverai.py | 4 +- torchgeo/datamodules/levircd.py | 20 +- torchgeo/datamodules/loveda.py | 12 +- torchgeo/datamodules/naip.py | 12 +- torchgeo/datamodules/nasa_marine_debris.py | 2 +- torchgeo/datamodules/oscd.py | 64 +-- torchgeo/datamodules/potsdam.py | 10 +- torchgeo/datamodules/resisc45.py | 2 +- torchgeo/datamodules/seco.py | 14 +- torchgeo/datamodules/sen12ms.py | 24 +- torchgeo/datamodules/sentinel2_cdl.py | 16 +- torchgeo/datamodules/sentinel2_eurocrops.py | 16 +- torchgeo/datamodules/sentinel2_nccm.py | 16 +- .../sentinel2_south_america_soybean.py | 16 +- torchgeo/datamodules/skippd.py | 8 +- torchgeo/datamodules/so2sat.py | 52 +-- torchgeo/datamodules/spacenet.py | 6 +- torchgeo/datamodules/ssl4eo_benchmark.py | 8 +- .../datamodules/sustainbench_crop_yield.py | 12 +- torchgeo/datamodules/ucmerced.py | 2 +- torchgeo/datamodules/utils.py | 50 +-- torchgeo/datamodules/vaihingen.py | 10 +- torchgeo/datamodules/vhr10.py | 4 +- torchgeo/datamodules/xview.py | 8 +- torchgeo/datasets/__init__.py | 266 +++++------ torchgeo/datasets/advance.py | 72 +-- torchgeo/datasets/agb_live_woody_density.py | 32 +- torchgeo/datasets/agrifieldnet.py | 76 ++-- torchgeo/datasets/airphen.py | 10 +- torchgeo/datasets/astergdem.py | 22 +- torchgeo/datasets/benin_cashews.py | 266 +++++------ torchgeo/datasets/bigearthnet.py | 268 +++++------ torchgeo/datasets/biomassters.py | 86 ++-- torchgeo/datasets/cbf.py | 80 ++-- torchgeo/datasets/cdl.py | 74 +-- torchgeo/datasets/chabud.py | 82 ++-- torchgeo/datasets/chesapeake.py | 320 ++++++------- torchgeo/datasets/cloud_cover.py | 122 ++--- torchgeo/datasets/cms_mangrove_canopy.py | 282 ++++++------ torchgeo/datasets/cowc.py | 104 ++--- torchgeo/datasets/cropharvest.py | 134 +++--- torchgeo/datasets/cv4a_kenya_crop_type.py | 160 +++---- torchgeo/datasets/cyclone.py | 76 ++-- torchgeo/datasets/deepglobelandcover.py | 62 +-- torchgeo/datasets/dfc2022.py | 174 ++++---- torchgeo/datasets/eddmaps.py | 12 +- torchgeo/datasets/enviroatlas.py | 252 +++++------ torchgeo/datasets/esri2020.py | 30 +- torchgeo/datasets/etci2021.py | 110 ++--- torchgeo/datasets/eudem.py | 82 ++-- torchgeo/datasets/eurocrops.py | 92 ++-- torchgeo/datasets/eurosat.py | 100 ++--- torchgeo/datasets/fair1m.py | 230 +++++----- torchgeo/datasets/fire_risk.py | 42 +- torchgeo/datasets/forestdamage.py | 74 +-- torchgeo/datasets/gbif.py | 12 +- torchgeo/datasets/geo.py | 84 ++-- torchgeo/datasets/gid15.py | 98 ++-- torchgeo/datasets/globbiomass.py | 166 +++---- torchgeo/datasets/idtrees.py | 226 +++++----- torchgeo/datasets/inaturalist.py | 16 +- torchgeo/datasets/inria.py | 72 +-- torchgeo/datasets/l7irish.py | 72 +-- torchgeo/datasets/l8biome.py | 66 +-- torchgeo/datasets/landcoverai.py | 84 ++-- torchgeo/datasets/landsat.py | 48 +- torchgeo/datasets/levircd.py | 104 ++--- torchgeo/datasets/loveda.py | 92 ++-- torchgeo/datasets/mapinwild.py | 140 +++--- torchgeo/datasets/millionaid.py | 302 ++++++------- torchgeo/datasets/naip.py | 12 +- torchgeo/datasets/nasa_marine_debris.py | 62 +-- torchgeo/datasets/nccm.py | 50 +-- torchgeo/datasets/nlcd.py | 58 +-- torchgeo/datasets/openbuildings.py | 328 +++++++------- torchgeo/datasets/oscd.py | 114 ++--- torchgeo/datasets/pastis.py | 138 +++--- torchgeo/datasets/patternnet.py | 22 +- torchgeo/datasets/potsdam.py | 132 +++--- torchgeo/datasets/prisma.py | 10 +- torchgeo/datasets/reforestree.py | 54 +-- torchgeo/datasets/resisc45.py | 44 +- torchgeo/datasets/rwanda_field_boundary.py | 94 ++-- torchgeo/datasets/seasonet.py | 202 ++++----- torchgeo/datasets/seco.py | 76 ++-- torchgeo/datasets/sen12ms.py | 208 ++++----- torchgeo/datasets/sentinel.py | 66 +-- torchgeo/datasets/skippd.py | 72 +-- torchgeo/datasets/so2sat.py | 168 +++---- torchgeo/datasets/south_africa_crop_type.py | 90 ++-- torchgeo/datasets/south_america_soybean.py | 70 +-- torchgeo/datasets/spacenet.py | 418 ++++++++--------- torchgeo/datasets/splits.py | 26 +- torchgeo/datasets/ssl4eo.py | 232 +++++----- torchgeo/datasets/ssl4eo_benchmark.py | 136 +++--- torchgeo/datasets/sustainbench_crop_yield.py | 64 +-- torchgeo/datasets/ucmerced.py | 44 +- torchgeo/datasets/usavars.py | 90 ++-- torchgeo/datasets/utils.py | 124 +++--- torchgeo/datasets/vaihingen.py | 122 ++--- torchgeo/datasets/vhr10.py | 202 ++++----- .../western_usa_live_fuel_moisture.py | 304 ++++++------- torchgeo/datasets/xview.py | 88 ++-- torchgeo/datasets/zuericrop.py | 54 +-- torchgeo/losses/__init__.py | 2 +- torchgeo/losses/qr.py | 4 +- torchgeo/main.py | 12 +- torchgeo/models/__init__.py | 54 +-- torchgeo/models/api.py | 16 +- torchgeo/models/changestar.py | 34 +- torchgeo/models/dofa.py | 46 +- torchgeo/models/farseg.py | 16 +- torchgeo/models/fcsiam.py | 10 +- torchgeo/models/rcf.py | 24 +- torchgeo/models/resnet.py | 420 +++++++++--------- torchgeo/models/swin.py | 90 ++-- torchgeo/models/vit.py | 176 ++++---- torchgeo/samplers/__init__.py | 18 +- torchgeo/trainers/__init__.py | 20 +- torchgeo/trainers/base.py | 14 +- torchgeo/trainers/byol.py | 16 +- torchgeo/trainers/classification.py | 138 +++--- torchgeo/trainers/detection.py | 124 +++--- torchgeo/trainers/moco.py | 90 ++-- torchgeo/trainers/regression.py | 104 ++--- torchgeo/trainers/segmentation.py | 102 ++--- torchgeo/trainers/simclr.py | 66 +-- torchgeo/trainers/utils.py | 52 +-- torchgeo/transforms/__init__.py | 32 +- torchgeo/transforms/color.py | 4 +- torchgeo/transforms/indices.py | 14 +- torchgeo/transforms/transforms.py | 70 +-- 366 files changed, 12190 insertions(+), 12189 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d92bd228940..36af7d3d275 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -17,16 +17,16 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath("..")) +sys.path.insert(0, os.path.abspath('..')) import torchgeo # noqa: E402 # -- Project information ----------------------------------------------------- -project = "torchgeo" -copyright = "2021, Microsoft Corporation" +project = 'torchgeo' +copyright = '2021, Microsoft Corporation' author = torchgeo.__author__ -version = ".".join(torchgeo.__version__.split(".")[:2]) +version = '.'.join(torchgeo.__version__.split('.')[:2]) release = torchgeo.__version__ @@ -36,38 +36,38 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - "sphinx.ext.autodoc", - "sphinx.ext.intersphinx", - "sphinx.ext.mathjax", - "sphinx.ext.napoleon", - "sphinx.ext.todo", - "sphinx.ext.viewcode", - "nbsphinx", + 'sphinx.ext.autodoc', + 'sphinx.ext.intersphinx', + 'sphinx.ext.mathjax', + 'sphinx.ext.napoleon', + 'sphinx.ext.todo', + 'sphinx.ext.viewcode', + 'nbsphinx', ] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ["_build"] +exclude_patterns = ['_build'] # Sphinx 3.0+ required for: # autodoc_typehints_description_target = "documented" -needs_sphinx = "4.0" +needs_sphinx = '4.0' nitpicky = True nitpick_ignore = [ # Undocumented classes - ("py:class", "fiona.model.Feature"), - ("py:class", "kornia.augmentation._2d.intensity.base.IntensityAugmentationBase2D"), - ("py:class", "kornia.augmentation.base._AugmentationBase"), - ("py:class", "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig"), - ("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"), - ("py:class", "timm.models.resnet.ResNet"), - ("py:class", "timm.models.vision_transformer.VisionTransformer"), - ("py:class", "torch.optim.lr_scheduler.LRScheduler"), - ("py:class", "torchvision.models._api.WeightsEnum"), - ("py:class", "torchvision.models.resnet.ResNet"), - ("py:class", "torchvision.models.swin_transformer.SwinTransformer"), + ('py:class', 'fiona.model.Feature'), + ('py:class', 'kornia.augmentation._2d.intensity.base.IntensityAugmentationBase2D'), + ('py:class', 'kornia.augmentation.base._AugmentationBase'), + ('py:class', 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig'), + ('py:class', 'segmentation_models_pytorch.base.model.SegmentationModel'), + ('py:class', 'timm.models.resnet.ResNet'), + ('py:class', 'timm.models.vision_transformer.VisionTransformer'), + ('py:class', 'torch.optim.lr_scheduler.LRScheduler'), + ('py:class', 'torchvision.models._api.WeightsEnum'), + ('py:class', 'torchvision.models.resnet.ResNet'), + ('py:class', 'torchvision.models.swin_transformer.SwinTransformer'), ] @@ -75,58 +75,58 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = "pytorch_sphinx_theme" +html_theme = 'pytorch_sphinx_theme' html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. html_theme_options = { - "collapse_navigation": False, - "display_version": True, - "logo_only": True, - "pytorch_project": "docs", - "navigation_with_keys": True, - "analytics_id": "UA-209075005-1", + 'collapse_navigation': False, + 'display_version': True, + 'logo_only': True, + 'pytorch_project': 'docs', + 'navigation_with_keys': True, + 'analytics_id': 'UA-209075005-1', } -html_favicon = os.path.join("..", "logo", "favicon.ico") +html_favicon = os.path.join('..', 'logo', 'favicon.ico') -html_static_path = ["_static"] -html_css_files = ["button-width.css", "notebook-prompt.css", "table-scroll.css"] +html_static_path = ['_static'] +html_css_files = ['button-width.css', 'notebook-prompt.css', 'table-scroll.css'] # -- Extension configuration ------------------------------------------------- # sphinx.ext.autodoc autodoc_default_options = { - "members": True, - "special-members": True, - "show-inheritance": True, + 'members': True, + 'special-members': True, + 'show-inheritance': True, } -autodoc_member_order = "bysource" -autodoc_typehints = "description" -autodoc_typehints_description_target = "documented" +autodoc_member_order = 'bysource' +autodoc_typehints = 'description' +autodoc_typehints_description_target = 'documented' # sphinx.ext.intersphinx intersphinx_mapping = { - "kornia": ("https://kornia.readthedocs.io/en/stable/", None), - "matplotlib": ("https://matplotlib.org/stable/", None), - "numpy": ("https://numpy.org/doc/stable/", None), - "python": ("https://docs.python.org/3", None), - "lightning": ("https://lightning.ai/docs/pytorch/stable/", None), - "pyvista": ("https://docs.pyvista.org/version/stable/", None), - "rasterio": ("https://rasterio.readthedocs.io/en/stable/", None), - "rtree": ("https://rtree.readthedocs.io/en/stable/", None), - "segmentation_models_pytorch": ("https://smp.readthedocs.io/en/stable/", None), - "sklearn": ("https://scikit-learn.org/stable/", None), - "timm": ("https://huggingface.co/docs/timm/main/en/", None), - "torch": ("https://pytorch.org/docs/stable", None), - "torchmetrics": ("https://lightning.ai/docs/torchmetrics/stable/", None), - "torchvision": ("https://pytorch.org/vision/stable", None), + 'kornia': ('https://kornia.readthedocs.io/en/stable/', None), + 'matplotlib': ('https://matplotlib.org/stable/', None), + 'numpy': ('https://numpy.org/doc/stable/', None), + 'python': ('https://docs.python.org/3', None), + 'lightning': ('https://lightning.ai/docs/pytorch/stable/', None), + 'pyvista': ('https://docs.pyvista.org/version/stable/', None), + 'rasterio': ('https://rasterio.readthedocs.io/en/stable/', None), + 'rtree': ('https://rtree.readthedocs.io/en/stable/', None), + 'segmentation_models_pytorch': ('https://smp.readthedocs.io/en/stable/', None), + 'sklearn': ('https://scikit-learn.org/stable/', None), + 'timm': ('https://huggingface.co/docs/timm/main/en/', None), + 'torch': ('https://pytorch.org/docs/stable', None), + 'torchmetrics': ('https://lightning.ai/docs/torchmetrics/stable/', None), + 'torchvision': ('https://pytorch.org/vision/stable', None), } # nbsphinx -nbsphinx_execute = "never" +nbsphinx_execute = 'never' # TODO: branch/tag should change depending on which version of docs you look at # TODO: width option of image directive is broken, see: # https://github.com/pytorch/pytorch_sphinx_theme/issues/140 @@ -165,4 +165,4 @@ # Disables requirejs in nbsphinx to enable compatibility with the pytorch_sphinx_theme # See more information here https://github.com/spatialaudio/nbsphinx/issues/599 # NOTE: This will likely break nbsphinx widgets -nbsphinx_requirejs_path = "" +nbsphinx_requirejs_path = '' diff --git a/docs/tutorials/custom_raster_dataset.ipynb b/docs/tutorials/custom_raster_dataset.ipynb index e994897fe26..77b44898e6d 100644 --- a/docs/tutorials/custom_raster_dataset.ipynb +++ b/docs/tutorials/custom_raster_dataset.ipynb @@ -131,7 +131,7 @@ "from torchgeo.samplers import RandomGeoSampler\n", "\n", "%matplotlib inline\n", - "plt.rcParams[\"figure.figsize\"] = (12, 12)" + "plt.rcParams['figure.figsize'] = (12, 12)" ] }, { @@ -248,18 +248,18 @@ }, "outputs": [], "source": [ - "root = os.path.join(tempfile.gettempdir(), \"sentinel\")\n", + "root = os.path.join(tempfile.gettempdir(), 'sentinel')\n", "item_urls = [\n", - " \"https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items/S2B_MSIL2A_20220902T090559_R050_T40XDH_20220902T181115\",\n", - " \"https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items/S2B_MSIL2A_20220718T084609_R107_T40XEJ_20220718T175008\",\n", + " 'https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items/S2B_MSIL2A_20220902T090559_R050_T40XDH_20220902T181115',\n", + " 'https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items/S2B_MSIL2A_20220718T084609_R107_T40XEJ_20220718T175008',\n", "]\n", "\n", "for item_url in item_urls:\n", " item = pystac.Item.from_file(item_url)\n", " signed_item = planetary_computer.sign(item)\n", - " for band in [\"B02\", \"B03\", \"B04\", \"B08\"]:\n", + " for band in ['B02', 'B03', 'B04', 'B08']:\n", " asset_href = signed_item.assets[band].href\n", - " filename = urlparse(asset_href).path.split(\"/\")[-1]\n", + " filename = urlparse(asset_href).path.split('/')[-1]\n", " download_url(asset_href, root, filename)" ] }, @@ -360,13 +360,13 @@ "outputs": [], "source": [ "class Sentinel2(RasterDataset):\n", - " filename_glob = \"T*_B02_10m.tif\"\n", - " filename_regex = r\"^.{6}_(?P\\d{8}T\\d{6})_(?PB0[\\d])\"\n", - " date_format = \"%Y%m%dT%H%M%S\"\n", + " filename_glob = 'T*_B02_10m.tif'\n", + " filename_regex = r'^.{6}_(?P\\d{8}T\\d{6})_(?PB0[\\d])'\n", + " date_format = '%Y%m%dT%H%M%S'\n", " is_image = True\n", " separate_files = True\n", - " all_bands = [\"B02\", \"B03\", \"B04\", \"B08\"]\n", - " rgb_bands = [\"B04\", \"B03\", \"B02\"]" + " all_bands = ['B02', 'B03', 'B04', 'B08']\n", + " rgb_bands = ['B04', 'B03', 'B02']" ] }, { @@ -423,13 +423,13 @@ "outputs": [], "source": [ "class Sentinel2(RasterDataset):\n", - " filename_glob = \"T*_B02_10m.tif\"\n", - " filename_regex = r\"^.{6}_(?P\\d{8}T\\d{6})_(?PB0[\\d])\"\n", - " date_format = \"%Y%m%dT%H%M%S\"\n", + " filename_glob = 'T*_B02_10m.tif'\n", + " filename_regex = r'^.{6}_(?P\\d{8}T\\d{6})_(?PB0[\\d])'\n", + " date_format = '%Y%m%dT%H%M%S'\n", " is_image = True\n", " separate_files = True\n", - " all_bands = [\"B02\", \"B03\", \"B04\", \"B08\"]\n", - " rgb_bands = [\"B04\", \"B03\", \"B02\"]\n", + " all_bands = ['B02', 'B03', 'B04', 'B08']\n", + " rgb_bands = ['B04', 'B03', 'B02']\n", "\n", " def plot(self, sample):\n", " # Find the correct band index order\n", @@ -438,7 +438,7 @@ " rgb_indices.append(self.all_bands.index(band))\n", "\n", " # Reorder and rescale the image\n", - " image = sample[\"image\"][rgb_indices].permute(1, 2, 0)\n", + " image = sample['image'][rgb_indices].permute(1, 2, 0)\n", " image = torch.clamp(image / 10000, min=0, max=1).numpy()\n", "\n", " # Plot the image\n", @@ -479,7 +479,7 @@ "for batch in dataloader:\n", " sample = unbind_samples(batch)[0]\n", " dataset.plot(sample)\n", - " plt.axis(\"off\")\n", + " plt.axis('off')\n", " plt.show()" ] }, diff --git a/docs/tutorials/getting_started.ipynb b/docs/tutorials/getting_started.ipynb index bfae62e29a7..1b1982711b3 100644 --- a/docs/tutorials/getting_started.ipynb +++ b/docs/tutorials/getting_started.ipynb @@ -149,15 +149,15 @@ }, "outputs": [], "source": [ - "naip_root = os.path.join(tempfile.gettempdir(), \"naip\")\n", + "naip_root = os.path.join(tempfile.gettempdir(), 'naip')\n", "naip_url = (\n", - " \"https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/\"\n", + " 'https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/'\n", ")\n", "tiles = [\n", - " \"m_3807511_ne_18_060_20181104.tif\",\n", - " \"m_3807511_se_18_060_20181104.tif\",\n", - " \"m_3807512_nw_18_060_20180815.tif\",\n", - " \"m_3807512_sw_18_060_20180815.tif\",\n", + " 'm_3807511_ne_18_060_20181104.tif',\n", + " 'm_3807511_se_18_060_20181104.tif',\n", + " 'm_3807512_nw_18_060_20180815.tif',\n", + " 'm_3807512_sw_18_060_20180815.tif',\n", "]\n", "for tile in tiles:\n", " download_url(naip_url + tile, naip_root)\n", @@ -188,7 +188,7 @@ }, "outputs": [], "source": [ - "chesapeake_root = os.path.join(tempfile.gettempdir(), \"chesapeake\")\n", + "chesapeake_root = os.path.join(tempfile.gettempdir(), 'chesapeake')\n", "os.makedirs(chesapeake_root, exist_ok=True)\n", "chesapeake = ChesapeakeDE(chesapeake_root, crs=naip.crs, res=naip.res, download=True)" ] @@ -285,8 +285,8 @@ "outputs": [], "source": [ "for sample in dataloader:\n", - " image = sample[\"image\"]\n", - " target = sample[\"mask\"]" + " image = sample['image']\n", + " target = sample['mask']" ] } ], diff --git a/docs/tutorials/indices.ipynb b/docs/tutorials/indices.ipynb index 3b2c1c410f5..30576609ac8 100644 --- a/docs/tutorials/indices.ipynb +++ b/docs/tutorials/indices.ipynb @@ -171,7 +171,7 @@ }, "outputs": [], "source": [ - "root = os.path.join(tempfile.gettempdir(), \"eurosat100\")\n", + "root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", "ds = EuroSAT100(root, download=True)\n", "sample = ds[21]" ] @@ -247,14 +247,14 @@ "source": [ "# NDVI is appended to channel dimension (dim=0)\n", "index = AppendNDVI(index_nir=7, index_red=3)\n", - "image = sample[\"image\"]\n", + "image = sample['image']\n", "image = index(image)[0]\n", "\n", "# Normalize from [-1, 1] -> [0, 1] for visualization\n", "image[-1] = (image[-1] + 1) / 2\n", "\n", - "plt.imshow(image[-1], cmap=\"RdYlGn\")\n", - "plt.axis(\"off\")\n", + "plt.imshow(image[-1], cmap='RdYlGn')\n", + "plt.axis('off')\n", "plt.show()\n", "plt.close()" ] @@ -299,8 +299,8 @@ "# Normalize from [-1, 1] -> [0, 1] for visualization\n", "image[-1] = (image[-1] + 1) / 2\n", "\n", - "plt.imshow(image[-1], cmap=\"BrBG\")\n", - "plt.axis(\"off\")\n", + "plt.imshow(image[-1], cmap='BrBG')\n", + "plt.axis('off')\n", "plt.show()\n", "plt.close()" ] @@ -345,8 +345,8 @@ "# Normalize from [-1, 1] -> [0, 1] for visualization\n", "image[-1] = (image[-1] + 1) / 2\n", "\n", - "plt.imshow(image[-1], cmap=\"terrain\")\n", - "plt.axis(\"off\")\n", + "plt.imshow(image[-1], cmap='terrain')\n", + "plt.axis('off')\n", "plt.show()\n", "plt.close()" ] diff --git a/docs/tutorials/pretrained_weights.ipynb b/docs/tutorials/pretrained_weights.ipynb index c69a7fd3cb1..28b354efee6 100644 --- a/docs/tutorials/pretrained_weights.ipynb +++ b/docs/tutorials/pretrained_weights.ipynb @@ -133,7 +133,7 @@ }, "outputs": [], "source": [ - "root = os.path.join(tempfile.gettempdir(), \"eurosat100\")\n", + "root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", "datamodule = EuroSAT100DataModule(\n", " root=root, batch_size=batch_size, num_workers=num_workers, download=True\n", ")" @@ -199,8 +199,8 @@ "outputs": [], "source": [ "task = ClassificationTask(\n", - " model=\"resnet18\",\n", - " loss=\"ce\",\n", + " model='resnet18',\n", + " loss='ce',\n", " weights=weights,\n", " in_channels=13,\n", " num_classes=10,\n", @@ -226,8 +226,8 @@ }, "outputs": [], "source": [ - "in_chans = weights.meta[\"in_chans\"]\n", - "model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=10)\n", + "in_chans = weights.meta['in_chans']\n", + "model = timm.create_model('resnet18', in_chans=in_chans, num_classes=10)\n", "model.load_state_dict(weights.get_state_dict(progress=True), strict=False)" ] }, @@ -250,8 +250,8 @@ }, "outputs": [], "source": [ - "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"\n", - "default_root_dir = os.path.join(tempfile.gettempdir(), \"experiments\")" + "accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n", + "default_root_dir = os.path.join(tempfile.gettempdir(), 'experiments')" ] }, { diff --git a/docs/tutorials/trainers.ipynb b/docs/tutorials/trainers.ipynb index 3e05930ffd2..5de3937a026 100644 --- a/docs/tutorials/trainers.ipynb +++ b/docs/tutorials/trainers.ipynb @@ -143,7 +143,7 @@ }, "outputs": [], "source": [ - "root = os.path.join(tempfile.gettempdir(), \"eurosat100\")\n", + "root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", "datamodule = EuroSAT100DataModule(\n", " root=root, batch_size=batch_size, num_workers=num_workers, download=True\n", ")" @@ -169,8 +169,8 @@ "outputs": [], "source": [ "task = ClassificationTask(\n", - " loss=\"ce\",\n", - " model=\"resnet18\",\n", + " loss='ce',\n", + " model='resnet18',\n", " weights=ResNet18_Weights.SENTINEL2_ALL_MOCO,\n", " in_channels=13,\n", " num_classes=10,\n", @@ -200,13 +200,13 @@ }, "outputs": [], "source": [ - "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"\n", - "default_root_dir = os.path.join(tempfile.gettempdir(), \"experiments\")\n", + "accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n", + "default_root_dir = os.path.join(tempfile.gettempdir(), 'experiments')\n", "checkpoint_callback = ModelCheckpoint(\n", - " monitor=\"val_loss\", dirpath=default_root_dir, save_top_k=1, save_last=True\n", + " monitor='val_loss', dirpath=default_root_dir, save_top_k=1, save_last=True\n", ")\n", - "early_stopping_callback = EarlyStopping(monitor=\"val_loss\", min_delta=0.00, patience=10)\n", - "logger = TensorBoardLogger(save_dir=default_root_dir, name=\"tutorial_logs\")" + "early_stopping_callback = EarlyStopping(monitor='val_loss', min_delta=0.00, patience=10)\n", + "logger = TensorBoardLogger(save_dir=default_root_dir, name='tutorial_logs')" ] }, { diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb index 64246924aed..a7de9f32c69 100644 --- a/docs/tutorials/transforms.ipynb +++ b/docs/tutorials/transforms.ipynb @@ -127,7 +127,7 @@ "\n", " def __init__(self, mins: Tensor, maxs: Tensor) -> None:\n", " super().__init__(p=1)\n", - " self.flags = {\"mins\": mins.view(1, -1, 1, 1), \"maxs\": maxs.view(1, -1, 1, 1)}\n", + " self.flags = {'mins': mins.view(1, -1, 1, 1), 'maxs': maxs.view(1, -1, 1, 1)}\n", "\n", " def apply_transform(\n", " self,\n", @@ -136,7 +136,7 @@ " flags: dict[str, int],\n", " transform: Tensor | None = None,\n", " ) -> Tensor:\n", - " return (input - flags[\"mins\"]) / (flags[\"maxs\"] - flags[\"mins\"] + 1e-10)" + " return (input - flags['mins']) / (flags['maxs'] - flags['mins'] + 1e-10)" ] }, { @@ -200,19 +200,19 @@ " ]\n", ")\n", "bands = {\n", - " \"B01\": \"Coastal Aerosol\",\n", - " \"B02\": \"Blue\",\n", - " \"B03\": \"Green\",\n", - " \"B04\": \"Red\",\n", - " \"B05\": \"Vegetation Red Edge 1\",\n", - " \"B06\": \"Vegetation Red Edge 2\",\n", - " \"B07\": \"Vegetation Red Edge 3\",\n", - " \"B08\": \"NIR 1\",\n", - " \"B8A\": \"NIR 2\",\n", - " \"B09\": \"Water Vapour\",\n", - " \"B10\": \"SWIR 1\",\n", - " \"B11\": \"SWIR 2\",\n", - " \"B12\": \"SWIR 3\",\n", + " 'B01': 'Coastal Aerosol',\n", + " 'B02': 'Blue',\n", + " 'B03': 'Green',\n", + " 'B04': 'Red',\n", + " 'B05': 'Vegetation Red Edge 1',\n", + " 'B06': 'Vegetation Red Edge 2',\n", + " 'B07': 'Vegetation Red Edge 3',\n", + " 'B08': 'NIR 1',\n", + " 'B8A': 'NIR 2',\n", + " 'B09': 'Water Vapour',\n", + " 'B10': 'SWIR 1',\n", + " 'B11': 'SWIR 2',\n", + " 'B12': 'SWIR 3',\n", "}" ] }, @@ -320,14 +320,14 @@ }, "outputs": [], "source": [ - "root = os.path.join(tempfile.gettempdir(), \"eurosat100\")\n", + "root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", "dataset = EuroSAT100(root, download=True)\n", "dataloader = DataLoader(\n", " dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers\n", ")\n", "dataloader = iter(dataloader)\n", - "print(f\"Number of images in dataset: {len(dataset)}\")\n", - "print(f\"Dataset Classes: {dataset.classes}\")" + "print(f'Number of images in dataset: {len(dataset)}')\n", + "print(f'Dataset Classes: {dataset.classes}')" ] }, { @@ -361,7 +361,7 @@ "outputs": [], "source": [ "sample = dataset[0]\n", - "x, y = sample[\"image\"], sample[\"label\"]\n", + "x, y = sample['image'], sample['label']\n", "print(x.shape, x.dtype, x.min(), x.max())\n", "print(y, dataset.classes[y])" ] @@ -388,7 +388,7 @@ "outputs": [], "source": [ "batch = next(dataloader)\n", - "x, y = batch[\"image\"], batch[\"label\"]\n", + "x, y = batch['image'], batch['label']\n", "print(x.shape, x.dtype, x.min(), x.max())\n", "print(y, [dataset.classes[i] for i in y])" ] @@ -452,7 +452,7 @@ "source": [ "transform = indices.AppendNDVI(index_nir=7, index_red=3)\n", "batch = next(dataloader)\n", - "x = batch[\"image\"]\n", + "x = batch['image']\n", "print(x.shape)\n", "x = transform(x)\n", "print(x.shape)" @@ -488,7 +488,7 @@ ")\n", "\n", "batch = next(dataloader)\n", - "x = batch[\"image\"]\n", + "x = batch['image']\n", "print(x.shape)\n", "x = transforms(x)\n", "print(x.shape)" @@ -523,13 +523,13 @@ " indices.AppendNDWI(index_green=2, index_nir=7),\n", " K.RandomHorizontalFlip(p=0.5),\n", " K.RandomVerticalFlip(p=0.5),\n", - " data_keys=[\"image\"],\n", + " data_keys=['image'],\n", ")\n", "\n", "batch = next(dataloader)\n", - "print(batch[\"image\"].shape)\n", + "print(batch['image'].shape)\n", "batch = transforms(batch)\n", - "print(batch[\"image\"].shape)" + "print(batch['image'].shape)" ] }, { @@ -567,7 +567,7 @@ }, "outputs": [], "source": [ - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "transforms = AugmentationSequential(\n", " MinMaxNormalize(mins, maxs),\n", @@ -580,7 +580,7 @@ " K.RandomAffine(degrees=(0, 90), p=0.25),\n", " K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n", " K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n", - " data_keys=[\"image\"],\n", + " data_keys=['image'],\n", ")\n", "\n", "transforms_gpu = AugmentationSequential(\n", @@ -594,12 +594,12 @@ " K.RandomAffine(degrees=(0, 90), p=0.25),\n", " K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n", " K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n", - " data_keys=[\"image\"],\n", + " data_keys=['image'],\n", ").to(device)\n", "\n", "\n", "def get_batch_cpu():\n", - " return dict(image=torch.randn(64, 13, 512, 512).to(\"cpu\"))\n", + " return dict(image=torch.randn(64, 13, 512, 512).to('cpu'))\n", "\n", "\n", "def get_batch_gpu():\n", @@ -664,7 +664,7 @@ }, "outputs": [], "source": [ - "transforms = AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=[\"image\"])\n", + "transforms = AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=['image'])\n", "dataset = EuroSAT100(root, transforms=transforms)" ] }, @@ -684,7 +684,7 @@ "# @title EuroSat Multispectral (MS) Browser { run: \"auto\", vertical-output: true }\n", "idx = 21 # @param {type:\"slider\", min:0, max:59, step:1}\n", "sample = dataset[idx]\n", - "rgb = sample[\"image\"][0, 1:4]\n", + "rgb = sample['image'][0, 1:4]\n", "image = T.ToPILImage()(rgb)\n", "print(f\"Class Label: {dataset.classes[sample['label']]}\")\n", "image.resize((256, 256), resample=Image.BILINEAR)" diff --git a/experiments/ssl4eo/class_imbalance.py b/experiments/ssl4eo/class_imbalance.py index 64004b733df..faa280432b7 100755 --- a/experiments/ssl4eo/class_imbalance.py +++ b/experiments/ssl4eo/class_imbalance.py @@ -12,30 +12,30 @@ from tqdm import tqdm from tqdm.contrib.concurrent import thread_map -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("roots", nargs="+", help="directories to search for files") - parser.add_argument("--suffix", default=".tif", help="file suffix") - parser.add_argument("--sort", action="store_true", help="sort by class frequency") + parser.add_argument('roots', nargs='+', help='directories to search for files') + parser.add_argument('--suffix', default='.tif', help='file suffix') + parser.add_argument('--sort', action='store_true', help='sort by class frequency') parser.add_argument( - "--weights", action="store_true", help="print weights instead of ratios" + '--weights', action='store_true', help='print weights instead of ratios' ) parser.add_argument( - "--total-classes", type=int, default=256, help="total number of classes" + '--total-classes', type=int, default=256, help='total number of classes' ) parser.add_argument( - "--keep-classes", + '--keep-classes', type=float, default=1, - help="keep classes with percentage higher than this", + help='keep classes with percentage higher than this', ) parser.add_argument( - "--ignore-index", type=int, default=0, help="fill value to ignore" + '--ignore-index', type=int, default=0, help='fill value to ignore' ) - parser.add_argument("--num-workers", type=int, default=10, help="number of threads") + parser.add_argument('--num-workers', type=int, default=10, help='number of threads') args = parser.parse_args() - def class_counts(path: str) -> "np.typing.NDArray[np.float64]": + def class_counts(path: str) -> 'np.typing.NDArray[np.float64]': """Calculate the number of values in each class. Args: @@ -47,7 +47,7 @@ def class_counts(path: str) -> "np.typing.NDArray[np.float64]": global args counts = np.zeros(args.total_classes) - with rio.open(path, "r") as src: + with rio.open(path, 'r') as src: x = src.read() unique, unique_counts = np.unique(x, return_counts=True) counts[unique] = unique_counts @@ -57,7 +57,7 @@ def class_counts(path: str) -> "np.typing.NDArray[np.float64]": paths = [] for root in args.roots: paths.extend( - glob.glob(os.path.join(root, "**", f"*{args.suffix}"), recursive=True) + glob.glob(os.path.join(root, '**', f'*{args.suffix}'), recursive=True) ) if args.num_workers > 0: diff --git a/experiments/ssl4eo/compress_dataset.py b/experiments/ssl4eo/compress_dataset.py index 31146eb2943..5655ab41713 100755 --- a/experiments/ssl4eo/compress_dataset.py +++ b/experiments/ssl4eo/compress_dataset.py @@ -12,20 +12,20 @@ from tqdm import tqdm from tqdm.contrib.concurrent import thread_map -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() # Can be same directory for in-place compression - parser.add_argument("src_dir", help="directory to recursively search for files") - parser.add_argument("dst_dir", help="directory to save compressed files in") - parser.add_argument("--suffix", default=".tif", help="file suffix") + parser.add_argument('src_dir', help='directory to recursively search for files') + parser.add_argument('dst_dir', help='directory to save compressed files in') + parser.add_argument('--suffix', default='.tif', help='file suffix') # Could be min/max, 2%/98%, mean ± 2 * std, etc. parser.add_argument( - "--min", nargs="+", type=float, required=True, help="minimum range" + '--min', nargs='+', type=float, required=True, help='minimum range' ) parser.add_argument( - "--max", nargs="+", type=float, required=True, help="maximum range" + '--max', nargs='+', type=float, required=True, help='maximum range' ) - parser.add_argument("--num-workers", type=int, default=10, help="number of threads") + parser.add_argument('--num-workers', type=int, default=10, help='number of threads') args = parser.parse_args() args.min = np.array(args.min)[:, np.newaxis, np.newaxis] @@ -41,7 +41,7 @@ def compress(src_path: str) -> None: dst_path = src_path.replace(args.src_dir, args.dst_dir) dst_dir = os.path.dirname(dst_path) os.makedirs(dst_dir, exist_ok=True) - with rio.open(src_path, "r") as src: + with rio.open(src_path, 'r') as src: x = src.read() x = (x - args.min) / (args.max - args.min) @@ -50,15 +50,15 @@ def compress(src_path: str) -> None: x = np.clip(x * 255, 0, 255).astype(np.uint8) profile = src.profile - profile["dtype"] = "uint8" - profile["compress"] = "lzw" - profile["predictor"] = 2 - with rio.open(dst_path, "w", **profile) as dst: + profile['dtype'] = 'uint8' + profile['compress'] = 'lzw' + profile['predictor'] = 2 + with rio.open(dst_path, 'w', **profile) as dst: for i, band in enumerate(dst.indexes): dst.write(x[i], band) paths = glob.glob( - os.path.join(args.src_dir, "**", f"*{args.suffix}"), recursive=True + os.path.join(args.src_dir, '**', f'*{args.suffix}'), recursive=True ) if args.num_workers > 0: diff --git a/experiments/ssl4eo/compute_dataset_pca.py b/experiments/ssl4eo/compute_dataset_pca.py index 018786bbbfe..32b1007cc99 100755 --- a/experiments/ssl4eo/compute_dataset_pca.py +++ b/experiments/ssl4eo/compute_dataset_pca.py @@ -11,16 +11,16 @@ import rasterio as rio from sklearn.decomposition import IncrementalPCA -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("directory", help="directory to recursively search for files") - parser.add_argument("--ext", default="tif", help="file extension") - parser.add_argument("--scale", default=255, type=float, help="scale factor") + parser.add_argument('directory', help='directory to recursively search for files') + parser.add_argument('--ext', default='tif', help='file extension') + parser.add_argument('--scale', default=255, type=float, help='scale factor') args = parser.parse_args() transformer = IncrementalPCA(n_components=1) for path in glob.iglob( - os.path.join(args.directory, "**", f"*.{args.ext}"), recursive=True + os.path.join(args.directory, '**', f'*.{args.ext}'), recursive=True ): with rio.open(path) as f: x = f.read().astype(np.float32) @@ -29,4 +29,4 @@ x = x.reshape((-1, x.shape[-1])) transformer.partial_fit(x) - print("pca:", transformer.components_) + print('pca:', transformer.components_) diff --git a/experiments/ssl4eo/compute_dataset_statistics.py b/experiments/ssl4eo/compute_dataset_statistics.py index efb3cf46bb2..05ff623f1e0 100755 --- a/experiments/ssl4eo/compute_dataset_statistics.py +++ b/experiments/ssl4eo/compute_dataset_statistics.py @@ -12,14 +12,14 @@ from tqdm import tqdm from tqdm.contrib.concurrent import thread_map -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("directory", help="directory to recursively search for files") - parser.add_argument("--suffix", default=".tif", help="file suffix") - parser.add_argument("--num-workers", type=int, default=10, help="number of threads") + parser.add_argument('directory', help='directory to recursively search for files') + parser.add_argument('--suffix', default='.tif', help='file suffix') + parser.add_argument('--num-workers', type=int, default=10, help='number of threads') args = parser.parse_args() - def compute(path: str) -> tuple["np.typing.NDArray[np.float32]", int]: + def compute(path: str) -> tuple['np.typing.NDArray[np.float32]', int]: """Compute the min, max, mean, and std dev of a single image. Args: @@ -36,7 +36,7 @@ def compute(path: str) -> tuple["np.typing.NDArray[np.float32]", int]: return out, f.width * f.height paths = glob.glob( - os.path.join(args.directory, "**", f"*{args.suffix}"), recursive=True + os.path.join(args.directory, '**', f'*{args.suffix}'), recursive=True ) if args.num_workers > 0: @@ -73,7 +73,7 @@ def compute(path: str) -> tuple["np.typing.NDArray[np.float32]", int]: ) np.set_printoptions(linewidth=2**8) - print("min:", repr(minimum)) - print("max:", repr(maximum)) - print("mean:", repr(mu)) - print("std:", repr(sigma)) + print('min:', repr(minimum)) + print('max:', repr(maximum)) + print('mean:', repr(mu)) + print('std:', repr(sigma)) diff --git a/experiments/ssl4eo/delete_excess.py b/experiments/ssl4eo/delete_excess.py index ca7c1a34627..0dd26bd8209 100755 --- a/experiments/ssl4eo/delete_excess.py +++ b/experiments/ssl4eo/delete_excess.py @@ -11,16 +11,16 @@ from tqdm import tqdm from tqdm.contrib.concurrent import thread_map -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("root", help="directory to search for scenes") - parser.add_argument("--num-workers", type=int, default=10, help="number of threads") + parser.add_argument('root', help='directory to search for scenes') + parser.add_argument('--num-workers', type=int, default=10, help='number of threads') parser.add_argument( - "--length", type=int, default=250000, help="number of scenes to keep" + '--length', type=int, default=250000, help='number of scenes to keep' ) args = parser.parse_args() - paths = sorted(glob.glob(os.path.join(args.root, "*"))) + paths = sorted(glob.glob(os.path.join(args.root, '*'))) paths = paths[args.length :] if args.num_workers > 0: diff --git a/experiments/ssl4eo/delete_mismatch.py b/experiments/ssl4eo/delete_mismatch.py index 31968ffb7d0..3956ef43185 100755 --- a/experiments/ssl4eo/delete_mismatch.py +++ b/experiments/ssl4eo/delete_mismatch.py @@ -15,40 +15,40 @@ def delete_scene(directories: list[str], scene_id: str) -> None: directories: directories to check scene_id: scene to delete """ - print(f"Removing {scene_id}") + print(f'Removing {scene_id}') for directory in directories: scene = os.path.join(directory, scene_id) if os.path.exists(scene): shutil.rmtree(scene) -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("directories", nargs="+", help="directories to compare") + parser.add_argument('directories', nargs='+', help='directories to compare') parser.add_argument( - "--delete-different-locations", - action="store_true", - help="delete scene locations that do not match", + '--delete-different-locations', + action='store_true', + help='delete scene locations that do not match', ) parser.add_argument( - "--delete-different-dates", - action="store_true", - help="delete scene dates that do not match (must be same satellite)", + '--delete-different-dates', + action='store_true', + help='delete scene dates that do not match (must be same satellite)', ) args = parser.parse_args() - print("Computing sets...") + print('Computing sets...') scene_sets = [set(os.listdir(directory)) for directory in args.directories] - print("Computing union...") + print('Computing union...') union = set.union(*scene_sets) total = len(union) - print("Computing intersection...") + print('Computing intersection...') intersection = set.intersection(*scene_sets) remaining = len(intersection) - print("Computing difference...") + print('Computing difference...') difference = union - intersection delete_locations = len(difference) @@ -72,6 +72,6 @@ def delete_scene(directories: list[str], scene_id: str) -> None: remaining -= delete_times delete = delete_locations + delete_times if not (args.delete_different_locations or args.delete_different_dates): - print(f"Would delete {delete} scenes, leaving {remaining} remaining scenes.") + print(f'Would delete {delete} scenes, leaving {remaining} remaining scenes.') else: - print(f"Deleted {delete} scenes, leaving {remaining} remaining scenes.") + print(f'Deleted {delete} scenes, leaving {remaining} remaining scenes.') diff --git a/experiments/ssl4eo/download_ssl4eo.py b/experiments/ssl4eo/download_ssl4eo.py index 3ca4fa71ce0..d93413bdd25 100755 --- a/experiments/ssl4eo/download_ssl4eo.py +++ b/experiments/ssl4eo/download_ssl4eo.py @@ -65,7 +65,7 @@ def date2str(date: date) -> str: - return date.strftime("%Y-%m-%d") + return date.strftime('%Y-%m-%d') def get_period(date: date, days: int = 5) -> tuple[str, str, str, str]: @@ -125,19 +125,19 @@ def filter_collection( if filtered.size().getInfo() == 0: raise ee.EEException( - f"ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}." # noqa: E501 + f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.' # noqa: E501 ) return filtered def center_crop( - img: "np.typing.NDArray[np.float32]", out_size: int -) -> "np.typing.NDArray[np.float32]": + img: 'np.typing.NDArray[np.float32]', out_size: int +) -> 'np.typing.NDArray[np.float32]': image_height, image_width = img.shape[:2] crop_height = crop_width = out_size pad_height = max(crop_height - image_height, 0) pad_width = max(crop_width - image_width, 0) - img = np.pad(img, ((pad_height, 0), (pad_width, 0), (0, 0)), mode="edge") + img = np.pad(img, ((pad_height, 0), (pad_width, 0), (0, 0)), mode='edge') crop_top = (image_height - crop_height + 1) // 2 crop_left = (image_width - crop_width + 1) // 2 return img[crop_top : crop_top + crop_height, crop_left : crop_left + crop_width] @@ -166,8 +166,8 @@ def get_patch( bands: list[str], original_resolutions: list[int], new_resolutions: list[int], - dtype: str = "float32", - meta_cloud_name: str = "CLOUD_COVER", + dtype: str = 'float32', + meta_cloud_name: str = 'CLOUD_COVER', default_value: float | None = None, ) -> dict[str, Any]: image = collection.sort(meta_cloud_name).first() @@ -188,22 +188,22 @@ def get_patch( patch = patch.sampleRectangle(region, defaultValue=default_value) features = patch.getInfo() for i, band in zip(indices, bands_group): - x = features["properties"][band] + x = features['properties'][band] x = np.atleast_3d(x) x = center_crop(x, out_size=int(2 * radius // new_res)) raster[i] = x.astype(dtype) # Compute coordinates after cropping - coords0 = np.array(features["geometry"]["coordinates"][0]) + coords0 = np.array(features['geometry']['coordinates'][0]) coords = [ [coords0[:, 0].min(), coords0[:, 1].max()], [coords0[:, 0].max(), coords0[:, 1].min()], ] - old_size = (len(features["properties"][band]), len(features["properties"][band][0])) + old_size = (len(features['properties'][band]), len(features['properties'][band][0])) new_size = raster[0].shape[:2] coords = adjust_coords(coords, old_size, new_size) - return {"raster": raster, "coords": coords, "metadata": image.getInfo()} + return {'raster': raster, 'coords': coords, 'metadata': image.getInfo()} def get_random_patches_match( @@ -253,7 +253,7 @@ def get_random_patches_match( def save_geotiff( - img: "np.typing.NDArray[np.float32]", + img: 'np.typing.NDArray[np.float32]', coords: list[tuple[float, float]], filename: str, ) -> None: @@ -264,40 +264,40 @@ def save_geotiff( coords[0][0] - xres / 2, coords[0][1] + yres / 2 ) * Affine.scale(xres, -yres) profile = { - "driver": "GTiff", - "width": width, - "height": height, - "count": channels, - "crs": "+proj=latlong", - "transform": transform, - "dtype": img.dtype, - "compress": "None", + 'driver': 'GTiff', + 'width': width, + 'height': height, + 'count': channels, + 'crs': '+proj=latlong', + 'transform': transform, + 'dtype': img.dtype, + 'compress': 'None', } - with rasterio.open(filename, "w", **profile) as f: + with rasterio.open(filename, 'w', **profile) as f: f.write(img.transpose(2, 0, 1)) def save_patch( - raster: dict[int, "np.typing.NDArray[np.float32]"], + raster: dict[int, 'np.typing.NDArray[np.float32]'], coords: list[tuple[float, float]], metadata: dict[str, Any], bands: list[str], new_resolutions: list[int], path: str, ) -> None: - patch_id = metadata["properties"]["system:index"] + patch_id = metadata['properties']['system:index'] patch_path = os.path.join(path, patch_id) os.makedirs(patch_path, exist_ok=True) if len(set(new_resolutions)) == 1: img_all = np.concatenate([raster[i] for i in range(len(raster))], axis=2) - save_geotiff(img_all, coords, os.path.join(patch_path, "all_bands.tif")) + save_geotiff(img_all, coords, os.path.join(patch_path, 'all_bands.tif')) else: for i, band in enumerate(bands): img = raster[i] - save_geotiff(img, coords, os.path.join(patch_path, f"{band}.tif")) + save_geotiff(img, coords, os.path.join(patch_path, f'{band}.tif')) - with open(os.path.join(patch_path, "metadata.json"), "w") as f: + with open(os.path.join(patch_path, 'metadata.json'), 'w') as f: json.dump(metadata, f) @@ -312,60 +312,60 @@ def update(self, delta: int = 1) -> int: return self.value -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--save-path", type=str, default="./data/", help="dir to save data" + '--save-path', type=str, default='./data/', help='dir to save data' ) # collection properties parser.add_argument( - "--collection", type=str, default="COPERNICUS/S2", help="GEE collection name" + '--collection', type=str, default='COPERNICUS/S2', help='GEE collection name' ) - parser.add_argument("--qa-band", type=str, default="QA60", help="qa band name") + parser.add_argument('--qa-band', type=str, default='QA60', help='qa band name') parser.add_argument( - "--qa-cloud-bit", type=int, default=10, help="qa band cloud bit" + '--qa-cloud-bit', type=int, default=10, help='qa band cloud bit' ) parser.add_argument( - "--meta-cloud-name", + '--meta-cloud-name', type=str, - default="CLOUDY_PIXEL_PERCENTAGE", - help="meta data cloud percentage name", + default='CLOUDY_PIXEL_PERCENTAGE', + help='meta data cloud percentage name', ) parser.add_argument( - "--cloud-pct", type=int, default=20, help="cloud percentage threshold" + '--cloud-pct', type=int, default=20, help='cloud percentage threshold' ) # patch properties parser.add_argument( - "--dates", + '--dates', type=str, - nargs="+", + nargs='+', # https://www.weather.gov/media/ind/seasons.pdf - default=["2021-12-21", "2021-09-23", "2021-06-21", "2021-03-20"], - help="reference dates", + default=['2021-12-21', '2021-09-23', '2021-06-21', '2021-03-20'], + help='reference dates', ) parser.add_argument( - "--radius", type=int, default=1320, help="patch radius in meters" + '--radius', type=int, default=1320, help='patch radius in meters' ) parser.add_argument( - "--bands", + '--bands', type=str, - nargs="+", + nargs='+', default=[ - "B1", - "B2", - "B3", - "B4", - "B5", - "B6", - "B7", - "B8", - "B8A", - "B9", - "B10", - "B11", - "B12", + 'B1', + 'B2', + 'B3', + 'B4', + 'B5', + 'B6', + 'B7', + 'B8', + 'B8A', + 'B9', + 'B10', + 'B11', + 'B12', ], - help="bands to download", + help='bands to download', ) # Reprojection options # @@ -374,47 +374,47 @@ def update(self, delta: int = 1) -> int: # to after reprojection. All of these options should either be a single value # or the same length as the bands flag. parser.add_argument( - "--original-resolutions", + '--original-resolutions', type=int, - nargs="+", + nargs='+', default=[60, 10, 10, 10, 20, 20, 20, 10, 20, 60, 60, 20, 20], - help="original band resolutions in meters", + help='original band resolutions in meters', ) parser.add_argument( - "--new-resolutions", + '--new-resolutions', type=int, - nargs="+", + nargs='+', default=[10], - help="new band resolutions in meters", + help='new band resolutions in meters', ) - parser.add_argument("--dtype", type=str, default="float32", help="data type") + parser.add_argument('--dtype', type=str, default='float32', help='data type') # If None, don't download patches with nodata pixels parser.add_argument( - "--default-value", type=float, default=None, help="default fill value" + '--default-value', type=float, default=None, help='default fill value' ) # download settings - parser.add_argument("--num-workers", type=int, default=8, help="number of workers") - parser.add_argument("--log-freq", type=int, default=10, help="print frequency") + parser.add_argument('--num-workers', type=int, default=8, help='number of workers') + parser.add_argument('--log-freq', type=int, default=10, help='print frequency') parser.add_argument( - "--resume", type=str, default=None, help="resume from a previous run" + '--resume', type=str, default=None, help='resume from a previous run' ) # sampler options parser.add_argument( - "--match-file", + '--match-file', type=str, required=True, - help="match pre-sampled coordinates and indexes", + help='match pre-sampled coordinates and indexes', ) # number of locations to download parser.add_argument( - "--indices-range", + '--indices-range', type=int, nargs=2, default=[0, 250000], - help="indices to download", + help='indices to download', ) # debug - parser.add_argument("--debug", action="store_true", help="debug mode") + parser.add_argument('--debug', action='store_true', help='debug mode') args = parser.parse_args() os.makedirs(args.save_path, exist_ok=True) @@ -458,7 +458,7 @@ def update(self, delta: int = 1) -> int: ext_coords[key] = (val1, val2) # lon, lat ext_flags[key] = int(row[3]) # success or not else: - ext_path = os.path.join(args.save_path, "checked_locations.csv") + ext_path = os.path.join(args.save_path, 'checked_locations.csv') # match from pre-sampled coords match_coords = {} @@ -480,7 +480,7 @@ def worker(idx: int) -> None: # Skip if idx is not in pre-sampled coordinates if idx not in match_coords.keys(): - warnings.warn(f"{idx} not found in {args.match_file}, skipping.") + warnings.warn(f'{idx} not found in {args.match_file}, skipping.') return worker_start = time.time() @@ -500,13 +500,13 @@ def worker(idx: int) -> None: ) if patches: - location_path = os.path.join(args.save_path, "imgs", f"{idx:07d}") + location_path = os.path.join(args.save_path, 'imgs', f'{idx:07d}') os.makedirs(location_path, exist_ok=True) for patch in patches: save_patch( - patch["raster"], - patch["coords"], - patch["metadata"], + patch['raster'], + patch['coords'], + patch['metadata'], bands, new_resolutions, location_path, @@ -514,13 +514,13 @@ def worker(idx: int) -> None: count = counter.update(1) if count % args.log_freq == 0: - print(f"Downloaded {count} images in {time.time() - start_time:.3f}s.") + print(f'Downloaded {count} images in {time.time() - start_time:.3f}s.') else: if args.debug: - print("no suitable image for location %d." % (idx)) + print('no suitable image for location %d.' % (idx)) # add to existing checked locations - with open(ext_path, "a") as f: + with open(ext_path, 'a') as f: writer = csv.writer(f) if patches: success = 1 diff --git a/experiments/ssl4eo/flops.py b/experiments/ssl4eo/flops.py index 6bffb1835f9..5c7a77b78d5 100755 --- a/experiments/ssl4eo/flops.py +++ b/experiments/ssl4eo/flops.py @@ -7,7 +7,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.profiling.flops_profiler import get_model_profile -models = ["resnet18", "resnet50", "vit_small_patch16_224"] +models = ['resnet18', 'resnet50', 'vit_small_patch16_224'] num_classes = 14 in_channels = 11 batch_size = 64 @@ -15,7 +15,7 @@ input_shape = (batch_size, in_channels, patch_size, patch_size) for model in models: - print(f"Model: {model}") + print(f'Model: {model}') m = timm.create_model(model, num_classes=num_classes, in_chans=in_channels) @@ -23,7 +23,7 @@ mem_params = sum([p.nelement() * p.element_size() for p in m.parameters()]) mem_bufs = sum([b.nelement() * b.element_size() for b in m.buffers()]) mem = (mem_params + mem_bufs) / 1000000 - print(f"Memory: {mem:.2f} MB") + print(f'Memory: {mem:.2f} MB') with get_accelerator().device(0): get_model_profile( diff --git a/experiments/ssl4eo/landsat/chip_landsat_benchmark.py b/experiments/ssl4eo/landsat/chip_landsat_benchmark.py index 77bfaa0f057..93fd16b0209 100755 --- a/experiments/ssl4eo/landsat/chip_landsat_benchmark.py +++ b/experiments/ssl4eo/landsat/chip_landsat_benchmark.py @@ -17,7 +17,7 @@ def retrieve_mask_chip( img_src: DatasetReader, mask_src: DatasetReader -) -> "np.typing.NDArray[np.uint8]": +) -> 'np.typing.NDArray[np.uint8]': """Retrieve the mask for a given landsat image. Args: @@ -29,40 +29,40 @@ def retrieve_mask_chip( mask array """ out_shape = (1, *img_src.shape) - mask_chip: "np.typing.NDArray[np.uint8]" = mask_src.read( + mask_chip: 'np.typing.NDArray[np.uint8]' = mask_src.read( out_shape=out_shape, window=from_bounds(*img_src.bounds, mask_src.transform) ) # Copy nodata pixels from image to mask (Landsat 7 ETM+ SLC-off only) - if "LE07" in img_src.files[0]: + if 'LE07' in img_src.files[0]: img_chip = img_src.read(1) mask_chip[0][img_chip == 0] = 0 return mask_chip -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--landsat-dir", help="directory to recursively search for files", required=True + '--landsat-dir', help='directory to recursively search for files', required=True ) parser.add_argument( - "--mask-path", help="path to downstream task mask to chip", required=True + '--mask-path', help='path to downstream task mask to chip', required=True ) parser.add_argument( - "--save-dir", help="directory where to save masks", required=True + '--save-dir', help='directory where to save masks', required=True ) - parser.add_argument("--suffix", default=".tif", help="file suffix") + parser.add_argument('--suffix', default='.tif', help='file suffix') args = parser.parse_args() paths = glob.glob( - os.path.join(args.landsat_dir, "**", f"all_bands{args.suffix}"), recursive=True + os.path.join(args.landsat_dir, '**', f'all_bands{args.suffix}'), recursive=True ) - if "nlcd" in args.mask_path: - layer_name = "nlcd" + if 'nlcd' in args.mask_path: + layer_name = 'nlcd' else: - layer_name = "cdl" + layer_name = 'cdl' for img_path in tqdm(paths): with ( @@ -77,17 +77,17 @@ def retrieve_mask_chip( # directory structure mask <7-digit id>//_.tif digit_id, scene_id = img_path.split(os.sep)[-3:-1] - year = scene_id.split("_")[-1][:4] + year = scene_id.split('_')[-1][:4] mask_dir = os.path.join(args.save_dir, digit_id, scene_id) os.makedirs(mask_dir, exist_ok=True) # write mask tif profile = img_src.profile - profile["count"] = 1 - profile["dtype"] = mask_src.profile["dtype"] + profile['count'] = 1 + profile['dtype'] = mask_src.profile['dtype'] with rasterio.open( - os.path.join(mask_dir, f"{layer_name}_{year}.tif"), "w", **profile + os.path.join(mask_dir, f'{layer_name}_{year}.tif'), 'w', **profile ) as dst: dst.write(mask) dst.write_colormap(1, mask_src.colormap(1)) diff --git a/experiments/ssl4eo/landsat/plot_landsat_bands.py b/experiments/ssl4eo/landsat/plot_landsat_bands.py index ddc81fdf9cd..2c72e7f887f 100755 --- a/experiments/ssl4eo/landsat/plot_landsat_bands.py +++ b/experiments/ssl4eo/landsat/plot_landsat_bands.py @@ -14,39 +14,39 @@ # Match NeurIPS template plt.rcParams.update( { - "font.family": "Times New Roman", - "font.size": 10, - "axes.labelsize": 10, - "text.usetex": True, + 'font.family': 'Times New Roman', + 'font.size': 10, + 'axes.labelsize': 10, + 'text.usetex': True, } ) parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=__doc__ ) -parser.add_argument("skip", nargs="*", help="sensors to skip", metavar="SENSOR") +parser.add_argument('skip', nargs='*', help='sensors to skip', metavar='SENSOR') parser.add_argument( - "--fig-height", default=5, type=float, help="height of figure in inches" + '--fig-height', default=5, type=float, help='height of figure in inches' ) -parser.add_argument("--bar-start", default=1, type=float, help="height of first bar") -parser.add_argument("--bar-height", default=3, type=float, help="height of each bar") +parser.add_argument('--bar-start', default=1, type=float, help='height of first bar') +parser.add_argument('--bar-height', default=3, type=float, help='height of each bar') parser.add_argument( - "--bar-sep", default=3.5, type=float, help="separation between bars" + '--bar-sep', default=3.5, type=float, help='separation between bars' ) parser.add_argument( - "--bar-jump", default=2.6, type=float, help="additional height for narrow bars" + '--bar-jump', default=2.6, type=float, help='additional height for narrow bars' ) parser.add_argument( - "--sensor-sep", default=2, type=float, help="separation between sensors" + '--sensor-sep', default=2, type=float, help='separation between sensors' ) args = parser.parse_args() # https://www.usgs.gov/landsat-missions/landsat-satellite-missions -df = pd.read_csv("band_data.csv", skip_blank_lines=True) +df = pd.read_csv('band_data.csv', skip_blank_lines=True) df = df.iloc[::-1] fig, ax = plt.subplots(figsize=(5.5, args.fig_height)) -ax1, ax2 = fig.subplots(nrows=1, ncols=2, gridspec_kw={"width_ratios": [3, 1]}) # type: ignore[misc] +ax1, ax2 = fig.subplots(nrows=1, ncols=2, gridspec_kw={'width_ratios': [3, 1]}) # type: ignore[misc] sensor_names: list[str] = [] sensor_ylocs: list[float] = [] @@ -55,16 +55,16 @@ bar_min = args.bar_start # For each satellite/sensor -for (satellite, sensor), group1 in df.groupby(["Satellite", "Sensor"], sort=False): +for (satellite, sensor), group1 in df.groupby(['Satellite', 'Sensor'], sort=False): if sensor in args.skip: continue - sensor_names.append(f"{satellite}\n({sensor})") + sensor_names.append(f'{satellite}\n({sensor})') sensor_yloc = 0.0 res_count = 0 # For each resolution - for res, group2 in group1.groupby("Resolution (m)"): + for res, group2 in group1.groupby('Resolution (m)'): res_names.append(res) res_ylocs.append(bar_min) @@ -75,10 +75,10 @@ # For each band for i in range(group2.shape[0]): row = group2.iloc[i] - wavelength_start = row["Wavelength Start (μm)"] - wavelength_width = row["Wavelength Width"] - color = row["Color"] - band = row["Band"] + wavelength_start = row['Wavelength Start (μm)'] + wavelength_width = row['Wavelength Width'] + color = row['Color'] + band = row['Band'] # We've split the plot into two parts as the thermal bands are > 10μm # while the other bands are < 3μm @@ -90,7 +90,7 @@ ax1.broken_barh( [[wavelength_start, wavelength_width]], [bar_min, args.bar_height], - edgecolor="k", + edgecolor='k', facecolors=color, linewidth=0.5, alpha=0.8, @@ -99,14 +99,14 @@ wavelength_start + (wavelength_width / 2), y, band, - horizontalalignment="center", - verticalalignment="center_baseline", + horizontalalignment='center', + verticalalignment='center_baseline', ) else: ax2.broken_barh( [[wavelength_start, wavelength_width]], [bar_min, args.bar_height], - edgecolor="k", + edgecolor='k', facecolors=color, linewidth=0.5, alpha=0.8, @@ -115,8 +115,8 @@ wavelength_start + (wavelength_width / 2), y, band, - horizontalalignment="center", - verticalalignment="center_baseline", + horizontalalignment='center', + verticalalignment='center_baseline', ) bar_min += args.bar_sep bar_min += args.sensor_sep @@ -124,35 +124,35 @@ sensor_ylocs.append(sensor_yloc) # Labels -ax.set_xlabel(r"Wavelength (\textmu m)") +ax.set_xlabel(r'Wavelength (\textmu m)') ax.set_xticks([0], labels=[0]) ax.set_yticks([0], labels=[0]) -ax.tick_params(colors="w") -ax.spines[["bottom", "left", "top", "right"]].set_visible(False) +ax.tick_params(colors='w') +ax.spines[['bottom', 'left', 'top', 'right']].set_visible(False) ax1.set_yticks(np.array(sensor_ylocs) + args.bar_height / 2) ax1.set_yticklabels(sensor_names) ax1.set_ylim(0, max(res_ylocs) + args.bar_height + args.bar_start) -ax1.spines[["left", "top", "right"]].set_visible(False) -ax1.tick_params(axis="both", which="both", left=False) +ax1.spines[['left', 'top', 'right']].set_visible(False) +ax1.tick_params(axis='both', which='both', left=False) ax2.set_xlim(10.1, 12.8) ax2.set_ylim(0, max(res_ylocs) + args.bar_height + args.bar_start) -ax2.yaxis.set_label_position("right") +ax2.yaxis.set_label_position('right') ax2.yaxis.tick_right() -ax2.set_ylabel("Resolution (m)") +ax2.set_ylabel('Resolution (m)') ax2.set_yticks(np.array(res_ylocs) + args.bar_height / 2) ax2.set_yticklabels(res_names) -ax2.spines[["left", "top"]].set_visible(False) +ax2.spines[['left', 'top']].set_visible(False) # Draw axis break symbol d = 1.5 kwargs = dict( marker=[(-1, -d), (1, d)], markersize=12, - linestyle="none", - color="k", - mec="k", + linestyle='none', + color='k', + mec='k', mew=0.75, clip_on=False, ) diff --git a/experiments/ssl4eo/landsat/plot_landsat_timeline.py b/experiments/ssl4eo/landsat/plot_landsat_timeline.py index 94eb40dbe0e..ef855f01eee 100755 --- a/experiments/ssl4eo/landsat/plot_landsat_timeline.py +++ b/experiments/ssl4eo/landsat/plot_landsat_timeline.py @@ -11,20 +11,20 @@ # Match NeurIPS template plt.rcParams.update( { - "font.family": "Times New Roman", - "font.size": 10, - "axes.labelsize": 10, - "text.usetex": True, - "hatch.linewidth": 0.5, + 'font.family': 'Times New Roman', + 'font.size': 10, + 'axes.labelsize': 10, + 'text.usetex': True, + 'hatch.linewidth': 0.5, } ) parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=__doc__ ) -parser.add_argument("--bar-start", default=1, type=float, help="height of first bar") -parser.add_argument("--bar-height", default=3, type=float, help="height of each bar") -parser.add_argument("--bar-sep", default=2, type=float, help="separation between bars") +parser.add_argument('--bar-start', default=1, type=float, help='height of first bar') +parser.add_argument('--bar-height', default=3, type=float, help='height of each bar') +parser.add_argument('--bar-sep', default=2, type=float, help='separation between bars') args = parser.parse_args() working: dict[int, list[tuple[date, date]]] = { @@ -80,18 +80,18 @@ for satellite in range(9, 0, -1): # Bar plot kwargs = { - "yrange": (ymin, args.bar_height), - "alpha": 0.8, - "color": next(cmap), - "edgecolor": (0, 0, 0, 0.8), - "linewidth": 0.5, + 'yrange': (ymin, args.bar_height), + 'alpha': 0.8, + 'color': next(cmap), + 'edgecolor': (0, 0, 0, 0.8), + 'linewidth': 0.5, } xranges = [(start, end - start) for start, end in working[satellite]] ax.broken_barh(xranges, hatch=None, **kwargs) xranges = [(start, end - start) for start, end in failing[satellite]] - ax.broken_barh(xranges, hatch="////", **kwargs) + ax.broken_barh(xranges, hatch='////', **kwargs) # Label xmin = global_xmax @@ -106,25 +106,25 @@ if (xmin - global_xmin) > (global_xmax - xmax): # Left side label x = xmin - timedelta(weeks=52) - horizontalalignment = "right" + horizontalalignment = 'right' else: # Right side label x = xmax + timedelta(weeks=52) - horizontalalignment = "left" + horizontalalignment = 'left' - start = f"{xmin:%b %Y}" - end = f"{xmax:%b %Y}" + start = f'{xmin:%b %Y}' + end = f'{xmax:%b %Y}' if xmax == date.today(): - end = "Present" + end = 'Present' if start == end: s = start else: - s = f"{start}--{end}" + s = f'{start}--{end}' kwargs = { - "y": ymin + args.bar_height / 2, - "s": s, - "verticalalignment": "center_baseline", + 'y': ymin + args.bar_height / 2, + 's': s, + 'verticalalignment': 'center_baseline', } ax.text(x, horizontalalignment=horizontalalignment, **kwargs) @@ -134,11 +134,11 @@ ax.xaxis_date() ax.set_xlim(global_xmin, global_xmax) -ax.set_ylabel("Landsat Mission") +ax.set_ylabel('Landsat Mission') ax.set_yticks(yticks) ax.set_yticklabels(range(9, 0, -1)) -ax.tick_params(axis="both", which="both", top=False, right=False) -ax.spines[["top", "right"]].set_visible(False) +ax.tick_params(axis='both', which='both', top=False, right=False) +ax.spines[['top', 'right']].set_visible(False) plt.tight_layout() plt.show() diff --git a/experiments/ssl4eo/plot_example_predictions.py b/experiments/ssl4eo/plot_example_predictions.py index 161a70a741d..596c8ea304d 100755 --- a/experiments/ssl4eo/plot_example_predictions.py +++ b/experiments/ssl4eo/plot_example_predictions.py @@ -13,28 +13,28 @@ from torchgeo.datamodules import L7IrishDataModule from torchgeo.datasets import unbind_samples -device = torch.device("cpu") +device = torch.device('cpu') # Load weights -path = "data/l7irish/checkpoint-epoch=26-val_loss=0.68.ckpt" -state_dict = torch.load(path, map_location=device)["state_dict"] -state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()} +path = 'data/l7irish/checkpoint-epoch=26-val_loss=0.68.ckpt' +state_dict = torch.load(path, map_location=device)['state_dict'] +state_dict = {key.replace('model.', ''): value for key, value in state_dict.items()} # Initialize model -model = smp.Unet(encoder_name="resnet18", in_channels=9, classes=5) +model = smp.Unet(encoder_name='resnet18', in_channels=9, classes=5) model.to(device) model.load_state_dict(state_dict) # Initialize data loaders datamodule = L7IrishDataModule( - root="data/l7irish", crs="epsg:3857", download=True, batch_size=1, patch_size=224 + root='data/l7irish', crs='epsg:3857', download=True, batch_size=1, patch_size=224 ) -datamodule.setup("test") +datamodule.setup('test') i = 0 for batch in datamodule.test_dataloader(): - image = batch["image"] - mask = batch["mask"] + image = batch['image'] + mask = batch['mask'] image.to(device) # Skip nodata pixels @@ -48,28 +48,28 @@ # Make a prediction prediction = model(image) prediction = prediction.argmax(dim=1) - prediction.detach().to("cpu") + prediction.detach().to('cpu') - batch["prediction"] = prediction + batch['prediction'] = prediction for sample in unbind_samples(batch): # Plot # datamodule.test_dataset.plot(sample) # plt.show() - path = f"data/l7irish_predictions/{i}" - print(f"Saving {path}...") + path = f'data/l7irish_predictions/{i}' + print(f'Saving {path}...') os.makedirs(path, exist_ok=True) - for key in ["image", "mask", "prediction"]: + for key in ['image', 'mask', 'prediction']: data = sample[key] - if key == "image": - data = data[[2, 1, 0]].permute(1, 2, 0).numpy().astype("uint8") - Image.fromarray(data, "RGB").save( # type: ignore[no-untyped-call] - f"{path}/{key}.png" + if key == 'image': + data = data[[2, 1, 0]].permute(1, 2, 0).numpy().astype('uint8') + Image.fromarray(data, 'RGB').save( # type: ignore[no-untyped-call] + f'{path}/{key}.png' ) else: data = data * 255 / 4 - data = data.numpy().astype("uint8").squeeze() - Image.fromarray(data, "L").save( # type: ignore[no-untyped-call] - f"{path}/{key}.png" + data = data.numpy().astype('uint8').squeeze() + Image.fromarray(data, 'L').save( # type: ignore[no-untyped-call] + f'{path}/{key}.png' ) i += 1 diff --git a/experiments/ssl4eo/sample_conus.py b/experiments/ssl4eo/sample_conus.py index 518f956a82f..3f303ebcede 100755 --- a/experiments/ssl4eo/sample_conus.py +++ b/experiments/ssl4eo/sample_conus.py @@ -26,49 +26,49 @@ def retrieve_rois_polygons(download_root: str) -> MultiPolygon: Returns: MultiPolygon of CONUS """ - state_url = "https://www2.census.gov/geo/tiger/GENZ2018/shp/cb_2018_us_state_5m.zip" - state_filename = "cb_2018_us_state_5m.shp" + state_url = 'https://www2.census.gov/geo/tiger/GENZ2018/shp/cb_2018_us_state_5m.zip' + state_filename = 'cb_2018_us_state_5m.shp' download_and_extract_archive(state_url, download_root) excluded_states = [ - "United States Virgin Islands", - "Commonwealth of the Northern Mariana Islands", - "Puerto Rico", - "Alaska", - "Hawaii", - "American Samoa", - "Guam", + 'United States Virgin Islands', + 'Commonwealth of the Northern Mariana Islands', + 'Puerto Rico', + 'Alaska', + 'Hawaii', + 'American Samoa', + 'Guam', ] conus = [] - with fiona.open(os.path.join(download_root, state_filename), "r") as shapefile: + with fiona.open(os.path.join(download_root, state_filename), 'r') as shapefile: for feature in shapefile: - name = feature["properties"]["NAME"] + name = feature['properties']['NAME'] if name in excluded_states: continue else: - conus.append(shape(feature["geometry"])) + conus.append(shape(feature['geometry'])) conus = unary_union(conus) return conus -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--save-path", type=str, default="./data/", help="dir to save data" + '--save-path', type=str, default='./data/', help='dir to save data' ) parser.add_argument( - "--size", type=float, default=1320, help="half patch size in meters" + '--size', type=float, default=1320, help='half patch size in meters' ) parser.add_argument( - "--indices-range", + '--indices-range', type=int, nargs=2, default=[0, 500], - help="indices to download", + help='indices to download', ) parser.add_argument( - "--resume", action="store_true", help="resume from a previous run" + '--resume', action='store_true', help='resume from a previous run' ) args = parser.parse_args() @@ -77,13 +77,13 @@ def retrieve_rois_polygons(download_root: str) -> MultiPolygon: bbox_size = args.size / 1000 # no overlap between adjacent patches bbox_size_degree = km2deg(bbox_size) - root = os.path.join(args.save_path, "conus") - csv_path = os.path.join(args.save_path, "sampled_locations.csv") + root = os.path.join(args.save_path, 'conus') + csv_path = os.path.join(args.save_path, 'sampled_locations.csv') # Populate R-tree if resuming rtree_coords = index.Index() if args.resume: - print("Loading existing locations...") + print('Loading existing locations...') with open(csv_path) as csv_file: reader = csv.reader(csv_file) for i, row in enumerate(reader): @@ -101,7 +101,7 @@ def retrieve_rois_polygons(download_root: str) -> MultiPolygon: conus = retrieve_rois_polygons(root) x_min, y_min, x_max, y_max = conus.bounds - with open(csv_path, "a") as f: + with open(csv_path, 'a') as f: writer = csv.writer(f) for idx in tqdm(range(*args.indices_range)): count = 0 diff --git a/experiments/ssl4eo/sample_ssl4eo.py b/experiments/ssl4eo/sample_ssl4eo.py index 313978851e9..68d1056df55 100755 --- a/experiments/ssl4eo/sample_ssl4eo.py +++ b/experiments/ssl4eo/sample_ssl4eo.py @@ -45,15 +45,15 @@ def get_world_cities( - download_root: str = "world_cities", size: int = 10000 + download_root: str = 'world_cities', size: int = 10000 ) -> pd.DataFrame: - url = "https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip" # noqa: E501 - filename = "worldcities.csv" + url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip' # noqa: E501 + filename = 'worldcities.csv' download_and_extract_archive(url, download_root) - cols = ["city", "lat", "lng", "population"] + cols = ['city', 'lat', 'lng', 'population'] cities = pd.read_csv(os.path.join(download_root, filename), usecols=cols) - cities.at[8436, "population"] = 50789 # fix one bug (Tecax) in the csv file - cities = cities.nlargest(size, "population") + cities.at[8436, 'population'] = 50789 # fix one bug (Tecax) in the csv file + cities = cities.nlargest(size, 'population') return cities @@ -63,7 +63,7 @@ def km2deg(kms: float, radius: float = 6371) -> float: def sample_point(cities: pd.DataFrame, std: float) -> tuple[float, float]: city = cities.sample() - point = (float(city["lng"]), float(city["lat"])) + point = (float(city['lng']), float(city['lat'])) std = km2deg(std) lon, lat = np.random.normal(loc=point, scale=[std, std]) return (lon, lat) @@ -82,35 +82,35 @@ def create_bbox( return bbox -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--save-path", type=str, default="./data/", help="dir to save data" + '--save-path', type=str, default='./data/', help='dir to save data' ) parser.add_argument( - "--size", type=float, default=1320, help="half patch size in meters" + '--size', type=float, default=1320, help='half patch size in meters' ) parser.add_argument( - "--num-cities", type=int, default=10000, help="number of cities to sample" + '--num-cities', type=int, default=10000, help='number of cities to sample' ) parser.add_argument( - "--std", type=int, default=50, help="std dev of gaussian distribution" + '--std', type=int, default=50, help='std dev of gaussian distribution' ) parser.add_argument( - "--resume", action="store_true", help="resume from a previous run" + '--resume', action='store_true', help='resume from a previous run' ) parser.add_argument( - "--indices-range", + '--indices-range', type=int, nargs=2, default=[0, 250000], - help="indices to sample", + help='indices to sample', ) args = parser.parse_args() os.makedirs(args.save_path, exist_ok=True) - path = os.path.join(args.save_path, "sampled_locations.csv") - root = os.path.join(args.save_path, "world_cities") + path = os.path.join(args.save_path, 'sampled_locations.csv') + root = os.path.join(args.save_path, 'world_cities') cities = get_world_cities(download_root=root, size=args.num_cities) bbox_size = args.size / 1000 # no overlap between adjacent patches bbox_size_degree = km2deg(bbox_size) @@ -118,7 +118,7 @@ def create_bbox( # Populate R-tree if resuming rtree_coords = index.Index() if args.resume: - print("Loading existing locations...") + print('Loading existing locations...') with open(path) as csv_file: reader = csv.reader(csv_file) for i, row in enumerate(reader): @@ -133,9 +133,9 @@ def create_bbox( os.remove(path) # Sample locations and save to file - print("Sampling new locations...") + print('Sampling new locations...') start_time = time.time() - with open(path, "a") as f: + with open(path, 'a') as f: writer = csv.writer(f) for i in tqdm(range(*args.indices_range)): # Sample new coord and check overlap @@ -152,4 +152,4 @@ def create_bbox( f.flush() elapsed = time.time() - start_time - print(f"Sampled locations saved to {path} in {elapsed:.2f} seconds.") + print(f'Sampled locations saved to {path} in {elapsed:.2f} seconds.') diff --git a/experiments/torchgeo/benchmark.py b/experiments/torchgeo/benchmark.py index 7b1c98d2611..f5d0fd8eba6 100755 --- a/experiments/torchgeo/benchmark.py +++ b/experiments/torchgeo/benchmark.py @@ -32,85 +32,85 @@ def set_up_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--landsat-root", - default=os.path.join("data", "landsat"), - help="directory containing Landsat data", - metavar="ROOT", + '--landsat-root', + default=os.path.join('data', 'landsat'), + help='directory containing Landsat data', + metavar='ROOT', ) parser.add_argument( - "--cdl-root", - default=os.path.join("data", "cdl"), - help="directory containing CDL data", - metavar="ROOT", + '--cdl-root', + default=os.path.join('data', 'cdl'), + help='directory containing CDL data', + metavar='ROOT', ) parser.add_argument( - "-d", "--device", default=0, type=int, help="CPU/GPU ID to use", metavar="ID" + '-d', '--device', default=0, type=int, help='CPU/GPU ID to use', metavar='ID' ) parser.add_argument( - "-c", - "--cache", - action="store_true", - help="cache file handles during data loading", + '-c', + '--cache', + action='store_true', + help='cache file handles during data loading', ) parser.add_argument( - "-b", - "--batch-size", + '-b', + '--batch-size', default=2**4, type=int, - help="number of samples in each mini-batch", - metavar="SIZE", + help='number of samples in each mini-batch', + metavar='SIZE', ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument( - "-n", - "--num-batches", + '-n', + '--num-batches', type=int, - help="number of batches to load", - metavar="SIZE", + help='number of batches to load', + metavar='SIZE', ) group.add_argument( - "-e", - "--epoch-size", + '-e', + '--epoch-size', type=int, - help="number of samples to load, should be evenly divisible by batch size", - metavar="SIZE", + help='number of samples to load, should be evenly divisible by batch size', + metavar='SIZE', ) parser.add_argument( - "-p", - "--patch-size", + '-p', + '--patch-size', default=224, type=int, - help="height/width of each patch in pixels", - metavar="PIXELS", + help='height/width of each patch in pixels', + metavar='PIXELS', ) parser.add_argument( - "-s", - "--stride", + '-s', + '--stride', default=112, type=int, - help="sampling stride for GridGeoSampler in pixels", - metavar="PIXELS", + help='sampling stride for GridGeoSampler in pixels', + metavar='PIXELS', ) parser.add_argument( - "-w", - "--num-workers", + '-w', + '--num-workers', default=0, type=int, - help="number of workers for parallel data loading", - metavar="NUM", + help='number of workers for parallel data loading', + metavar='NUM', ) parser.add_argument( - "--seed", default=0, type=int, help="random seed for reproducibility" + '--seed', default=0, type=int, help='random seed for reproducibility' ) parser.add_argument( - "--output-fn", - default="benchmark-results.csv", + '--output-fn', + default='benchmark-results.csv', type=str, - help="path to the CSV file to write results", - metavar="FILE", + help='path to the CSV file to write results', + metavar='FILE', ) parser.add_argument( - "-v", "--verbose", action="store_true", help="print results to stdout" + '-v', '--verbose', action='store_true', help='print results to stdout' ) return parser @@ -124,7 +124,7 @@ def main(args: argparse.Namespace) -> None: Args: args: command-line arguments """ - bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"] + bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7'] # Benchmark samplers @@ -154,7 +154,7 @@ def main(args: argparse.Namespace) -> None: results_rows = [] for sampler in samplers: if args.verbose: - print(f"\n{sampler.__class__.__name__}:") + print(f'\n{sampler.__class__.__name__}:') if isinstance(sampler, RandomBatchGeoSampler): dataloader = DataLoader( @@ -183,9 +183,9 @@ def main(args: argparse.Namespace) -> None: duration = toc - tic if args.verbose: - print(f" duration: {duration:.3f} sec") - print(f" count: {num_total_patches} patches") - print(f" rate: {num_total_patches / duration:.3f} patches/sec") + print(f' duration: {duration:.3f} sec') + print(f' count: {num_total_patches} patches') + print(f' rate: {num_total_patches / duration:.3f} patches/sec') if args.cache: if args.verbose: @@ -197,14 +197,14 @@ def main(args: argparse.Namespace) -> None: results_rows.append( { - "cached": args.cache, - "seed": args.seed, - "duration": duration, - "count": num_total_patches, - "rate": num_total_patches / duration, - "sampler": sampler.__class__.__name__, - "batch_size": args.batch_size, - "num_workers": args.num_workers, + 'cached': args.cache, + 'seed': args.seed, + 'duration': duration, + 'count': num_total_patches, + 'rate': num_total_patches / duration, + 'sampler': sampler.__class__.__name__, + 'batch_size': args.batch_size, + 'num_workers': args.num_workers, } ) @@ -219,7 +219,7 @@ def main(args: argparse.Namespace) -> None: params = model.parameters() optimizer = optim.SGD(params, lr=0.0001) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu", args.device) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu', args.device) model = model.to(device) tic = time.time() @@ -242,44 +242,44 @@ def main(args: argparse.Namespace) -> None: duration = toc - tic if args.verbose: - print("\nResNet-34:") - print(f" duration: {duration:.3f} sec") - print(f" count: {num_total_patches} patches") - print(f" rate: {num_total_patches / duration:.3f} patches/sec") + print('\nResNet-34:') + print(f' duration: {duration:.3f} sec') + print(f' count: {num_total_patches} patches') + print(f' rate: {num_total_patches / duration:.3f} patches/sec') results_rows.append( { - "cached": args.cache, - "seed": args.seed, - "duration": duration, - "count": num_total_patches, - "rate": num_total_patches / duration, - "sampler": "ResNet-34", - "batch_size": args.batch_size, - "num_workers": args.num_workers, + 'cached': args.cache, + 'seed': args.seed, + 'duration': duration, + 'count': num_total_patches, + 'rate': num_total_patches / duration, + 'sampler': 'ResNet-34', + 'batch_size': args.batch_size, + 'num_workers': args.num_workers, } ) fieldnames = [ - "cached", - "seed", - "duration", - "count", - "rate", - "sampler", - "batch_size", - "num_workers", + 'cached', + 'seed', + 'duration', + 'count', + 'rate', + 'sampler', + 'batch_size', + 'num_workers', ] if not os.path.exists(args.output_fn): - with open(args.output_fn, "w") as f: + with open(args.output_fn, 'w') as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() - with open(args.output_fn, "a") as f: + with open(args.output_fn, 'a') as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writerows(results_rows) -if __name__ == "__main__": +if __name__ == '__main__': parser = set_up_parser() args = parser.parse_args() diff --git a/experiments/torchgeo/find_optimal_hyperparams.py b/experiments/torchgeo/find_optimal_hyperparams.py index 03f92c99d53..50fa96b9045 100755 --- a/experiments/torchgeo/find_optimal_hyperparams.py +++ b/experiments/torchgeo/find_optimal_hyperparams.py @@ -12,7 +12,7 @@ from tbparse import SummaryReader -OUTPUT_DIR = "" +OUTPUT_DIR = '' # mypy does not yet support recursive type hints @@ -25,22 +25,22 @@ def nested_dict() -> defaultdict[str, defaultdict]: # type: ignore[type-arg] return defaultdict(nested_dict) -if __name__ == "__main__": +if __name__ == '__main__': metrics = nested_dict() - logs = os.path.join(OUTPUT_DIR, "logs", "*", "version_*", "events*") + logs = os.path.join(OUTPUT_DIR, 'logs', '*', 'version_*', 'events*') for log in glob.iglob(logs): hyperparams = log.split(os.sep)[-3] reader = SummaryReader(log) df = reader.scalars # Some event logs are for train/val, others are for test - for split in ["train", "val", "test"]: - rmse = df.loc[df["tag"] == f"{split}_RMSE"] - mae = df.loc[df["tag"] == f"{split}_MAE"] + for split in ['train', 'val', 'test']: + rmse = df.loc[df['tag'] == f'{split}_RMSE'] + mae = df.loc[df['tag'] == f'{split}_MAE'] if len(rmse): - metrics[hyperparams][split]["RMSE"] = rmse.iloc[-1]["value"] + metrics[hyperparams][split]['RMSE'] = rmse.iloc[-1]['value'] if len(mae): - metrics[hyperparams][split]["MAE"] = mae.iloc[-1]["value"] + metrics[hyperparams][split]['MAE'] = mae.iloc[-1]['value'] print(json.dumps(metrics, sort_keys=True, indent=4)) diff --git a/experiments/torchgeo/plot_bar_chart.py b/experiments/torchgeo/plot_bar_chart.py index 9f421183789..f4c9fbb7e7e 100755 --- a/experiments/torchgeo/plot_bar_chart.py +++ b/experiments/torchgeo/plot_bar_chart.py @@ -8,33 +8,33 @@ import pandas as pd import seaborn as sns -df1 = pd.read_csv("original-benchmark-results.csv") -df2 = pd.read_csv("warped-benchmark-results.csv") +df1 = pd.read_csv('original-benchmark-results.csv') +df2 = pd.read_csv('warped-benchmark-results.csv') -mean1 = df1.groupby("sampler").mean() -mean2 = df2.groupby("sampler").mean() +mean1 = df1.groupby('sampler').mean() +mean2 = df2.groupby('sampler').mean() cached1 = ( - df1[(df1["cached"]) & (df1["sampler"] != "resnet18")].groupby("sampler").mean() + df1[(df1['cached']) & (df1['sampler'] != 'resnet18')].groupby('sampler').mean() ) cached2 = ( - df2[(df2["cached"]) & (df2["sampler"] != "resnet18")].groupby("sampler").mean() + df2[(df2['cached']) & (df2['sampler'] != 'resnet18')].groupby('sampler').mean() ) not_cached1 = ( - df1[(~df1["cached"]) & (df1["sampler"] != "resnet18")].groupby("sampler").mean() + df1[(~df1['cached']) & (df1['sampler'] != 'resnet18')].groupby('sampler').mean() ) not_cached2 = ( - df2[(~df2["cached"]) & (df2["sampler"] != "resnet18")].groupby("sampler").mean() + df2[(~df2['cached']) & (df2['sampler'] != 'resnet18')].groupby('sampler').mean() ) -print("cached, original\n", cached1) -print("cached, warped\n", cached2) -print("not cached, original\n", not_cached1) -print("not cached, warped\n", not_cached2) +print('cached, original\n', cached1) +print('cached, warped\n', cached2) +print('not cached, original\n', not_cached1) +print('not cached, warped\n', not_cached2) cmap = sns.color_palette() -labels = ["GridGeoSampler", "RandomBatchGeoSampler", "RandomGeoSampler"] +labels = ['GridGeoSampler', 'RandomBatchGeoSampler', 'RandomGeoSampler'] fig, ax = plt.subplots() x = np.arange(3) @@ -42,34 +42,34 @@ rects1 = ax.bar( x - width * 3 / 2, - not_cached1["rate"], + not_cached1['rate'], width, - label="Raw Data, Not Cached", + label='Raw Data, Not Cached', color=cmap[0], ) rects2 = ax.bar( x - width * 1 / 2, - not_cached2["rate"], + not_cached2['rate'], width, - label="Preprocessed, Not Cached", + label='Preprocessed, Not Cached', color=cmap[1], ) rects2 = ax.bar( - x + width * 1 / 2, cached1["rate"], width, label="Raw Data, Cached", color=cmap[2] + x + width * 1 / 2, cached1['rate'], width, label='Raw Data, Cached', color=cmap[2] ) rects3 = ax.bar( x + width * 3 / 2, - cached2["rate"], + cached2['rate'], width, - label="Preprocessed, Cached", + label='Preprocessed, Cached', color=cmap[3], ) -ax.set_ylabel("sampling rate (patches/sec)", fontsize=12) +ax.set_ylabel('sampling rate (patches/sec)', fontsize=12) ax.set_xticks(x) ax.set_xticklabels(labels, fontsize=12) -ax.tick_params(axis="x", labelrotation=10) -ax.legend(fontsize="large") +ax.tick_params(axis='x', labelrotation=10) +ax.legend(fontsize='large') plt.gca().spines.right.set_visible(False) plt.gca().spines.top.set_visible(False) diff --git a/experiments/torchgeo/plot_dataloader_benchmark.py b/experiments/torchgeo/plot_dataloader_benchmark.py index 4c313ca174f..9804fb384e4 100755 --- a/experiments/torchgeo/plot_dataloader_benchmark.py +++ b/experiments/torchgeo/plot_dataloader_benchmark.py @@ -7,15 +7,15 @@ import pandas as pd import seaborn as sns -df = pd.read_csv("warped-benchmark-results.csv") +df = pd.read_csv('warped-benchmark-results.csv') -random_cached = df[(df["sampler"] == "RandomGeoSampler") & (df["cached"])] -random_batch_cached = df[(df["sampler"] == "RandomBatchGeoSampler") & (df["cached"])] -grid_cached = df[(df["sampler"] == "GridGeoSampler") & (df["cached"])] +random_cached = df[(df['sampler'] == 'RandomGeoSampler') & (df['cached'])] +random_batch_cached = df[(df['sampler'] == 'RandomBatchGeoSampler') & (df['cached'])] +grid_cached = df[(df['sampler'] == 'GridGeoSampler') & (df['cached'])] other = [ - ("RandomGeoSampler", random_cached), - ("RandomBatchGeoSampler", random_batch_cached), - ("GridGeoSampler", grid_cached), + ('RandomGeoSampler', random_cached), + ('RandomBatchGeoSampler', random_batch_cached), + ('GridGeoSampler', grid_cached), ] cmap = sns.color_palette() @@ -23,19 +23,19 @@ ax = plt.gca() for i, (label, df) in enumerate(other): - df = df.groupby("batch_size") - ax.plot(df.mean().index, df.mean()["rate"], color=cmap[i], label=label) + df = df.groupby('batch_size') + ax.plot(df.mean().index, df.mean()['rate'], color=cmap[i], label=label) ax.fill_between( - df.mean().index, df.min()["rate"], df.max()["rate"], color=cmap[i], alpha=0.2 + df.mean().index, df.min()['rate'], df.max()['rate'], color=cmap[i], alpha=0.2 ) -ax.set_xscale("log") +ax.set_xscale('log') ax.set_xticks([16, 32, 64, 128, 256]) -ax.set_xticklabels(["16", "32", "64", "128", "256"], fontsize=12) -ax.set_xlabel("batch size", fontsize=12) -ax.set_ylabel("sampling rate (patches/sec)", fontsize=12) -ax.legend(loc="center right", fontsize="large") +ax.set_xticklabels(['16', '32', '64', '128', '256'], fontsize=12) +ax.set_xlabel('batch size', fontsize=12) +ax.set_ylabel('sampling rate (patches/sec)', fontsize=12) +ax.legend(loc='center right', fontsize='large') plt.gca().spines.right.set_visible(False) plt.gca().spines.top.set_visible(False) diff --git a/experiments/torchgeo/plot_percentage_benchmark.py b/experiments/torchgeo/plot_percentage_benchmark.py index e0f2aa3b0e8..ac6b6aca9ec 100755 --- a/experiments/torchgeo/plot_percentage_benchmark.py +++ b/experiments/torchgeo/plot_percentage_benchmark.py @@ -7,32 +7,32 @@ import pandas as pd import seaborn as sns -df1 = pd.read_csv("original-benchmark-results.csv") -df2 = pd.read_csv("warped-benchmark-results.csv") +df1 = pd.read_csv('original-benchmark-results.csv') +df2 = pd.read_csv('warped-benchmark-results.csv') -random_cached1 = df1[(df1["sampler"] == "RandomGeoSampler") & (df1["cached"])] -random_cached2 = df2[(df2["sampler"] == "RandomGeoSampler") & (df2["cached"])] +random_cached1 = df1[(df1['sampler'] == 'RandomGeoSampler') & (df1['cached'])] +random_cached2 = df2[(df2['sampler'] == 'RandomGeoSampler') & (df2['cached'])] random_cachedp = random_cached1 -random_cachedp["rate"] /= random_cached2["rate"] +random_cachedp['rate'] /= random_cached2['rate'] random_batch_cached1 = df1[ - (df1["sampler"] == "RandomBatchGeoSampler") & (df1["cached"]) + (df1['sampler'] == 'RandomBatchGeoSampler') & (df1['cached']) ] random_batch_cached2 = df2[ - (df2["sampler"] == "RandomBatchGeoSampler") & (df2["cached"]) + (df2['sampler'] == 'RandomBatchGeoSampler') & (df2['cached']) ] random_batch_cachedp = random_batch_cached1 -random_batch_cachedp["rate"] /= random_batch_cached2["rate"] +random_batch_cachedp['rate'] /= random_batch_cached2['rate'] -grid_cached1 = df1[(df1["sampler"] == "GridGeoSampler") & (df1["cached"])] -grid_cached2 = df2[(df2["sampler"] == "GridGeoSampler") & (df2["cached"])] +grid_cached1 = df1[(df1['sampler'] == 'GridGeoSampler') & (df1['cached'])] +grid_cached2 = df2[(df2['sampler'] == 'GridGeoSampler') & (df2['cached'])] grid_cachedp = grid_cached1 -grid_cachedp["rate"] /= grid_cached2["rate"] +grid_cachedp['rate'] /= grid_cached2['rate'] other = [ - ("RandomGeoSampler (cached)", random_cachedp), - ("RandomBatchGeoSampler (cached)", random_batch_cachedp), - ("GridGeoSampler (cached)", grid_cachedp), + ('RandomGeoSampler (cached)', random_cachedp), + ('RandomBatchGeoSampler (cached)', random_batch_cachedp), + ('GridGeoSampler (cached)', grid_cachedp), ] cmap = sns.color_palette() @@ -40,17 +40,17 @@ ax = plt.gca() for i, (label, df) in enumerate(other): - df = df.groupby("batch_size") - ax.plot([16, 32, 64, 128, 256], df.mean()["rate"], color=cmap[i], label=label) + df = df.groupby('batch_size') + ax.plot([16, 32, 64, 128, 256], df.mean()['rate'], color=cmap[i], label=label) ax.fill_between( - df.mean().index, df.min()["rate"], df.max()["rate"], color=cmap[i], alpha=0.2 + df.mean().index, df.min()['rate'], df.max()['rate'], color=cmap[i], alpha=0.2 ) -ax.set_xscale("log") +ax.set_xscale('log') ax.set_xticks([16, 32, 64, 128, 256]) -ax.set_xticklabels(["16", "32", "64", "128", "256"]) -ax.set_xlabel("batch size") -ax.set_ylabel("% sampling rate (patches/sec)") +ax.set_xticklabels(['16', '32', '64', '128', '256']) +ax.set_xlabel('batch size') +ax.set_ylabel('% sampling rate (patches/sec)') ax.legend() plt.show() diff --git a/experiments/torchgeo/run_benchmarks_experiments.py b/experiments/torchgeo/run_benchmarks_experiments.py index 8f10c402e44..363f32034f0 100755 --- a/experiments/torchgeo/run_benchmarks_experiments.py +++ b/experiments/torchgeo/run_benchmarks_experiments.py @@ -16,42 +16,42 @@ BATCH_SIZE_OPTIONS = [16, 32, 64, 128, 256, 512] # path to a directory containing Landsat 8 GeoTIFFs -LANDSAT_DATA_ROOT = "" +LANDSAT_DATA_ROOT = '' # path to a directory containing CDL GeoTIFF(s) -CDL_DATA_ROOT = "" +CDL_DATA_ROOT = '' total_num_experiments = len(SEED_OPTIONS) * len(CACHE_OPTIONS) * len(BATCH_SIZE_OPTIONS) -if __name__ == "__main__": +if __name__ == '__main__': # With 6 workers, this will use ~60% of available RAM - os.environ["GDAL_CACHEMAX"] = "10%" + os.environ['GDAL_CACHEMAX'] = '10%' tic = time.time() for i, (cache, batch_size, seed) in enumerate( itertools.product(CACHE_OPTIONS, BATCH_SIZE_OPTIONS, SEED_OPTIONS) ): - print(f"\n{i}/{total_num_experiments} -- {time.time() - tic}") + print(f'\n{i}/{total_num_experiments} -- {time.time() - tic}') tic = time.time() command: list[str] = [ - "python", - "benchmark.py", - "--landsat-root", + 'python', + 'benchmark.py', + '--landsat-root', LANDSAT_DATA_ROOT, - "--cdl-root", + '--cdl-root', CDL_DATA_ROOT, - "--num-workers", - "6", - "--batch-size", + '--num-workers', + '6', + '--batch-size', str(batch_size), - "--epoch-size", + '--epoch-size', str(EPOCH_SIZE), - "--seed", + '--seed', str(seed), - "--verbose", + '--verbose', ] if cache: - command.append("--cache") + command.append('--cache') subprocess.call(command) diff --git a/experiments/torchgeo/run_chesapeake_cvpr_experiments.py b/experiments/torchgeo/run_chesapeake_cvpr_experiments.py index 7e3f3b305a3..d837a20c923 100755 --- a/experiments/torchgeo/run_chesapeake_cvpr_experiments.py +++ b/experiments/torchgeo/run_chesapeake_cvpr_experiments.py @@ -12,30 +12,30 @@ # list of GPU IDs that we want to use, one job will be started for every ID in the list GPUS = [0] DRY_RUN = False # if False then print out the commands to be run, if True then run -DATA_DIR = "" # path to the ChesapeakeCVPR data directory +DATA_DIR = '' # path to the ChesapeakeCVPR data directory # Hyperparameter options -training_set_options = ["de"] -model_options = ["unet"] -backbone_options = ["resnet18", "resnet50"] +training_set_options = ['de'] +model_options = ['unet'] +backbone_options = ['resnet18', 'resnet50'] lr_options = [1e-2, 1e-3, 1e-4] -loss_options = ["ce", "jaccard"] -weight_init_options = ["null", "imagenet"] +loss_options = ['ce', 'jaccard'] +weight_init_options = ['null', 'imagenet'] -def do_work(work: "Queue[str]", gpu_idx: int) -> bool: +def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: """Process for each ID in GPUS.""" while not work.empty(): experiment = work.get() - experiment = experiment.replace("GPU", str(gpu_idx)) + experiment = experiment.replace('GPU', str(gpu_idx)) print(experiment) if not DRY_RUN: - subprocess.call(experiment.split(" ")) + subprocess.call(experiment.split(' ')) return True -if __name__ == "__main__": - work: "Queue[str]" = Queue() +if __name__ == '__main__': + work: 'Queue[str]' = Queue() for train_state, model, backbone, lr, loss, weight_init in itertools.product( training_set_options, @@ -45,30 +45,30 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool: loss_options, weight_init_options, ): - experiment_name = f"{train_state}_{model}_{backbone}_{lr}_{loss}_{weight_init}" + experiment_name = f'{train_state}_{model}_{backbone}_{lr}_{loss}_{weight_init}' - output_dir = os.path.join("output", "chesapeake-cvpr_experiments") - log_dir = os.path.join(output_dir, "logs") - config_file = os.path.join("conf", "chesapeake_cvpr.yaml") + output_dir = os.path.join('output', 'chesapeake-cvpr_experiments') + log_dir = os.path.join(output_dir, 'logs') + config_file = os.path.join('conf', 'chesapeake_cvpr.yaml') if not os.path.exists(os.path.join(output_dir, experiment_name)): command = ( - "python train.py" - + f" config_file={config_file}" - + f" experiment.name={experiment_name}" - + f" experiment.module.model={model}" - + f" experiment.module.backbone={backbone}" - + f" experiment.module.weights={weight_init}" - + f" experiment.module.learning_rate={lr}" - + f" experiment.module.loss={loss}" - + " experiment.module.class_set=7" + 'python train.py' + + f' config_file={config_file}' + + f' experiment.name={experiment_name}' + + f' experiment.module.model={model}' + + f' experiment.module.backbone={backbone}' + + f' experiment.module.weights={weight_init}' + + f' experiment.module.learning_rate={lr}' + + f' experiment.module.loss={loss}' + + ' experiment.module.class_set=7' + f" experiment.datamodule.train_splits=['{train_state}-train']" + f" experiment.datamodule.val_splits=['{train_state}-val']" + f" experiment.datamodule.test_splits=['{train_state}-test']" - + f" program.output_dir={output_dir}" - + f" program.log_dir={log_dir}" - + f" program.data_dir={DATA_DIR}" - + " trainer.gpus=[GPU]" + + f' program.output_dir={output_dir}' + + f' program.log_dir={log_dir}' + + f' program.data_dir={DATA_DIR}' + + ' trainer.gpus=[GPU]' ) command = command.strip() diff --git a/experiments/torchgeo/run_chesapeakecvpr_models.py b/experiments/torchgeo/run_chesapeakecvpr_models.py index 382f594af67..2a3ae06c730 100755 --- a/experiments/torchgeo/run_chesapeakecvpr_models.py +++ b/experiments/torchgeo/run_chesapeakecvpr_models.py @@ -13,7 +13,7 @@ from torchgeo.datamodules import ChesapeakeCVPRDataModule from torchgeo.trainers.chesapeake import SemanticSegmentationTask -ALL_TEST_SPLITS = [["de-val"], ["pa-test"], ["ny-test"], ["pa-test", "ny-test"]] +ALL_TEST_SPLITS = [['de-val'], ['pa-test'], ['ny-test'], ['pa-test', 'ny-test']] def set_up_parser() -> argparse.ArgumentParser: @@ -27,33 +27,33 @@ def set_up_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--input-dir", + '--input-dir', required=True, type=str, - help="directory containing the experiment run directories", - metavar="ROOT", + help='directory containing the experiment run directories', + metavar='ROOT', ) parser.add_argument( - "--chesapeakecvpr-root", + '--chesapeakecvpr-root', required=True, type=str, - help="directory containing the ChesapeakeCVPR dataset", - metavar="ROOT", + help='directory containing the ChesapeakeCVPR dataset', + metavar='ROOT', ) parser.add_argument( - "--output-fn", - default="chesapeakecvpr-results.csv", + '--output-fn', + default='chesapeakecvpr-results.csv', type=str, - help="path to the CSV file to write results", - metavar="FILE", + help='path to the CSV file to write results', + metavar='FILE', ) parser.add_argument( - "-d", - "--device", + '-d', + '--device', default=0, type=int, - help="GPU ID to use, ignored if no GPUs are available", - metavar="ID", + help='GPU ID to use, ignored if no GPUs are available', + metavar='ID', ) return parser @@ -66,27 +66,27 @@ def main(args: argparse.Namespace) -> None: args: command-line arguments """ if os.path.exists(args.output_fn): - print(f"The output file {args.output_fn} already exists, exiting...") + print(f'The output file {args.output_fn} already exists, exiting...') return # Set up the result file fieldnames = [ - "train-state", - "model", - "learning-rate", - "initialization", - "loss", - "test-state", - "acc", - "iou", + 'train-state', + 'model', + 'learning-rate', + 'initialization', + 'loss', + 'test-state', + 'acc', + 'iou', ] - with open(args.output_fn, "w") as f: + with open(args.output_fn, 'w') as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() # Test loop trainer = Trainer( - accelerator="auto", + accelerator='auto', devices=[args.device], logger=False, enable_progress_bar=False, @@ -96,13 +96,13 @@ def main(args: argparse.Namespace) -> None: for experiment_dir in os.listdir(args.input_dir): checkpoint_fn = None for fn in os.listdir(os.path.join(args.input_dir, experiment_dir)): - if fn.startswith("epoch") and fn.endswith(".ckpt"): + if fn.startswith('epoch') and fn.endswith('.ckpt'): checkpoint_fn = fn break if checkpoint_fn is None: print( - f"Skipping {os.path.join(args.input_dir, experiment_dir)} as we are not" - + " able to find a checkpoint file" + f'Skipping {os.path.join(args.input_dir, experiment_dir)} as we are not' + + ' able to find a checkpoint file' ) continue checkpoint_fn = os.path.join(args.input_dir, experiment_dir, checkpoint_fn) @@ -113,22 +113,22 @@ def main(args: argparse.Namespace) -> None: model.eval() except KeyError: print( - f"Skipping {experiment_dir} as we are not able to load a valid" - + f" SemanticSegmentationTask from {checkpoint_fn}" + f'Skipping {experiment_dir} as we are not able to load a valid' + + f' SemanticSegmentationTask from {checkpoint_fn}' ) continue try: - experiment_dir_parts = experiment_dir.split("_") + experiment_dir_parts = experiment_dir.split('_') train_state = experiment_dir_parts[0] model_name = experiment_dir_parts[1] learning_rate = experiment_dir_parts[2] loss = experiment_dir_parts[3] - initialization = "random" if len(experiment_dir_parts) == 5 else "imagenet" + initialization = 'random' if len(experiment_dir_parts) == 5 else 'imagenet' except IndexError: print( - f"Skipping {experiment_dir} as the directory name is not in the" - + " expected format" + f'Skipping {experiment_dir} as the directory name is not in the' + + ' expected format' ) continue @@ -136,8 +136,8 @@ def main(args: argparse.Namespace) -> None: for test_splits in ALL_TEST_SPLITS: dm = ChesapeakeCVPRDataModule( root=args.chesapeakecvpr_root, - train_splits=["de-train"], - val_splits=["de-val"], + train_splits=['de-train'], + val_splits=['de-val'], test_splits=test_splits, batch_size=32, num_workers=8, @@ -147,21 +147,21 @@ def main(args: argparse.Namespace) -> None: print(experiment_dir, test_splits, results[0]) row = { - "train-state": train_state, - "model": model_name, - "learning-rate": learning_rate, - "initialization": initialization, - "loss": loss, - "test-state": "_".join(test_splits), - "acc": results[0]["test_Accuracy"], - "iou": results[0]["test_IoU"], + 'train-state': train_state, + 'model': model_name, + 'learning-rate': learning_rate, + 'initialization': initialization, + 'loss': loss, + 'test-state': '_'.join(test_splits), + 'acc': results[0]['test_Accuracy'], + 'iou': results[0]['test_IoU'], } - with open(args.output_fn, "a") as f: + with open(args.output_fn, 'a') as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writerow(row) -if __name__ == "__main__": +if __name__ == '__main__': parser = set_up_parser() args = parser.parse_args() diff --git a/experiments/torchgeo/run_cowc_experiments.py b/experiments/torchgeo/run_cowc_experiments.py index 3d6a4bef26b..343526a434e 100755 --- a/experiments/torchgeo/run_cowc_experiments.py +++ b/experiments/torchgeo/run_cowc_experiments.py @@ -13,49 +13,49 @@ # list of GPU IDs that we want to use, one job will be started for every ID in the list GPUS = range(0) DRY_RUN = True # if True then print out the commands to be run, if False then run -DATA_DIR = "" # path to the COWC data directory +DATA_DIR = '' # path to the COWC data directory # Hyperparameter options -model_options = ["resnet18", "resnet50"] +model_options = ['resnet18', 'resnet50'] pretrained_options = [True, False] lr_options = [1e-2, 1e-3, 1e-4] -def do_work(work: "Queue[str]", gpu_idx: int) -> bool: +def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: """Process for each ID in GPUS.""" while not work.empty(): experiment = work.get() - experiment = experiment.replace("GPU", str(gpu_idx)) + experiment = experiment.replace('GPU', str(gpu_idx)) print(experiment) if not DRY_RUN: - subprocess.call(experiment.split(" ")) + subprocess.call(experiment.split(' ')) return True -if __name__ == "__main__": - work: "Queue[str]" = Queue() +if __name__ == '__main__': + work: 'Queue[str]' = Queue() for model, lr, pretrained in itertools.product( model_options, lr_options, pretrained_options ): - experiment_name = f"{model}_{lr}_{pretrained}" + experiment_name = f'{model}_{lr}_{pretrained}' - output_dir = os.path.join("output", "cowc_experiments") - log_dir = os.path.join(output_dir, "logs") - config_file = os.path.join("conf", "cowc_counting.yaml") + output_dir = os.path.join('output', 'cowc_experiments') + log_dir = os.path.join(output_dir, 'logs') + config_file = os.path.join('conf', 'cowc_counting.yaml') if not os.path.exists(os.path.join(output_dir, experiment_name)): command = ( - "python train.py" - + f" config_file={config_file}" - + f" experiment.name={experiment_name}" - + f" experiment.module.model={model}" - + f" experiment.module.learning_rate={lr}" - + f" experiment.module.pretrained={pretrained}" - + f" program.output_dir={output_dir}" - + f" program.log_dir={log_dir}" - + f" program.data_dir={DATA_DIR}" - + " trainer.gpus=[GPU]" + 'python train.py' + + f' config_file={config_file}' + + f' experiment.name={experiment_name}' + + f' experiment.module.model={model}' + + f' experiment.module.learning_rate={lr}' + + f' experiment.module.pretrained={pretrained}' + + f' program.output_dir={output_dir}' + + f' program.log_dir={log_dir}' + + f' program.data_dir={DATA_DIR}' + + ' trainer.gpus=[GPU]' ) command = command.strip() diff --git a/experiments/torchgeo/run_cowc_seed_experiments.py b/experiments/torchgeo/run_cowc_seed_experiments.py index 0c2adcc136f..ffef876818c 100755 --- a/experiments/torchgeo/run_cowc_seed_experiments.py +++ b/experiments/torchgeo/run_cowc_seed_experiments.py @@ -13,51 +13,51 @@ # list of GPU IDs that we want to use, one job will be started for every ID in the list GPUS = range(1) DRY_RUN = True # if True then print out the commands to be run, if False then run -DATA_DIR = "" # path to the COWC data directory +DATA_DIR = '' # path to the COWC data directory # Hyperparameter options -model_options = ["resnet18", "resnet50"] +model_options = ['resnet18', 'resnet50'] pretrained_options = [True] lr_options = [1e-4] seeds = range(10) -def do_work(work: "Queue[str]", gpu_idx: int) -> bool: +def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: """Process for each ID in GPUS.""" while not work.empty(): experiment = work.get() - experiment = experiment.replace("GPU", str(gpu_idx)) + experiment = experiment.replace('GPU', str(gpu_idx)) print(experiment) if not DRY_RUN: - subprocess.call(experiment.split(" ")) + subprocess.call(experiment.split(' ')) return True -if __name__ == "__main__": - work: "Queue[str]" = Queue() +if __name__ == '__main__': + work: 'Queue[str]' = Queue() for model, lr, pretrained, seed in itertools.product( model_options, lr_options, pretrained_options, seeds ): - experiment_name = f"{model}_{lr}_{pretrained}_{seed}" + experiment_name = f'{model}_{lr}_{pretrained}_{seed}' - output_dir = os.path.join("output", "cowc_seed_experiments") - log_dir = os.path.join(output_dir, "logs") - config_file = os.path.join("conf", "cowc_counting.yaml") + output_dir = os.path.join('output', 'cowc_seed_experiments') + log_dir = os.path.join(output_dir, 'logs') + config_file = os.path.join('conf', 'cowc_counting.yaml') if not os.path.exists(os.path.join(output_dir, experiment_name)): command = ( - "python train.py" - + f" config_file={config_file}" - + f" experiment.name={experiment_name}" - + f" experiment.module.model={model}" - + f" experiment.module.learning_rate={lr}" - + f" experiment.module.pretrained={pretrained}" - + f" program.output_dir={output_dir}" - + f" program.log_dir={log_dir}" - + f" program.data_dir={DATA_DIR}" - + f" program.seed={seed}" - + " trainer.gpus=[GPU]" + 'python train.py' + + f' config_file={config_file}' + + f' experiment.name={experiment_name}' + + f' experiment.module.model={model}' + + f' experiment.module.learning_rate={lr}' + + f' experiment.module.pretrained={pretrained}' + + f' program.output_dir={output_dir}' + + f' program.log_dir={log_dir}' + + f' program.data_dir={DATA_DIR}' + + f' program.seed={seed}' + + ' trainer.gpus=[GPU]' ) command = command.strip() diff --git a/experiments/torchgeo/run_landcoverai_experiments.py b/experiments/torchgeo/run_landcoverai_experiments.py index 64749b5de5d..621f8d2c57c 100755 --- a/experiments/torchgeo/run_landcoverai_experiments.py +++ b/experiments/torchgeo/run_landcoverai_experiments.py @@ -12,53 +12,53 @@ # list of GPU IDs that we want to use, one job will be started for every ID in the list GPUS = [0] DRY_RUN = False # if False then print out the commands to be run, if True then run -DATA_DIR = "" # path to the LandcoverAI data directory +DATA_DIR = '' # path to the LandcoverAI data directory # Hyperparameter options -model_options = ["unet"] -backbone_options = ["resnet18", "resnet50"] +model_options = ['unet'] +backbone_options = ['resnet18', 'resnet50'] lr_options = [1e-2, 1e-3, 1e-4] -loss_options = ["ce", "jaccard"] -weight_init_options = ["null", "imagenet"] +loss_options = ['ce', 'jaccard'] +weight_init_options = ['null', 'imagenet'] -def do_work(work: "Queue[str]", gpu_idx: int) -> bool: +def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: """Process for each ID in GPUS.""" while not work.empty(): experiment = work.get() - experiment = experiment.replace("GPU", str(gpu_idx)) + experiment = experiment.replace('GPU', str(gpu_idx)) print(experiment) if not DRY_RUN: - subprocess.call(experiment.split(" ")) + subprocess.call(experiment.split(' ')) return True -if __name__ == "__main__": - work: "Queue[str]" = Queue() +if __name__ == '__main__': + work: 'Queue[str]' = Queue() for model, backbone, lr, loss, weight_init in itertools.product( model_options, backbone_options, lr_options, loss_options, weight_init_options ): - experiment_name = f"{model}_{backbone}_{lr}_{loss}_{weight_init}" + experiment_name = f'{model}_{backbone}_{lr}_{loss}_{weight_init}' - output_dir = os.path.join("output", "landcoverai_experiments") - log_dir = os.path.join(output_dir, "logs") - config_file = os.path.join("conf", "landcoverai.yaml") + output_dir = os.path.join('output', 'landcoverai_experiments') + log_dir = os.path.join(output_dir, 'logs') + config_file = os.path.join('conf', 'landcoverai.yaml') if not os.path.exists(os.path.join(output_dir, experiment_name)): command = ( - "python train.py" - + f" config_file={config_file}" - + f" experiment.name={experiment_name}" - + f" experiment.module.segmentation_model={model}" - + f" experiment.module.learning_rate={lr}" - + f" experiment.module.loss={loss}" - + f" experiment.module.backbone={backbone}" - + f" experiment.module.weights={weight_init}" - + f" program.output_dir={output_dir}" - + f" program.log_dir={log_dir}" - + f" program.data_dir={DATA_DIR}" - + " trainer.gpus=[GPU]" + 'python train.py' + + f' config_file={config_file}' + + f' experiment.name={experiment_name}' + + f' experiment.module.segmentation_model={model}' + + f' experiment.module.learning_rate={lr}' + + f' experiment.module.loss={loss}' + + f' experiment.module.backbone={backbone}' + + f' experiment.module.weights={weight_init}' + + f' program.output_dir={output_dir}' + + f' program.log_dir={log_dir}' + + f' program.data_dir={DATA_DIR}' + + ' trainer.gpus=[GPU]' ) command = command.strip() diff --git a/experiments/torchgeo/run_landcoverai_seed_experiments.py b/experiments/torchgeo/run_landcoverai_seed_experiments.py index a6778edb66d..6421377db3d 100755 --- a/experiments/torchgeo/run_landcoverai_seed_experiments.py +++ b/experiments/torchgeo/run_landcoverai_seed_experiments.py @@ -12,53 +12,53 @@ # list of GPU IDs that we want to use, one job will be started for every ID in the list GPUS = [0] DRY_RUN = False # if False then print out the commands to be run, if True then run -DATA_DIR = "" # path to the LandcoverAI data directory +DATA_DIR = '' # path to the LandcoverAI data directory # Hyperparameter options -model_options = ["unet"] -backbone_options = ["resnet18", "resnet50"] +model_options = ['unet'] +backbone_options = ['resnet18', 'resnet50'] lr_options = [1e-2, 1e-3, 1e-4] -loss_options = ["ce", "jaccard"] -weight_init_options = ["null", "imagenet"] +loss_options = ['ce', 'jaccard'] +weight_init_options = ['null', 'imagenet'] -def do_work(work: "Queue[str]", gpu_idx: int) -> bool: +def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: """Process for each ID in GPUS.""" while not work.empty(): experiment = work.get() - experiment = experiment.replace("GPU", str(gpu_idx)) + experiment = experiment.replace('GPU', str(gpu_idx)) print(experiment) if not DRY_RUN: - subprocess.call(experiment.split(" ")) + subprocess.call(experiment.split(' ')) return True -if __name__ == "__main__": - work: "Queue[str]" = Queue() +if __name__ == '__main__': + work: 'Queue[str]' = Queue() for model, backbone, lr, loss, weight_init in itertools.product( model_options, backbone_options, lr_options, loss_options, weight_init_options ): - experiment_name = f"{model}_{backbone}_{lr}_{loss}_{weight_init}" + experiment_name = f'{model}_{backbone}_{lr}_{loss}_{weight_init}' - output_dir = os.path.join("output", "landcoverai_experiments") - log_dir = os.path.join(output_dir, "logs") - config_file = os.path.join("conf", "landcoverai.yaml") + output_dir = os.path.join('output', 'landcoverai_experiments') + log_dir = os.path.join(output_dir, 'logs') + config_file = os.path.join('conf', 'landcoverai.yaml') if not os.path.exists(os.path.join(output_dir, experiment_name)): command = ( - "python train.py" - + f" config_file={config_file}" - + f" experiment.name={experiment_name}" - + f" experiment.module.model={model}" - + f" experiment.module.learning_rate={lr}" - + f" experiment.module.loss={loss}" - + f" experiment.module.backbone={backbone}" - + f" experiment.module.weights={weight_init}" - + f" program.output_dir={output_dir}" - + f" program.log_dir={log_dir}" - + f" program.data_dir={DATA_DIR}" - + " trainer.gpus=[GPU]" + 'python train.py' + + f' config_file={config_file}' + + f' experiment.name={experiment_name}' + + f' experiment.module.model={model}' + + f' experiment.module.learning_rate={lr}' + + f' experiment.module.loss={loss}' + + f' experiment.module.backbone={backbone}' + + f' experiment.module.weights={weight_init}' + + f' program.output_dir={output_dir}' + + f' program.log_dir={log_dir}' + + f' program.data_dir={DATA_DIR}' + + ' trainer.gpus=[GPU]' ) command = command.strip() diff --git a/experiments/torchgeo/run_resisc45_experiments.py b/experiments/torchgeo/run_resisc45_experiments.py index b043da8400d..6897ea12772 100755 --- a/experiments/torchgeo/run_resisc45_experiments.py +++ b/experiments/torchgeo/run_resisc45_experiments.py @@ -12,52 +12,52 @@ # list of GPU IDs that we want to use, one job will be started for every ID in the list GPUS = [0] DRY_RUN = False # if False then print out the commands to be run, if True then run -DATA_DIR = "" # path to the RESISC45 data directory +DATA_DIR = '' # path to the RESISC45 data directory # Hyperparameter options -model_options = ["resnet18", "resnet50"] +model_options = ['resnet18', 'resnet50'] lr_options = [1e-2, 1e-3, 1e-4] -loss_options = ["ce"] -weight_options = ["imagenet_only", "random"] +loss_options = ['ce'] +weight_options = ['imagenet_only', 'random'] -def do_work(work: "Queue[str]", gpu_idx: int) -> bool: +def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: """Process for each ID in GPUS.""" while not work.empty(): experiment = work.get() - experiment = experiment.replace("GPU", str(gpu_idx)) + experiment = experiment.replace('GPU', str(gpu_idx)) print(experiment) if not DRY_RUN: - subprocess.call(experiment.split(" ")) + subprocess.call(experiment.split(' ')) return True -if __name__ == "__main__": - work: "Queue[str]" = Queue() +if __name__ == '__main__': + work: 'Queue[str]' = Queue() for model, lr, loss, weights in itertools.product( model_options, lr_options, loss_options, weight_options ): experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}" - output_dir = os.path.join("output", "resisc45_experiments") - log_dir = os.path.join(output_dir, "logs") - config_file = os.path.join("conf", "resisc45.yaml") + output_dir = os.path.join('output', 'resisc45_experiments') + log_dir = os.path.join(output_dir, 'logs') + config_file = os.path.join('conf', 'resisc45.yaml') if not os.path.exists(os.path.join(output_dir, experiment_name)): command = ( - "python train.py" - + f" config_file={config_file}" - + f" experiment.name={experiment_name}" - + f" experiment.module.model={model}" - + f" experiment.module.learning_rate={lr}" - + f" experiment.module.loss={loss}" - + f" experiment.module.weights={weights}" - + f" experiment.datamodule.weights={weights}" - + f" program.output_dir={output_dir}" - + f" program.log_dir={log_dir}" - + f" program.data_dir={DATA_DIR}" - + " trainer.gpus=[GPU]" + 'python train.py' + + f' config_file={config_file}' + + f' experiment.name={experiment_name}' + + f' experiment.module.model={model}' + + f' experiment.module.learning_rate={lr}' + + f' experiment.module.loss={loss}' + + f' experiment.module.weights={weights}' + + f' experiment.datamodule.weights={weights}' + + f' program.output_dir={output_dir}' + + f' program.log_dir={log_dir}' + + f' program.data_dir={DATA_DIR}' + + ' trainer.gpus=[GPU]' ) command = command.strip() diff --git a/experiments/torchgeo/run_so2sat_byol_experiments.py b/experiments/torchgeo/run_so2sat_byol_experiments.py index 01d67ee5cb1..169a010cef8 100755 --- a/experiments/torchgeo/run_so2sat_byol_experiments.py +++ b/experiments/torchgeo/run_so2sat_byol_experiments.py @@ -12,54 +12,54 @@ # list of GPU IDs that we want to use, one job will be started for every ID in the list GPUS = [0, 1, 2, 3, 3] DRY_RUN = False # if False then print out the commands to be run, if True then run -DATA_DIR = "" # path to the So2Sat data directory +DATA_DIR = '' # path to the So2Sat data directory # Hyperparameter options -model_options = ["resnet50"] +model_options = ['resnet50'] lr_options = [1e-4] -loss_options = ["ce"] +loss_options = ['ce'] weight_options: list[str] = [] # set paths to checkpoint files -bands_options = ["s2"] +bands_options = ['s2'] -def do_work(work: "Queue[str]", gpu_idx: int) -> bool: +def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: """Process for each ID in GPUS.""" while not work.empty(): experiment = work.get() - experiment = experiment.replace("GPU", str(gpu_idx)) + experiment = experiment.replace('GPU', str(gpu_idx)) print(experiment) if not DRY_RUN: - subprocess.call(experiment.split(" ")) + subprocess.call(experiment.split(' ')) return True -if __name__ == "__main__": - work: "Queue[str]" = Queue() +if __name__ == '__main__': + work: 'Queue[str]' = Queue() for model, lr, loss, weights, bands in itertools.product( model_options, lr_options, loss_options, weight_options, bands_options ): experiment_name = f"{model}_{lr}_{loss}_byol_{bands}-{weights.split('/')[-2]}" - output_dir = os.path.join("output", "so2sat_experiments") - log_dir = os.path.join(output_dir, "logs") - config_file = os.path.join("conf", "so2sat.yaml") + output_dir = os.path.join('output', 'so2sat_experiments') + log_dir = os.path.join(output_dir, 'logs') + config_file = os.path.join('conf', 'so2sat.yaml') if not os.path.exists(os.path.join(output_dir, experiment_name)): command = ( - "python train.py" - + f" config_file={config_file}" - + f" experiment.name={experiment_name}" - + f" experiment.module.model={model}" - + f" experiment.module.learning_rate={lr}" - + f" experiment.module.loss={loss}" - + f" experiment.module.weights={weights}" - + " experiment.module.in_channels=10" - + f" experiment.datamodule.bands={bands}" - + f" program.output_dir={output_dir}" - + f" program.log_dir={log_dir}" - + f" program.data_dir={DATA_DIR}" - + " trainer.gpus=[GPU]" + 'python train.py' + + f' config_file={config_file}' + + f' experiment.name={experiment_name}' + + f' experiment.module.model={model}' + + f' experiment.module.learning_rate={lr}' + + f' experiment.module.loss={loss}' + + f' experiment.module.weights={weights}' + + ' experiment.module.in_channels=10' + + f' experiment.datamodule.bands={bands}' + + f' program.output_dir={output_dir}' + + f' program.log_dir={log_dir}' + + f' program.data_dir={DATA_DIR}' + + ' trainer.gpus=[GPU]' ) command = command.strip() diff --git a/experiments/torchgeo/run_so2sat_experiments.py b/experiments/torchgeo/run_so2sat_experiments.py index 82f3e511b43..41e2fc04b5f 100755 --- a/experiments/torchgeo/run_so2sat_experiments.py +++ b/experiments/torchgeo/run_so2sat_experiments.py @@ -12,52 +12,52 @@ # list of GPU IDs that we want to use, one job will be started for every ID in the list GPUS = [0] DRY_RUN = False # if False then print out the commands to be run, if True then run -DATA_DIR = "" # path to the So2Sat data directory +DATA_DIR = '' # path to the So2Sat data directory # Hyperparameter options -model_options = ["resnet18", "resnet50"] +model_options = ['resnet18', 'resnet50'] lr_options = [1e-2, 1e-3, 1e-4] -loss_options = ["ce"] -weight_options = ["imagenet_only", "random", "imagenet_and_random", "random_rgb"] +loss_options = ['ce'] +weight_options = ['imagenet_only', 'random', 'imagenet_and_random', 'random_rgb'] -def do_work(work: "Queue[str]", gpu_idx: int) -> bool: +def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: """Process for each ID in GPUS.""" while not work.empty(): experiment = work.get() - experiment = experiment.replace("GPU", str(gpu_idx)) + experiment = experiment.replace('GPU', str(gpu_idx)) print(experiment) if not DRY_RUN: - subprocess.call(experiment.split(" ")) + subprocess.call(experiment.split(' ')) return True -if __name__ == "__main__": - work: "Queue[str]" = Queue() +if __name__ == '__main__': + work: 'Queue[str]' = Queue() for model, lr, loss, weights in itertools.product( model_options, lr_options, loss_options, weight_options ): experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}" - output_dir = os.path.join("output", "so2sat_experiments") - log_dir = os.path.join(output_dir, "logs") - config_file = os.path.join("conf", "so2sat.yaml") + output_dir = os.path.join('output', 'so2sat_experiments') + log_dir = os.path.join(output_dir, 'logs') + config_file = os.path.join('conf', 'so2sat.yaml') if not os.path.exists(os.path.join(output_dir, experiment_name)): command = ( - "python train.py" - + f" config_file={config_file}" - + f" experiment.name={experiment_name}" - + f" experiment.module.model={model}" - + f" experiment.module.learning_rate={lr}" - + f" experiment.module.loss={loss}" - + f" experiment.module.weights={weights}" - + f" experiment.datamodule.weights={weights}" - + f" program.output_dir={output_dir}" - + f" program.log_dir={log_dir}" - + f" program.data_dir={DATA_DIR}" - + " trainer.gpus=[GPU]" + 'python train.py' + + f' config_file={config_file}' + + f' experiment.name={experiment_name}' + + f' experiment.module.model={model}' + + f' experiment.module.learning_rate={lr}' + + f' experiment.module.loss={loss}' + + f' experiment.module.weights={weights}' + + f' experiment.datamodule.weights={weights}' + + f' program.output_dir={output_dir}' + + f' program.log_dir={log_dir}' + + f' program.data_dir={DATA_DIR}' + + ' trainer.gpus=[GPU]' ) command = command.strip() diff --git a/experiments/torchgeo/run_so2sat_seed_experiments.py b/experiments/torchgeo/run_so2sat_seed_experiments.py index 5b5e1e3d697..2d2efe1e248 100755 --- a/experiments/torchgeo/run_so2sat_seed_experiments.py +++ b/experiments/torchgeo/run_so2sat_seed_experiments.py @@ -12,54 +12,54 @@ # list of GPU IDs that we want to use, one job will be started for every ID in the list GPUS = [0] DRY_RUN = False # if False then print out the commands to be run, if True then run -DATA_DIR = "" # path to the So2Sat data directory +DATA_DIR = '' # path to the So2Sat data directory # Hyperparameter options -model_options = ["resnet18"] +model_options = ['resnet18'] lr_options = [1e-3] -loss_options = ["ce"] -weight_options = ["random"] +loss_options = ['ce'] +weight_options = ['random'] seeds = list(range(32)) -def do_work(work: "Queue[str]", gpu_idx: int) -> bool: +def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: """Process for each ID in GPUS.""" while not work.empty(): experiment = work.get() - experiment = experiment.replace("GPU", str(gpu_idx)) + experiment = experiment.replace('GPU', str(gpu_idx)) print(experiment) if not DRY_RUN: - subprocess.call(experiment.split(" ")) + subprocess.call(experiment.split(' ')) return True -if __name__ == "__main__": - work: "Queue[str]" = Queue() +if __name__ == '__main__': + work: 'Queue[str]' = Queue() for model, lr, loss, weights, seed in itertools.product( model_options, lr_options, loss_options, weight_options, seeds ): experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}_{seed}" - output_dir = os.path.join("output", "so2sat_seed_experiments") - log_dir = os.path.join(output_dir, "logs") - config_file = os.path.join("conf", "so2sat.yaml") + output_dir = os.path.join('output', 'so2sat_seed_experiments') + log_dir = os.path.join(output_dir, 'logs') + config_file = os.path.join('conf', 'so2sat.yaml') if not os.path.exists(os.path.join(output_dir, experiment_name)): command = ( - "python train.py" - + f" config_file={config_file}" - + f" experiment.name={experiment_name}" - + f" experiment.module.model={model}" - + f" experiment.module.learning_rate={lr}" - + f" experiment.module.loss={loss}" - + f" experiment.module.weights={weights}" - + f" experiment.datamodule.weights={weights}" - + f" program.output_dir={output_dir}" - + f" program.log_dir={log_dir}" - + f" program.data_dir={DATA_DIR}" - + f" program.seed={seed}" - + " trainer.gpus=[GPU]" + 'python train.py' + + f' config_file={config_file}' + + f' experiment.name={experiment_name}' + + f' experiment.module.model={model}' + + f' experiment.module.learning_rate={lr}' + + f' experiment.module.loss={loss}' + + f' experiment.module.weights={weights}' + + f' experiment.datamodule.weights={weights}' + + f' program.output_dir={output_dir}' + + f' program.log_dir={log_dir}' + + f' program.data_dir={DATA_DIR}' + + f' program.seed={seed}' + + ' trainer.gpus=[GPU]' ) command = command.strip() diff --git a/hubconf.py b/hubconf.py index 3e17c4ad738..2e2fe6b1a77 100644 --- a/hubconf.py +++ b/hubconf.py @@ -9,6 +9,6 @@ from torchgeo.models import resnet18, resnet50, vit_small_patch16_224 -__all__ = ("resnet18", "resnet50", "vit_small_patch16_224") +__all__ = ('resnet18', 'resnet50', 'vit_small_patch16_224') -dependencies = ["timm"] +dependencies = ['timm'] diff --git a/pyproject.toml b/pyproject.toml index b8a5089835a..9707f472d8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -262,6 +262,7 @@ extend-include = ["*.ipynb"] fix = true [tool.ruff.format] +quote-style = "single" skip-magic-trailing-comma = true [tool.ruff.lint] diff --git a/tests/data/agb_live_woody_density/data.py b/tests/data/agb_live_woody_density/data.py index 115a9772fba..68512ba6ca1 100755 --- a/tests/data/agb_live_woody_density/data.py +++ b/tests/data/agb_live_woody_density/data.py @@ -15,24 +15,24 @@ base_file = { - "type": "FeatureCollection", - "name": "Aboveground_Live_Woody_Biomass_Density", - "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, - "features": [ + 'type': 'FeatureCollection', + 'name': 'Aboveground_Live_Woody_Biomass_Density', + 'crs': {'type': 'name', 'properties': {'name': 'urn:ogc:def:crs:OGC:1.3:CRS84'}}, + 'features': [ { - "type": "Feature", - "properties": { - "tile_id": "00N_000E", - "Mg_px_1_download": os.path.join( - "tests", "data", "agb_live_woody_density", "00N_000E.tif" + 'type': 'Feature', + 'properties': { + 'tile_id': '00N_000E', + 'Mg_px_1_download': os.path.join( + 'tests', 'data', 'agb_live_woody_density', '00N_000E.tif' ), - "ObjectId": 1, - "Shape__Area": 1245542622548.8701, - "Shape__Length": 4464169.7655813899, + 'ObjectId': 1, + 'Shape__Area': 1245542622548.8701, + 'Shape__Length': 4464169.7655813899, }, - "geometry": { - "type": "Polygon", - "coordinates": [ + 'geometry': { + 'type': 'Polygon', + 'coordinates': [ [[0.0, 0.0], [10.0, 0.0], [10.0, -10.0], [0.0, -10.0], [0.0, 0.0]] ], }, @@ -43,36 +43,36 @@ def create_file(path: str, dtype: str, num_channels: int) -> None: profile = {} - profile["driver"] = "GTiff" - profile["dtype"] = dtype - profile["count"] = num_channels - profile["crs"] = "epsg:4326" - profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) - profile["height"] = SIZE - profile["width"] = SIZE - profile["compress"] = "lzw" - profile["predictor"] = 2 + profile['driver'] = 'GTiff' + profile['dtype'] = dtype + profile['count'] = num_channels + profile['crs'] = 'epsg:4326' + profile['transform'] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile['height'] = SIZE + profile['width'] = SIZE + profile['compress'] = 'lzw' + profile['predictor'] = 2 - if "float" in profile["dtype"]: - Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"]) + if 'float' in profile['dtype']: + Z = np.random.randn(SIZE, SIZE).astype(profile['dtype']) else: Z = np.random.randint( - np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"] + np.iinfo(profile['dtype']).max, size=(SIZE, SIZE), dtype=profile['dtype'] ) - with rasterio.open(path, "w", **profile) as src: - for i in range(1, profile["count"] + 1): + with rasterio.open(path, 'w', **profile) as src: + for i in range(1, profile['count'] + 1): src.write(Z, i) -if __name__ == "__main__": - base_file_name = "Aboveground_Live_Woody_Biomass_Density.geojson" +if __name__ == '__main__': + base_file_name = 'Aboveground_Live_Woody_Biomass_Density.geojson' if os.path.exists(base_file_name): os.remove(base_file_name) - with open(base_file_name, "w") as f: + with open(base_file_name, 'w') as f: json.dump(base_file, f) - for i in base_file["features"]: - filepath = os.path.basename(i["properties"]["Mg_px_1_download"]) - create_file(path=filepath, dtype="int32", num_channels=1) + for i in base_file['features']: + filepath = os.path.basename(i['properties']['Mg_px_1_download']) + create_file(path=filepath, dtype='int32', num_channels=1) diff --git a/tests/data/agrifieldnet/data.py b/tests/data/agrifieldnet/data.py index e0b4d0e256c..388c61f6d56 100644 --- a/tests/data/agrifieldnet/data.py +++ b/tests/data/agrifieldnet/data.py @@ -29,80 +29,80 @@ def generate_test_data(paths: str) -> str: np.random.seed(0) bands = ( - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B11", - "B12", + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', ) profile = { - "dtype": dtype, - "width": SIZE, - "height": SIZE, - "count": 1, - "crs": CRS.from_epsg(32644), - "transform": Affine(10.0, 0.0, 535840.0, 0.0, -10.0, 3079680.0), + 'dtype': dtype, + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': CRS.from_epsg(32644), + 'transform': Affine(10.0, 0.0, 535840.0, 0.0, -10.0, 3079680.0), } - source_dir = os.path.join(paths, "source") - train_mask_dir = os.path.join(paths, "train_labels") - test_field_dir = os.path.join(paths, "test_labels") + source_dir = os.path.join(paths, 'source') + train_mask_dir = os.path.join(paths, 'train_labels') + test_field_dir = os.path.join(paths, 'test_labels') os.makedirs(source_dir, exist_ok=True) os.makedirs(train_mask_dir, exist_ok=True) os.makedirs(test_field_dir, exist_ok=True) - source_unique_folder_ids = ["32407", "8641e", "a419f", "eac11", "ff450"] + source_unique_folder_ids = ['32407', '8641e', 'a419f', 'eac11', 'ff450'] train_folder_ids = source_unique_folder_ids[0:5] test_folder_ids = source_unique_folder_ids[3:5] for id in source_unique_folder_ids: directory = os.path.join( - source_dir, "ref_agrifieldnet_competition_v1_source_" + id + source_dir, 'ref_agrifieldnet_competition_v1_source_' + id ) os.makedirs(directory, exist_ok=True) for band in bands: train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype) path = os.path.join( - directory, f"ref_agrifieldnet_competition_v1_source_{id}_{band}_10m.tif" + directory, f'ref_agrifieldnet_competition_v1_source_{id}_{band}_10m.tif' ) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(train_arr, 1) for id in train_folder_ids: train_mask_arr = np.random.randint(size=(SIZE, SIZE), low=0, high=6) path = os.path.join( - train_mask_dir, f"ref_agrifieldnet_competition_v1_labels_train_{id}.tif" + train_mask_dir, f'ref_agrifieldnet_competition_v1_labels_train_{id}.tif' ) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(train_mask_arr, 1) train_field_arr = np.random.randint(20, size=(SIZE, SIZE), dtype=np.uint16) path = os.path.join( train_mask_dir, - f"ref_agrifieldnet_competition_v1_labels_train_{id}_field_ids.tif", + f'ref_agrifieldnet_competition_v1_labels_train_{id}_field_ids.tif', ) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(train_field_arr, 1) for id in test_folder_ids: test_field_arr = np.random.randint(10, 30, size=(SIZE, SIZE), dtype=np.uint16) path = os.path.join( test_field_dir, - f"ref_agrifieldnet_competition_v1_labels_test_{id}_field_ids.tif", + f'ref_agrifieldnet_competition_v1_labels_test_{id}_field_ids.tif', ) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(test_field_arr, 1) -if __name__ == "__main__": +if __name__ == '__main__': generate_test_data(os.getcwd()) diff --git a/tests/data/airphen/data.py b/tests/data/airphen/data.py index 40a5fd6f97f..50dc521420a 100755 --- a/tests/data/airphen/data.py +++ b/tests/data/airphen/data.py @@ -14,13 +14,13 @@ profile = { - "driver": "GTiff", - "dtype": "uint16", - "width": SIZE, - "height": SIZE, - "count": 6, - "crs": CRS.from_epsg(4326), - "transform": Affine( + 'driver': 'GTiff', + 'dtype': 'uint16', + 'width': SIZE, + 'height': SIZE, + 'count': 6, + 'crs': CRS.from_epsg(4326), + 'transform': Affine( 4.497249999999613e-07, 0.0, 12.567765446921205, @@ -31,9 +31,9 @@ } Z = np.random.randint( - np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"] + np.iinfo(profile['dtype']).max, size=(SIZE, SIZE), dtype=profile['dtype'] ) -with rasterio.open("zoneA_B_R_NIR.tif", "w", **profile) as src: - for i in range(profile["count"]): +with rasterio.open('zoneA_B_R_NIR.tif', 'w', **profile) as src: + for i in range(profile['count']): src.write(Z, i + 1) diff --git a/tests/data/astergdem/data.py b/tests/data/astergdem/data.py index b041e70c9e4..c41855b3612 100755 --- a/tests/data/astergdem/data.py +++ b/tests/data/astergdem/data.py @@ -13,49 +13,49 @@ SIZE = 64 files = [ - {"image": "ASTGTMV003_N000000_dem.tif"}, - {"image": "ASTGTMV003_N000010_dem.tif"}, + {'image': 'ASTGTMV003_N000000_dem.tif'}, + {'image': 'ASTGTMV003_N000010_dem.tif'}, ] def create_file(path: str, dtype: str, num_channels: int) -> None: profile = {} - profile["driver"] = "GTiff" - profile["dtype"] = dtype - profile["count"] = num_channels - profile["crs"] = "epsg:4326" - profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) - profile["height"] = SIZE - profile["width"] = SIZE - profile["compress"] = "lzw" - profile["predictor"] = 2 + profile['driver'] = 'GTiff' + profile['dtype'] = dtype + profile['count'] = num_channels + profile['crs'] = 'epsg:4326' + profile['transform'] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile['height'] = SIZE + profile['width'] = SIZE + profile['compress'] = 'lzw' + profile['predictor'] = 2 Z = np.random.randint( - np.iinfo(profile["dtype"]).max, size=(1, SIZE, SIZE), dtype=profile["dtype"] + np.iinfo(profile['dtype']).max, size=(1, SIZE, SIZE), dtype=profile['dtype'] ) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(Z) -if __name__ == "__main__": - zipfilename = "astergdem.zip" +if __name__ == '__main__': + zipfilename = 'astergdem.zip' files_to_zip = [] for file_dict in files: - path = file_dict["image"] + path = file_dict['image'] # remove old data if os.path.exists(path): os.remove(path) # Create mask file - create_file(path, dtype="int32", num_channels=1) + create_file(path, dtype='int32', num_channels=1) files_to_zip.append(path) # Compress data - with zipfile.ZipFile(zipfilename, "w") as zip: + with zipfile.ZipFile(zipfilename, 'w') as zip: for file in files_to_zip: zip.write(file, arcname=file) # Compute checksums - with open(zipfilename, "rb") as f: + with open(zipfilename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{zipfilename}: {md5}") + print(f'{zipfilename}: {md5}') diff --git a/tests/data/biomassters/data.py b/tests/data/biomassters/data.py index 648ae9b94f4..5ba79905684 100644 --- a/tests/data/biomassters/data.py +++ b/tests/data/biomassters/data.py @@ -11,31 +11,31 @@ import numpy as np import rasterio -metadata_train = "The_BioMassters_-_features_metadata.csv.csv" +metadata_train = 'The_BioMassters_-_features_metadata.csv.csv' csv_columns = [ - "filename", - "chip_id", - "satellite", - "split", - "month", - "size", - "cksum", - "s3path_us", - "s3path_eu", - "s3path_as", - "corresponding_agbm", + 'filename', + 'chip_id', + 'satellite', + 'split', + 'month', + 'size', + 'cksum', + 's3path_us', + 's3path_eu', + 's3path_as', + 'corresponding_agbm', ] -targets = "train_agbm.zip" +targets = 'train_agbm.zip' -splits = ["train", "test"] +splits = ['train', 'test'] -sample_ids = ["0003d2eb", "000aa810"] +sample_ids = ['0003d2eb', '000aa810'] -months = ["September", "October", "November"] +months = ['September', 'October', 'November'] -satellite = ["S1", "S2"] +satellite = ['S1', 'S2'] SIZE = 32 @@ -49,43 +49,43 @@ def create_tif_file(path: str, num_channels: int, dtype: str) -> None: dtype: uint16 for image data and float 32 for target """ profile = {} - profile["driver"] = "GTiff" - profile["dtype"] = dtype - profile["count"] = num_channels - profile["crs"] = "epsg:4326" - profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) - profile["height"] = SIZE - profile["width"] = SIZE - profile["compress"] = "lzw" - profile["predictor"] = 2 - - if "float" in profile["dtype"]: - Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"]) + profile['driver'] = 'GTiff' + profile['dtype'] = dtype + profile['count'] = num_channels + profile['crs'] = 'epsg:4326' + profile['transform'] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile['height'] = SIZE + profile['width'] = SIZE + profile['compress'] = 'lzw' + profile['predictor'] = 2 + + if 'float' in profile['dtype']: + Z = np.random.randn(SIZE, SIZE).astype(profile['dtype']) else: Z = np.random.randint( - np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"] + np.iinfo(profile['dtype']).max, size=(SIZE, SIZE), dtype=profile['dtype'] ) - with rasterio.open(path, "w", **profile) as src: - for i in range(1, profile["count"] + 1): + with rasterio.open(path, 'w', **profile) as src: + for i in range(1, profile['count'] + 1): src.write(Z, i) # filename,chip_id,satellite,split,month,size,cksum,s3path_us,s3path_eu,s3path_as,corresponding_agbm -if __name__ == "__main__": +if __name__ == '__main__': csv_rows = [] for split in splits: - os.makedirs(f"{split}_features", exist_ok=True) - if split == "train": - os.makedirs("train_agbm", exist_ok=True) + os.makedirs(f'{split}_features', exist_ok=True) + if split == 'train': + os.makedirs('train_agbm', exist_ok=True) for id in sample_ids: for sat in satellite: - path = id + "_" + str(sat) + path = id + '_' + str(sat) for idx, month in enumerate(months): # S2 data is not present for every month - if sat == "S2" and idx == 1: + if sat == 'S2' and idx == 1: continue - file_path = path + "_" + f"{idx:02d}" + ".tif" + file_path = path + '_' + f'{idx:02d}' + '.tif' csv_rows.append( [ @@ -94,44 +94,44 @@ def create_tif_file(path: str, num_channels: int, dtype: str) -> None: sat, split, month, - "0", - "0", - "path", - "path", - "path", - id + "_agbm.tif", + '0', + '0', + 'path', + 'path', + 'path', + id + '_agbm.tif', ] ) # file path to save - file_path = os.path.join(f"{split}_features", file_path) + file_path = os.path.join(f'{split}_features', file_path) - if sat == "S1": - create_tif_file(file_path, num_channels=4, dtype="uint16") + if sat == 'S1': + create_tif_file(file_path, num_channels=4, dtype='uint16') else: - create_tif_file(file_path, num_channels=11, dtype="uint16") + create_tif_file(file_path, num_channels=11, dtype='uint16') # create target data one per id - if split == "train": + if split == 'train': create_tif_file( - os.path.join(f"{split}_agbm", id + "_agbm.tif"), + os.path.join(f'{split}_agbm', id + '_agbm.tif'), num_channels=1, - dtype="float32", + dtype='float32', ) # write out metadata - with open(metadata_train, "w") as csv_file: + with open(metadata_train, 'w') as csv_file: wr = csv.writer(csv_file) wr.writerow(csv_columns) for row in csv_rows: wr.writerow(row) # zip up feature and target folders - zip_dirs = ["train_features", "test_features", "train_agbm"] + zip_dirs = ['train_features', 'test_features', 'train_agbm'] for dir in zip_dirs: - shutil.make_archive(dir, "zip", dir) + shutil.make_archive(dir, 'zip', dir) # Compute checksums - with open(dir + ".zip", "rb") as f: + with open(dir + '.zip', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{dir}: {md5}") + print(f'{dir}: {md5}') diff --git a/tests/data/cbf/data.py b/tests/data/cbf/data.py index 28563ce3968..6dc5c457f47 100755 --- a/tests/data/cbf/data.py +++ b/tests/data/cbf/data.py @@ -11,18 +11,18 @@ def create_geojson(): geojson = { - "type": "FeatureCollection", - "crs": { - "type": "name", - "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}, + 'type': 'FeatureCollection', + 'crs': { + 'type': 'name', + 'properties': {'name': 'urn:ogc:def:crs:OGC:1.3:CRS84'}, }, - "features": [ + 'features': [ { - "type": "Feature", - "properties": {}, - "geometry": { - "type": "Polygon", - "coordinates": [ + 'type': 'Feature', + 'properties': {}, + 'geometry': { + 'type': 'Polygon', + 'coordinates': [ [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]] ], }, @@ -32,22 +32,22 @@ def create_geojson(): return geojson -if __name__ == "__main__": - filename = "Alberta.zip" +if __name__ == '__main__': + filename = 'Alberta.zip' geojson = create_geojson() - with open(filename.replace(".zip", ".geojson"), "w") as f: + with open(filename.replace('.zip', '.geojson'), 'w') as f: json.dump(geojson, f) # compress single file directly with no directory shutil.make_archive( - filename.replace(".zip", ""), - "zip", + filename.replace('.zip', ''), + 'zip', os.getcwd(), - filename.replace(".zip", ".geojson"), + filename.replace('.zip', '.geojson'), ) # Compute checksums - with open(filename, "rb") as f: + with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{filename}: {md5}") + print(f'{filename}: {md5}') diff --git a/tests/data/cdl/data.py b/tests/data/cdl/data.py index 8d8a443b138..8b29a4ebf1d 100755 --- a/tests/data/cdl/data.py +++ b/tests/data/cdl/data.py @@ -20,15 +20,15 @@ def create_file(path: str, dtype: str, num_channels: int) -> None: profile = {} - profile["driver"] = "GTiff" - profile["dtype"] = dtype - profile["count"] = num_channels - profile["crs"] = "epsg:32616" - profile["transform"] = Affine(30, 0.0, 399960.0, 0.0, -30, 4500000.0) - profile["height"] = SIZE - profile["width"] = SIZE - profile["compress"] = "lzw" - profile["predictor"] = 2 + profile['driver'] = 'GTiff' + profile['dtype'] = dtype + profile['count'] = num_channels + profile['crs'] = 'epsg:32616' + profile['transform'] = Affine(30, 0.0, 399960.0, 0.0, -30, 4500000.0) + profile['height'] = SIZE + profile['width'] = SIZE + profile['compress'] = 'lzw' + profile['predictor'] = 2 cmap = { 0: (0, 0, 0, 0), 1: (255, 211, 0, 255), @@ -43,20 +43,20 @@ def create_file(path: str, dtype: str, num_channels: int) -> None: Z = np.random.randint(size=(SIZE, SIZE), low=0, high=8) - with rasterio.open(path, "w", **profile) as src: - for i in range(1, profile["count"] + 1): + with rasterio.open(path, 'w', **profile) as src: + for i in range(1, profile['count'] + 1): src.write(Z, i) src.write_colormap(1, cmap) -directories = ["2023_30m_cdls", "2022_30m_cdls"] -raster_extensions = [".tif", ".tif.ovr"] +directories = ['2023_30m_cdls', '2022_30m_cdls'] +raster_extensions = ['.tif', '.tif.ovr'] -if __name__ == "__main__": +if __name__ == '__main__': for dir in directories: - filename = dir + ".zip" + filename = dir + '.zip' # Remove old data if os.path.isdir(dir): @@ -66,15 +66,15 @@ def create_file(path: str, dtype: str, num_channels: int) -> None: for e in raster_extensions: create_file( - os.path.join(dir, filename.replace(".zip", e)), - dtype="int8", + os.path.join(dir, filename.replace('.zip', e)), + dtype='int8', num_channels=1, ) # Compress data - shutil.make_archive(filename.replace(".zip", ""), "zip", ".", dir) + shutil.make_archive(filename.replace('.zip', ''), 'zip', '.', dir) # Compute checksums - with open(filename, "rb") as f: + with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{filename}: {md5}") + print(f'{filename}: {md5}') diff --git a/tests/data/chabud/data.py b/tests/data/chabud/data.py index f6595c68ecc..b8b78e7c8be 100644 --- a/tests/data/chabud/data.py +++ b/tests/data/chabud/data.py @@ -19,28 +19,28 @@ np.random.seed(0) -filename = "train_eval.hdf5" -fold_mapping = {"train": [1, 2, 3, 4], "val": [0]} +filename = 'train_eval.hdf5' +fold_mapping = {'train': [1, 2, 3, 4], 'val': [0]} uris = [ - "feb08801-64b1-4d11-a3fc-0efaad1f4274_0", - "e4d4dbcb-dd92-40cf-a7fe-fda8dd35f367_1", - "9fc8c1f4-1858-47c3-953e-1dc8b179a", - "3a1358a2-6155-445a-a269-13bebd9741a8_0", - "2f8e659c-f457-4527-a57f-bffc3bbe0baa_0", - "299ee670-19b1-4a76-bef3-34fd55580711_1", - "05cfef86-3e27-42be-a0cb-a61fe2f89e40_0", - "0328d12a-4ad8-4504-8ac5-70089db10b4e_1", + 'feb08801-64b1-4d11-a3fc-0efaad1f4274_0', + 'e4d4dbcb-dd92-40cf-a7fe-fda8dd35f367_1', + '9fc8c1f4-1858-47c3-953e-1dc8b179a', + '3a1358a2-6155-445a-a269-13bebd9741a8_0', + '2f8e659c-f457-4527-a57f-bffc3bbe0baa_0', + '299ee670-19b1-4a76-bef3-34fd55580711_1', + '05cfef86-3e27-42be-a0cb-a61fe2f89e40_0', + '0328d12a-4ad8-4504-8ac5-70089db10b4e_1', ] folds = [ - random.sample(fold_mapping["train"], 1)[0], - random.sample(fold_mapping["train"], 1)[0], - random.sample(fold_mapping["train"], 1)[0], - random.sample(fold_mapping["train"], 1)[0], - random.sample(fold_mapping["val"], 1)[0], - random.sample(fold_mapping["val"], 1)[0], - random.sample(fold_mapping["val"], 1)[0], - random.sample(fold_mapping["val"], 1)[0], + random.sample(fold_mapping['train'], 1)[0], + random.sample(fold_mapping['train'], 1)[0], + random.sample(fold_mapping['train'], 1)[0], + random.sample(fold_mapping['train'], 1)[0], + random.sample(fold_mapping['val'], 1)[0], + random.sample(fold_mapping['val'], 1)[0], + random.sample(fold_mapping['val'], 1)[0], + random.sample(fold_mapping['val'], 1)[0], ] # Remove old data @@ -54,16 +54,16 @@ data = data.astype(np.uint16) gt = np.random.randint(NUM_CLASSES, size=(SIZE, SIZE, 1), dtype=np.uint16) -with h5py.File(filename, "w") as f: +with h5py.File(filename, 'w') as f: for uri, fold in zip(uris, folds): sample = f.create_group(uri) - sample.attrs.create(name="fold", data=np.int64(fold)) + sample.attrs.create(name='fold', data=np.int64(fold)) sample.create_dataset - sample.create_dataset("pre_fire", data=data) - sample.create_dataset("post_fire", data=data) - sample.create_dataset("mask", data=gt) + sample.create_dataset('pre_fire', data=data) + sample.create_dataset('post_fire', data=data) + sample.create_dataset('mask', data=gt) # Compute checksums -with open(filename, "rb") as f: +with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"md5: {md5}") + print(f'md5: {md5}') diff --git a/tests/data/chesapeake/BAYWIDE/data.py b/tests/data/chesapeake/BAYWIDE/data.py index 3128f975157..dbf807ce0fe 100755 --- a/tests/data/chesapeake/BAYWIDE/data.py +++ b/tests/data/chesapeake/BAYWIDE/data.py @@ -17,7 +17,7 @@ np.random.seed(0) -filename = "Baywide_13Class_20132014" +filename = 'Baywide_13Class_20132014' wkt = """ PROJCS["USA_Contiguous_Albers_Equal_Area_Conic_USGS_version", GEOGCS["NAD83", @@ -60,22 +60,22 @@ meta = { - "driver": "GTiff", - "dtype": "uint8", - "nodata": None, - "width": SIZE, - "height": SIZE, - "count": 1, - "crs": CRS.from_wkt(wkt), - "transform": Affine(1.0, 0.0, 1303555.0000000005, 0.0, -1.0, 2535064.999999998), + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': CRS.from_wkt(wkt), + 'transform': Affine(1.0, 0.0, 1303555.0000000005, 0.0, -1.0, 2535064.999999998), } # Remove old data -if os.path.exists(f"{filename}.tif"): - os.remove(f"{filename}.tif") +if os.path.exists(f'{filename}.tif'): + os.remove(f'{filename}.tif') # Create raster file -with rasterio.open(f"{filename}.tif", "w", **meta) as f: +with rasterio.open(f'{filename}.tif', 'w', **meta) as f: data = np.random.randint(NUM_CLASSES, size=(SIZE, SIZE), dtype=np.uint8) f.write(data, 1) f.write_colormap(1, cmap) @@ -84,12 +84,12 @@ # 7z required to create a zip file using the proprietary DEFLATE64 compression algorithm # https://github.com/brianhelba/zipfile-deflate64/issues/19#issuecomment-1006077294 subprocess.run( - ["7z", "a", f"{filename}.zip", "-mm=DEFLATE64", f"{filename}.tif"], + ['7z', 'a', f'{filename}.zip', '-mm=DEFLATE64', f'{filename}.tif'], capture_output=True, check=True, ) # Compute checksums -with open(f"{filename}.zip", "rb") as f: +with open(f'{filename}.zip', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() print(repr(md5)) diff --git a/tests/data/cms_mangrove_canopy/data.py b/tests/data/cms_mangrove_canopy/data.py index de2a19f11ac..4465c75e173 100755 --- a/tests/data/cms_mangrove_canopy/data.py +++ b/tests/data/cms_mangrove_canopy/data.py @@ -18,51 +18,51 @@ files = [ - {"image": "Mangrove_agb_Angola.tif"}, - {"image": "Mangrove_hba95_Angola.tif"}, - {"image": "Mangrove_hmax95_Angola.tif"}, + {'image': 'Mangrove_agb_Angola.tif'}, + {'image': 'Mangrove_hba95_Angola.tif'}, + {'image': 'Mangrove_hmax95_Angola.tif'}, ] def create_file(path: str, dtype: str, num_channels: int) -> None: profile = {} - profile["driver"] = "GTiff" - profile["dtype"] = dtype - profile["count"] = num_channels - profile["crs"] = "epsg:4326" - profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) - profile["height"] = SIZE - profile["width"] = SIZE - profile["compress"] = "lzw" - profile["predictor"] = 2 + profile['driver'] = 'GTiff' + profile['dtype'] = dtype + profile['count'] = num_channels + profile['crs'] = 'epsg:4326' + profile['transform'] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile['height'] = SIZE + profile['width'] = SIZE + profile['compress'] = 'lzw' + profile['predictor'] = 2 Z = np.random.randint( - np.iinfo(profile["dtype"]).max, size=(1, SIZE, SIZE), dtype=profile["dtype"] + np.iinfo(profile['dtype']).max, size=(1, SIZE, SIZE), dtype=profile['dtype'] ) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(Z) -if __name__ == "__main__": - directory = "CMS_Global_Map_Mangrove_Canopy_1665" +if __name__ == '__main__': + directory = 'CMS_Global_Map_Mangrove_Canopy_1665' # Remove old data if os.path.isdir(directory): shutil.rmtree(directory) - os.makedirs(os.path.join(directory, "data"), exist_ok=True) + os.makedirs(os.path.join(directory, 'data'), exist_ok=True) for file_dict in files: # Create mask file - path = file_dict["image"] + path = file_dict['image'] create_file( - os.path.join(directory, "data", path), dtype="int32", num_channels=1 + os.path.join(directory, 'data', path), dtype='int32', num_channels=1 ) # Compress data - shutil.make_archive(directory.replace(".zip", ""), "zip", ".", directory) + shutil.make_archive(directory.replace('.zip', ''), 'zip', '.', directory) # Compute checksums - with open(directory + ".zip", "rb") as f: + with open(directory + '.zip', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{directory}: {md5}") + print(f'{directory}: {md5}') diff --git a/tests/data/cowc_counting/data.py b/tests/data/cowc_counting/data.py index b612f926cec..66f5ed42e4f 100755 --- a/tests/data/cowc_counting/data.py +++ b/tests/data/cowc_counting/data.py @@ -15,71 +15,71 @@ SIZE = 64 # image width/height STOP = 20 # range of values for labels -PREFIX = "Counting" -SUFFIX = "64_class" +PREFIX = 'Counting' +SUFFIX = '64_class' random.seed(0) sites = [ - "Toronto_ISPRS", - "Selwyn_LINZ", - "Potsdam_ISPRS", - "Vaihingen_ISPRS", - "Columbus_CSUAV_AFRL", - "Utah_AGRC", + 'Toronto_ISPRS', + 'Selwyn_LINZ', + 'Potsdam_ISPRS', + 'Vaihingen_ISPRS', + 'Columbus_CSUAV_AFRL', + 'Utah_AGRC', ] # Remove old data -for filename in glob.glob("COWC_*"): +for filename in glob.glob('COWC_*'): os.remove(filename) for site in sites: if os.path.exists(site): shutil.rmtree(site) i = 1 -data_list = {"train": [], "test": []} +data_list = {'train': [], 'test': []} image_md5s = [] for site in sites: # Create images - for split in ["test", "train", "train"]: + for split in ['test', 'train', 'train']: directory = os.path.join(site, split) os.makedirs(directory, exist_ok=True) - filename = os.path.join(directory, f"fake_{i:02}.png") + filename = os.path.join(directory, f'fake_{i:02}.png') - img = Image.new("RGB", (SIZE, SIZE)) + img = Image.new('RGB', (SIZE, SIZE)) img.save(filename) data_list[split].append((filename, random.randrange(STOP))) - if split == "train": + if split == 'train': i += 1 # Compress images - filename = f"COWC_{PREFIX}_{site}.tbz" - bad_filename = shutil.make_archive(filename.replace(".tbz", ""), "bztar", ".", site) + filename = f'COWC_{PREFIX}_{site}.tbz' + bad_filename = shutil.make_archive(filename.replace('.tbz', ''), 'bztar', '.', site) os.rename(bad_filename, filename) # Compute checksums - with open(filename, "rb") as f: + with open(filename, 'rb') as f: image_md5s.append(hashlib.md5(f.read()).hexdigest()) label_md5s = [] -for split in ["train", "test"]: +for split in ['train', 'test']: # Create labels - filename = f"COWC_{split}_list_{SUFFIX}.txt" - with open(filename, "w", newline="") as csvfile: - csvwriter = csv.writer(csvfile, delimiter=" ") + filename = f'COWC_{split}_list_{SUFFIX}.txt' + with open(filename, 'w', newline='') as csvfile: + csvwriter = csv.writer(csvfile, delimiter=' ') csvwriter.writerows(data_list[split]) # Compress labels - with open(filename, "rb") as src: - with bz2.open(filename + ".bz2", "wb") as dst: + with open(filename, 'rb') as src: + with bz2.open(filename + '.bz2', 'wb') as dst: dst.write(src.read()) # Compute checksums - with open(filename + ".bz2", "rb") as f: + with open(filename + '.bz2', 'rb') as f: label_md5s.append(hashlib.md5(f.read()).hexdigest()) md5s = label_md5s + image_md5s for md5 in md5s: - print(repr(md5) + ",") + print(repr(md5) + ',') diff --git a/tests/data/cowc_detection/data.py b/tests/data/cowc_detection/data.py index b5b91c6a7df..ea8e3db757e 100755 --- a/tests/data/cowc_detection/data.py +++ b/tests/data/cowc_detection/data.py @@ -15,71 +15,71 @@ SIZE = 64 # image width/height STOP = 2 # range of values for labels -PREFIX = "Detection" -SUFFIX = "detection" +PREFIX = 'Detection' +SUFFIX = 'detection' random.seed(0) sites = [ - "Toronto_ISPRS", - "Selwyn_LINZ", - "Potsdam_ISPRS", - "Vaihingen_ISPRS", - "Columbus_CSUAV_AFRL", - "Utah_AGRC", + 'Toronto_ISPRS', + 'Selwyn_LINZ', + 'Potsdam_ISPRS', + 'Vaihingen_ISPRS', + 'Columbus_CSUAV_AFRL', + 'Utah_AGRC', ] # Remove old data -for filename in glob.glob("COWC_*"): +for filename in glob.glob('COWC_*'): os.remove(filename) for site in sites: if os.path.exists(site): shutil.rmtree(site) i = 1 -data_list = {"train": [], "test": []} +data_list = {'train': [], 'test': []} image_md5s = [] for site in sites: # Create images - for split in ["test", "train", "train"]: + for split in ['test', 'train', 'train']: directory = os.path.join(site, split) os.makedirs(directory, exist_ok=True) - filename = os.path.join(directory, f"fake_{i:02}.png") + filename = os.path.join(directory, f'fake_{i:02}.png') - img = Image.new("RGB", (SIZE, SIZE)) + img = Image.new('RGB', (SIZE, SIZE)) img.save(filename) data_list[split].append((filename, random.randrange(STOP))) - if split == "train": + if split == 'train': i += 1 # Compress images - filename = f"COWC_{PREFIX}_{site}.tbz" - bad_filename = shutil.make_archive(filename.replace(".tbz", ""), "bztar", ".", site) + filename = f'COWC_{PREFIX}_{site}.tbz' + bad_filename = shutil.make_archive(filename.replace('.tbz', ''), 'bztar', '.', site) os.rename(bad_filename, filename) # Compute checksums - with open(filename, "rb") as f: + with open(filename, 'rb') as f: image_md5s.append(hashlib.md5(f.read()).hexdigest()) label_md5s = [] -for split in ["train", "test"]: +for split in ['train', 'test']: # Create labels - filename = f"COWC_{split}_list_{SUFFIX}.txt" - with open(filename, "w", newline="") as csvfile: - csvwriter = csv.writer(csvfile, delimiter=" ") + filename = f'COWC_{split}_list_{SUFFIX}.txt' + with open(filename, 'w', newline='') as csvfile: + csvwriter = csv.writer(csvfile, delimiter=' ') csvwriter.writerows(data_list[split]) # Compress labels - with open(filename, "rb") as src: - with bz2.open(filename + ".bz2", "wb") as dst: + with open(filename, 'rb') as src: + with bz2.open(filename + '.bz2', 'wb') as dst: dst.write(src.read()) # Compute checksums - with open(filename + ".bz2", "rb") as f: + with open(filename + '.bz2', 'rb') as f: label_md5s.append(hashlib.md5(f.read()).hexdigest()) md5s = label_md5s + image_md5s for md5 in md5s: - print(repr(md5) + ",") + print(repr(md5) + ',') diff --git a/tests/data/cropharvest/data.py b/tests/data/cropharvest/data.py index 4b882717c83..5bf85d21f84 100755 --- a/tests/data/cropharvest/data.py +++ b/tests/data/cropharvest/data.py @@ -16,90 +16,90 @@ np.random.seed(0) PATHS = [ - os.path.join("cropharvest", "features", "arrays", "0_TestDataset1.h5"), - os.path.join("cropharvest", "features", "arrays", "1_TestDataset1.h5"), - os.path.join("cropharvest", "features", "arrays", "2_TestDataset1.h5"), - os.path.join("cropharvest", "features", "arrays", "0_TestDataset2.h5"), - os.path.join("cropharvest", "features", "arrays", "1_TestDataset2.h5"), + os.path.join('cropharvest', 'features', 'arrays', '0_TestDataset1.h5'), + os.path.join('cropharvest', 'features', 'arrays', '1_TestDataset1.h5'), + os.path.join('cropharvest', 'features', 'arrays', '2_TestDataset1.h5'), + os.path.join('cropharvest', 'features', 'arrays', '0_TestDataset2.h5'), + os.path.join('cropharvest', 'features', 'arrays', '1_TestDataset2.h5'), ] def create_geojson(): geojson = { - "type": "FeatureCollection", - "crs": {}, - "features": [ + 'type': 'FeatureCollection', + 'crs': {}, + 'features': [ { - "type": "Feature", - "properties": { - "dataset": "TestDataset1", - "index": 0, - "is_crop": 1, - "label": "soybean", + 'type': 'Feature', + 'properties': { + 'dataset': 'TestDataset1', + 'index': 0, + 'is_crop': 1, + 'label': 'soybean', }, - "geometry": { - "type": "Polygon", - "coordinates": [ + 'geometry': { + 'type': 'Polygon', + 'coordinates': [ [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]] ], }, }, { - "type": "Feature", - "properties": { - "dataset": "TestDataset1", - "index": 0, - "is_crop": 1, - "label": "alfalfa", + 'type': 'Feature', + 'properties': { + 'dataset': 'TestDataset1', + 'index': 0, + 'is_crop': 1, + 'label': 'alfalfa', }, - "geometry": { - "type": "Polygon", - "coordinates": [ + 'geometry': { + 'type': 'Polygon', + 'coordinates': [ [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]] ], }, }, { - "type": "Feature", - "properties": { - "dataset": "TestDataset1", - "index": 1, - "is_crop": 1, - "label": None, + 'type': 'Feature', + 'properties': { + 'dataset': 'TestDataset1', + 'index': 1, + 'is_crop': 1, + 'label': None, }, - "geometry": { - "type": "Polygon", - "coordinates": [ + 'geometry': { + 'type': 'Polygon', + 'coordinates': [ [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]] ], }, }, { - "type": "Feature", - "properties": { - "dataset": "TestDataset2", - "index": 2, - "is_crop": 1, - "label": "maize", + 'type': 'Feature', + 'properties': { + 'dataset': 'TestDataset2', + 'index': 2, + 'is_crop': 1, + 'label': 'maize', }, - "geometry": { - "type": "Polygon", - "coordinates": [ + 'geometry': { + 'type': 'Polygon', + 'coordinates': [ [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]] ], }, }, { - "type": "Feature", - "properties": { - "dataset": "TestDataset2", - "index": 1, - "is_crop": 0, - "label": None, + 'type': 'Feature', + 'properties': { + 'dataset': 'TestDataset2', + 'index': 1, + 'is_crop': 0, + 'label': None, }, - "geometry": { - "type": "Polygon", - "coordinates": [ + 'geometry': { + 'type': 'Polygon', + 'coordinates': [ [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]] ], }, @@ -111,28 +111,28 @@ def create_geojson(): def create_file(path: str) -> None: Z = np.random.randint(4000, size=(12, 18), dtype=np.int64) - with h5py.File(path, "w") as f: - f.create_dataset("array", data=Z) + with h5py.File(path, 'w') as f: + f.create_dataset('array', data=Z) -if __name__ == "__main__": - directory = "cropharvest" +if __name__ == '__main__': + directory = 'cropharvest' # remove old data to_remove = [ - os.path.join(directory, "features"), - os.path.join(directory, "features.tar.gz"), - os.path.join(directory, "labels.geojson"), + os.path.join(directory, 'features'), + os.path.join(directory, 'features.tar.gz'), + os.path.join(directory, 'labels.geojson'), ] for path in to_remove: if os.path.isdir(path): shutil.rmtree(path) - label_path = os.path.join(directory, "labels.geojson") + label_path = os.path.join(directory, 'labels.geojson') geojson = create_geojson() os.makedirs(os.path.dirname(label_path), exist_ok=True) - with open(label_path, "w") as f: + with open(label_path, 'w') as f: json.dump(geojson, f) for path in PATHS: @@ -140,14 +140,14 @@ def create_file(path: str) -> None: create_file(path) # compress data - source_dir = os.path.join(directory, "features") - shutil.make_archive(source_dir, "gztar", directory, "features") + source_dir = os.path.join(directory, 'features') + shutil.make_archive(source_dir, 'gztar', directory, 'features') # compute checksum - with open(label_path, "rb") as f: + with open(label_path, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{label_path}: {md5}") + print(f'{label_path}: {md5}') - with open(os.path.join(directory, "features.tar.gz"), "rb") as f: + with open(os.path.join(directory, 'features.tar.gz'), 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"zipped features: {md5}") + print(f'zipped features: {md5}') diff --git a/tests/data/deepglobelandcover/data.py b/tests/data/deepglobelandcover/data.py index 1c8778bf8d4..44185cd04a1 100755 --- a/tests/data/deepglobelandcover/data.py +++ b/tests/data/deepglobelandcover/data.py @@ -24,12 +24,12 @@ def generate_test_data(root: str, n_samples: int = 3) -> str: dtype = np.uint8 size = 2 - folder_path = os.path.join(root, "data") + folder_path = os.path.join(root, 'data') - train_img_dir = os.path.join(folder_path, "data", "training_data", "images") - train_mask_dir = os.path.join(folder_path, "data", "training_data", "masks") - test_img_dir = os.path.join(folder_path, "data", "test_data", "images") - test_mask_dir = os.path.join(folder_path, "data", "test_data", "masks") + train_img_dir = os.path.join(folder_path, 'data', 'training_data', 'images') + train_mask_dir = os.path.join(folder_path, 'data', 'training_data', 'masks') + test_img_dir = os.path.join(folder_path, 'data', 'test_data', 'images') + test_mask_dir = os.path.join(folder_path, 'data', 'test_data', 'masks') os.makedirs(train_img_dir, exist_ok=True) os.makedirs(train_mask_dir, exist_ok=True) @@ -46,26 +46,26 @@ def generate_test_data(root: str, n_samples: int = 3) -> str: dtype_max = np.iinfo(dtype).max train_arr = np.random.randint(dtype_max, size=(size, size, 3), dtype=dtype) train_img = Image.fromarray(train_arr) - train_img.save(os.path.join(train_img_dir, str(train_id) + "_sat.jpg")) + train_img.save(os.path.join(train_img_dir, str(train_id) + '_sat.jpg')) test_arr = np.random.randint(dtype_max, size=(size, size, 3), dtype=dtype) test_img = Image.fromarray(test_arr) - test_img.save(os.path.join(test_img_dir, str(test_id) + "_sat.jpg")) + test_img.save(os.path.join(test_img_dir, str(test_id) + '_sat.jpg')) train_mask_arr = np.full((size, size, 3), (0, 255, 255), dtype=dtype) train_mask_img = Image.fromarray(train_mask_arr) - train_mask_img.save(os.path.join(train_mask_dir, str(train_id) + "_mask.png")) + train_mask_img.save(os.path.join(train_mask_dir, str(train_id) + '_mask.png')) test_mask_arr = np.full((size, size, 3), (255, 0, 255), dtype=dtype) test_mask_img = Image.fromarray(test_mask_arr) - test_mask_img.save(os.path.join(test_mask_dir, str(test_id) + "_mask.png")) + test_mask_img.save(os.path.join(test_mask_dir, str(test_id) + '_mask.png')) # Create archive - shutil.make_archive(folder_path, "zip", folder_path) + shutil.make_archive(folder_path, 'zip', folder_path) shutil.rmtree(folder_path) - return calculate_md5(f"{folder_path}.zip") + return calculate_md5(f'{folder_path}.zip') -if __name__ == "__main__": +if __name__ == '__main__': md5_hash = generate_test_data(os.getcwd(), 3) - print(md5_hash + "\n") + print(md5_hash + '\n') diff --git a/tests/data/dfc2022/data.py b/tests/data/dfc2022/data.py index 1dea719250e..39d41f5d945 100755 --- a/tests/data/dfc2022/data.py +++ b/tests/data/dfc2022/data.py @@ -19,66 +19,66 @@ train_set = [ { - "image": "labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif", # noqa: E501 - "dem": "labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif", # noqa: E501 - "target": "labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif", # noqa: E501 + 'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif', # noqa: E501 + 'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 + 'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif', # noqa: E501 }, { - "image": "labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif", # noqa: E501 - "dem": "labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif", # noqa: E501 - "target": "labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif", # noqa: E501 + 'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif', # noqa: E501 + 'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 + 'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif', # noqa: E501 }, ] unlabeled_set = [ { - "image": "unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif", # noqa: E501 - "dem": "unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif", # noqa: E501 + 'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif', # noqa: E501 + 'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 }, { - "image": "unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif", # noqa: E501 - "dem": "unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif", # noqa: E501 + 'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif', # noqa: E501 + 'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 }, ] val_set = [ { - "image": "val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif", # noqa: E501 - "dem": "val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif", # noqa: E501 + 'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif', # noqa: E501 + 'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 }, { - "image": "val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif", # noqa: E501 - "dem": "val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif", # noqa: E501 + 'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif', # noqa: E501 + 'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 }, ] def create_file(path: str, dtype: str, num_channels: int) -> None: profile = {} - profile["driver"] = "GTiff" - profile["dtype"] = dtype - profile["count"] = num_channels - profile["crs"] = "epsg:4326" - profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) - profile["height"] = SIZE - profile["width"] = SIZE - - if "float" in profile["dtype"]: - Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"]) + profile['driver'] = 'GTiff' + profile['dtype'] = dtype + profile['count'] = num_channels + profile['crs'] = 'epsg:4326' + profile['transform'] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile['height'] = SIZE + profile['width'] = SIZE + + if 'float' in profile['dtype']: + Z = np.random.randn(SIZE, SIZE).astype(profile['dtype']) else: Z = np.random.randint( - np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"] + np.iinfo(profile['dtype']).max, size=(SIZE, SIZE), dtype=profile['dtype'] ) - with rasterio.open(path, "w", **profile) as src: - for i in range(1, profile["count"] + 1): + with rasterio.open(path, 'w', **profile) as src: + for i in range(1, profile['count'] + 1): src.write(Z, i) -if __name__ == "__main__": +if __name__ == '__main__': for split in DFC2022.metadata: - directory = DFC2022.metadata[split]["directory"] - filename = DFC2022.metadata[split]["filename"] + directory = DFC2022.metadata[split]['directory'] + filename = DFC2022.metadata[split]['filename'] # Remove old data if os.path.isdir(directory): @@ -86,34 +86,34 @@ def create_file(path: str, dtype: str, num_channels: int) -> None: if os.path.exists(filename): os.remove(filename) - if split == "train": + if split == 'train': files = train_set - elif split == "train-unlabeled": + elif split == 'train-unlabeled': files = unlabeled_set else: files = val_set for file_dict in files: # Create image file - path = file_dict["image"] + path = file_dict['image'] os.makedirs(os.path.dirname(path), exist_ok=True) - create_file(path, dtype="uint8", num_channels=3) + create_file(path, dtype='uint8', num_channels=3) # Create DEM file - path = file_dict["dem"] + path = file_dict['dem'] os.makedirs(os.path.dirname(path), exist_ok=True) - create_file(path, dtype="float32", num_channels=1) + create_file(path, dtype='float32', num_channels=1) # Create mask file - if split == "train": - path = file_dict["target"] + if split == 'train': + path = file_dict['target'] os.makedirs(os.path.dirname(path), exist_ok=True) - create_file(path, dtype="uint8", num_channels=1) + create_file(path, dtype='uint8', num_channels=1) # Compress data - shutil.make_archive(filename.replace(".zip", ""), "zip", ".", directory) + shutil.make_archive(filename.replace('.zip', ''), 'zip', '.', directory) # Compute checksums - with open(filename, "rb") as f: + with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{filename}: {md5}") + print(f'{filename}: {md5}') diff --git a/tests/data/eddmaps/data.py b/tests/data/eddmaps/data.py index 5af198e4738..4ebf094aba0 100755 --- a/tests/data/eddmaps/data.py +++ b/tests/data/eddmaps/data.py @@ -5,92 +5,92 @@ import pandas as pd -filename = "mappings.csv" +filename = 'mappings.csv' size = 3 data = { - "gbifID": [""] * size, - "decimalLatitude": [41.881832] * size, - "decimalLongitude": [""] + [-87.623177] * (size - 1), - "objectid": [""] * size, - "reporter": [""] * size, - "RecOwner": [""] * size, - "SciName": ["Homo sapiens"] * size, - "ComName": ["human"] * size, - "Nativity": ["Native"] * size, - "OccStatus": ["Detected"] * size, - "Status": ["Positive"] * size, - "ObsDate": ["", "", "05-07-22"], - "DateEnt": ["05-07-22"] * size, - "DateUp": ["05-07-22"] * size, - "Location": ["Chicago, Illinois, United States"] * size, - "Latitude": [41.881832] * size, - "Longitude": [""] + [-87.623177] * (size - 1), - "Datum": ["WGS84"] * size, - "Method": [""] * size, - "CoordAcc": [""] * size, - "DataType": [""] * size, - "Centroid": [""] * size, - "Abundance": [""] * size, - "InfestAcre": [""] * size, - "GrossAcre": [""] * size, - "Percentcov": [""] * size, - "Density": [""] * size, - "Quantity": [""] * size, - "QuantityU": [""] * size, - "APPXQuant": [""] * size, - "NumCollect": [""] * size, - "Smallest": [""] * size, - "Largest": [""] * size, - "Incidence": [""] * size, - "Severity": [""] * size, - "Host": [""] * size, - "Host_Name": [""] * size, - "HostPheno": [""] * size, - "HostDamage": [""] * size, - "ManageStat": ["Unknown"] * size, - "PopStat": [""] * size, - "Habitat": [""] * size, - "LocalOwner": [""] * size, - "Site": [""] * size, - "RecBasis": [""] * size, - "Museum": [""] * size, - "MuseumRec": [""] * size, - "Voucher": [""] * size, - "ObsIDer": [""] * size, - "CollectTme": [""] * size, - "UUID": [""] * size, - "OrgSrcID": [""] * size, - "OrigName": ["Homo sapiens"] * size, - "RecSrcTyp": ["Bulk Data"] * size, - "Surveyor": [""] * size, - "DateAcc": [""] * size, - "VisitType": [""] * size, - "DataMthd": [""] * size, - "TrapType": [""] * size, - "NumTraps": [""] * size, - "TargetName": [""] * size, - "TargetCnt": [""] * size, - "TargetRnge": [""] * size, - "Phenology": [""] * size, - "LifeStatus": [""] * size, - "Sex": [""] * size, - "PID": [""] * size, - "WaterName": [""] * size, - "WaterType": [""] * size, - "Substrate": [""] * size, - "TreatArea": [""] * size, - "PlantTreat": [""] * size, - "TreatComm": [""] * size, - "Reference": [""] * size, - "Locality": [""] * size, - "Comments": [""] * size, - "ReviewDate": ["05-07-22"] * size, - "Reviewer": ["Charles Darwin"] * size, - "VerifyMthd": ["Bulk Verified"] * size, - "Verified": ["Verified"] * size, - "IDCred": ["Credible"] * size, - "ReviewComm": [""] * size, + 'gbifID': [''] * size, + 'decimalLatitude': [41.881832] * size, + 'decimalLongitude': [''] + [-87.623177] * (size - 1), + 'objectid': [''] * size, + 'reporter': [''] * size, + 'RecOwner': [''] * size, + 'SciName': ['Homo sapiens'] * size, + 'ComName': ['human'] * size, + 'Nativity': ['Native'] * size, + 'OccStatus': ['Detected'] * size, + 'Status': ['Positive'] * size, + 'ObsDate': ['', '', '05-07-22'], + 'DateEnt': ['05-07-22'] * size, + 'DateUp': ['05-07-22'] * size, + 'Location': ['Chicago, Illinois, United States'] * size, + 'Latitude': [41.881832] * size, + 'Longitude': [''] + [-87.623177] * (size - 1), + 'Datum': ['WGS84'] * size, + 'Method': [''] * size, + 'CoordAcc': [''] * size, + 'DataType': [''] * size, + 'Centroid': [''] * size, + 'Abundance': [''] * size, + 'InfestAcre': [''] * size, + 'GrossAcre': [''] * size, + 'Percentcov': [''] * size, + 'Density': [''] * size, + 'Quantity': [''] * size, + 'QuantityU': [''] * size, + 'APPXQuant': [''] * size, + 'NumCollect': [''] * size, + 'Smallest': [''] * size, + 'Largest': [''] * size, + 'Incidence': [''] * size, + 'Severity': [''] * size, + 'Host': [''] * size, + 'Host_Name': [''] * size, + 'HostPheno': [''] * size, + 'HostDamage': [''] * size, + 'ManageStat': ['Unknown'] * size, + 'PopStat': [''] * size, + 'Habitat': [''] * size, + 'LocalOwner': [''] * size, + 'Site': [''] * size, + 'RecBasis': [''] * size, + 'Museum': [''] * size, + 'MuseumRec': [''] * size, + 'Voucher': [''] * size, + 'ObsIDer': [''] * size, + 'CollectTme': [''] * size, + 'UUID': [''] * size, + 'OrgSrcID': [''] * size, + 'OrigName': ['Homo sapiens'] * size, + 'RecSrcTyp': ['Bulk Data'] * size, + 'Surveyor': [''] * size, + 'DateAcc': [''] * size, + 'VisitType': [''] * size, + 'DataMthd': [''] * size, + 'TrapType': [''] * size, + 'NumTraps': [''] * size, + 'TargetName': [''] * size, + 'TargetCnt': [''] * size, + 'TargetRnge': [''] * size, + 'Phenology': [''] * size, + 'LifeStatus': [''] * size, + 'Sex': [''] * size, + 'PID': [''] * size, + 'WaterName': [''] * size, + 'WaterType': [''] * size, + 'Substrate': [''] * size, + 'TreatArea': [''] * size, + 'PlantTreat': [''] * size, + 'TreatComm': [''] * size, + 'Reference': [''] * size, + 'Locality': [''] * size, + 'Comments': [''] * size, + 'ReviewDate': ['05-07-22'] * size, + 'Reviewer': ['Charles Darwin'] * size, + 'VerifyMthd': ['Bulk Verified'] * size, + 'Verified': ['Verified'] * size, + 'IDCred': ['Credible'] * size, + 'ReviewComm': [''] * size, } df = pd.DataFrame(data) diff --git a/tests/data/enviroatlas/data.py b/tests/data/enviroatlas/data.py index f2816e193a7..c694a315680 100755 --- a/tests/data/enviroatlas/data.py +++ b/tests/data/enviroatlas/data.py @@ -17,198 +17,198 @@ from torchvision.datasets.utils import calculate_md5 suffix_to_key_map = { - "a_naip": "naip", - "b_nlcd": "nlcd", - "c_roads": "roads", - "d_water": "water", - "d1_waterways": "waterways", - "d2_waterbodies": "waterbodies", - "e_buildings": "buildings", - "h_highres_labels": "lc", - "prior_from_cooccurrences_101_31": "prior", - "prior_from_cooccurrences_101_31_no_osm_no_buildings": "prior_no_osm_no_buildings", + 'a_naip': 'naip', + 'b_nlcd': 'nlcd', + 'c_roads': 'roads', + 'd_water': 'water', + 'd1_waterways': 'waterways', + 'd2_waterbodies': 'waterbodies', + 'e_buildings': 'buildings', + 'h_highres_labels': 'lc', + 'prior_from_cooccurrences_101_31': 'prior', + 'prior_from_cooccurrences_101_31_no_osm_no_buildings': 'prior_no_osm_no_buildings', } layer_data_profiles: dict[str, dict[Any, Any]] = { - "a_naip": { - "profile": { - "driver": "GTiff", - "dtype": "uint8", - "nodata": None, - "count": 4, - "crs": CRS.from_epsg(26914), - "blockxsize": 512, - "blockysize": 512, - "tiled": True, - "compress": "deflate", - "interleave": "pixel", + 'a_naip': { + 'profile': { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'count': 4, + 'crs': CRS.from_epsg(26914), + 'blockxsize': 512, + 'blockysize': 512, + 'tiled': True, + 'compress': 'deflate', + 'interleave': 'pixel', }, - "data_type": "continuous", - "vals": (4, 255), + 'data_type': 'continuous', + 'vals': (4, 255), }, - "b_nlcd": { - "profile": { - "driver": "GTiff", - "dtype": "uint8", - "nodata": None, - "count": 1, - "crs": CRS.from_epsg(26914), - "blockxsize": 512, - "blockysize": 512, - "tiled": True, - "compress": "deflate", - "interleave": "band", + 'b_nlcd': { + 'profile': { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'count': 1, + 'crs': CRS.from_epsg(26914), + 'blockxsize': 512, + 'blockysize': 512, + 'tiled': True, + 'compress': 'deflate', + 'interleave': 'band', }, - "data_type": "categorical", - "vals": [1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15], + 'data_type': 'categorical', + 'vals': [1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15], }, - "c_roads": { - "profile": { - "driver": "GTiff", - "dtype": "uint8", - "nodata": None, - "count": 1, - "crs": CRS.from_epsg(26914), - "blockxsize": 512, - "blockysize": 512, - "tiled": True, - "compress": "deflate", - "interleave": "band", + 'c_roads': { + 'profile': { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'count': 1, + 'crs': CRS.from_epsg(26914), + 'blockxsize': 512, + 'blockysize': 512, + 'tiled': True, + 'compress': 'deflate', + 'interleave': 'band', }, - "data_type": "categorical", - "vals": [0, 1], + 'data_type': 'categorical', + 'vals': [0, 1], }, - "d1_waterways": { - "profile": { - "driver": "GTiff", - "dtype": "uint8", - "nodata": None, - "count": 1, - "crs": CRS.from_epsg(26914), - "blockxsize": 512, - "blockysize": 512, - "tiled": True, - "compress": "deflate", - "interleave": "band", + 'd1_waterways': { + 'profile': { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'count': 1, + 'crs': CRS.from_epsg(26914), + 'blockxsize': 512, + 'blockysize': 512, + 'tiled': True, + 'compress': 'deflate', + 'interleave': 'band', }, - "data_type": "categorical", - "vals": [0, 1], + 'data_type': 'categorical', + 'vals': [0, 1], }, - "d2_waterbodies": { - "profile": { - "driver": "GTiff", - "dtype": "uint8", - "nodata": None, - "count": 1, - "crs": CRS.from_epsg(26914), - "blockxsize": 512, - "blockysize": 512, - "tiled": True, - "compress": "deflate", - "interleave": "band", + 'd2_waterbodies': { + 'profile': { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'count': 1, + 'crs': CRS.from_epsg(26914), + 'blockxsize': 512, + 'blockysize': 512, + 'tiled': True, + 'compress': 'deflate', + 'interleave': 'band', }, - "data_type": "categorical", - "vals": [0, 1], + 'data_type': 'categorical', + 'vals': [0, 1], }, - "d_water": { - "profile": { - "driver": "GTiff", - "dtype": "uint8", - "nodata": None, - "count": 1, - "crs": CRS.from_epsg(26914), - "blockxsize": 512, - "blockysize": 512, - "tiled": True, - "compress": "deflate", - "interleave": "band", + 'd_water': { + 'profile': { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'count': 1, + 'crs': CRS.from_epsg(26914), + 'blockxsize': 512, + 'blockysize': 512, + 'tiled': True, + 'compress': 'deflate', + 'interleave': 'band', }, - "data_type": "categorical", - "vals": [0, 1], + 'data_type': 'categorical', + 'vals': [0, 1], }, - "e_buildings": { - "profile": { - "driver": "GTiff", - "dtype": "uint8", - "nodata": None, - "count": 1, - "crs": CRS.from_epsg(26914), - "blockxsize": 512, - "blockysize": 512, - "tiled": True, - "compress": "deflate", - "interleave": "band", + 'e_buildings': { + 'profile': { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'count': 1, + 'crs': CRS.from_epsg(26914), + 'blockxsize': 512, + 'blockysize': 512, + 'tiled': True, + 'compress': 'deflate', + 'interleave': 'band', }, - "data_type": "categorical", - "vals": [0, 1], + 'data_type': 'categorical', + 'vals': [0, 1], }, - "h_highres_labels": { - "profile": { - "driver": "GTiff", - "dtype": "uint8", - "nodata": None, - "count": 1, - "crs": CRS.from_epsg(26914), - "blockxsize": 512, - "blockysize": 512, - "tiled": True, - "compress": "deflate", - "interleave": "band", + 'h_highres_labels': { + 'profile': { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'count': 1, + 'crs': CRS.from_epsg(26914), + 'blockxsize': 512, + 'blockysize': 512, + 'tiled': True, + 'compress': 'deflate', + 'interleave': 'band', }, - "data_type": "categorical", - "vals": [10, 20, 30, 40, 70], + 'data_type': 'categorical', + 'vals': [10, 20, 30, 40, 70], }, - "prior_from_cooccurrences_101_31": { - "profile": { - "driver": "GTiff", - "dtype": "uint8", - "nodata": None, - "count": 5, - "crs": CRS.from_epsg(26914), - "blockxsize": 512, - "blockysize": 512, - "tiled": True, - "compress": "deflate", - "interleave": "band", + 'prior_from_cooccurrences_101_31': { + 'profile': { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'count': 5, + 'crs': CRS.from_epsg(26914), + 'blockxsize': 512, + 'blockysize': 512, + 'tiled': True, + 'compress': 'deflate', + 'interleave': 'band', }, - "data_type": "continuous", - "vals": (0, 225), + 'data_type': 'continuous', + 'vals': (0, 225), }, - "prior_from_cooccurrences_101_31_no_osm_no_buildings": { - "profile": { - "driver": "GTiff", - "dtype": "uint8", - "nodata": None, - "count": 5, - "crs": CRS.from_epsg(26914), - "blockxsize": 512, - "blockysize": 512, - "tiled": True, - "compress": "deflate", - "interleave": "band", + 'prior_from_cooccurrences_101_31_no_osm_no_buildings': { + 'profile': { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'count': 5, + 'crs': CRS.from_epsg(26914), + 'blockxsize': 512, + 'blockysize': 512, + 'tiled': True, + 'compress': 'deflate', + 'interleave': 'band', }, - "data_type": "continuous", - "vals": (0, 220), + 'data_type': 'continuous', + 'vals': (0, 220), }, } tile_list = [ - "pittsburgh_pa-2010_1m-train_tiles-debuffered/4007925_se", - "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw", + 'pittsburgh_pa-2010_1m-train_tiles-debuffered/4007925_se', + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw', ] def write_data(path: str, profile: dict[Any, Any], data_type: Any, vals: Any) -> None: - assert all(key in profile for key in ("count", "height", "width", "dtype")) - with rasterio.open(path, "w", **profile) as dst: - size = (profile["count"], profile["height"], profile["width"]) - dtype = np.dtype(profile["dtype"]) - if data_type == "continuous": + assert all(key in profile for key in ('count', 'height', 'width', 'dtype')) + with rasterio.open(path, 'w', **profile) as dst: + size = (profile['count'], profile['height'], profile['width']) + dtype = np.dtype(profile['dtype']) + if data_type == 'continuous': data = np.random.randint(vals[0], vals[1] + 1, size=size, dtype=dtype) - elif data_type == "categorical": + elif data_type == 'categorical': data = np.random.choice(vals, size=size).astype(dtype) else: - raise ValueError(f"{data_type} is not recognized") + raise ValueError(f'{data_type} is not recognized') dst.write(data) @@ -222,82 +222,82 @@ def generate_test_data(root: str) -> str: str: md5 hash of created archive """ size = (64, 64) - folder_path = os.path.join(root, "enviroatlas_lotp") + folder_path = os.path.join(root, 'enviroatlas_lotp') if not os.path.exists(folder_path): os.makedirs(folder_path) for prefix in tile_list: for suffix, data_profile in layer_data_profiles.items(): - img_path = os.path.join(folder_path, f"{prefix}_{suffix}.tif") + img_path = os.path.join(folder_path, f'{prefix}_{suffix}.tif') img_dir = os.path.dirname(img_path) if not os.path.exists(img_dir): os.makedirs(img_dir) - data_profile["profile"]["height"] = size[0] - data_profile["profile"]["width"] = size[1] - data_profile["profile"]["transform"] = Affine( + data_profile['profile']['height'] = size[0] + data_profile['profile']['width'] = size[1] + data_profile['profile']['transform'] = Affine( 1.0, 0.0, 608170.0, 0.0, -1.0, 3381430.0 ) write_data( img_path, - data_profile["profile"], - data_profile["data_type"], - data_profile["vals"], + data_profile['profile'], + data_profile['data_type'], + data_profile['vals'], ) # build the spatial index schema = { - "geometry": "Polygon", - "properties": { - "split": "str", - "naip": "str", - "nlcd": "str", - "roads": "str", - "water": "str", - "waterways": "str", - "waterbodies": "str", - "buildings": "str", - "lc": "str", - "prior_no_osm_no_buildings": "str", - "prior": "str", + 'geometry': 'Polygon', + 'properties': { + 'split': 'str', + 'naip': 'str', + 'nlcd': 'str', + 'roads': 'str', + 'water': 'str', + 'waterways': 'str', + 'waterbodies': 'str', + 'buildings': 'str', + 'lc': 'str', + 'prior_no_osm_no_buildings': 'str', + 'prior': 'str', }, } with fiona.open( - os.path.join(folder_path, "spatial_index.geojson"), - "w", - driver="GeoJSON", - crs="EPSG:3857", + os.path.join(folder_path, 'spatial_index.geojson'), + 'w', + driver='GeoJSON', + crs='EPSG:3857', schema=schema, ) as dst: for prefix in tile_list: - img_path = os.path.join(folder_path, f"{prefix}_a_naip.tif") + img_path = os.path.join(folder_path, f'{prefix}_a_naip.tif') with rasterio.open(img_path) as f: geom = shapely.geometry.mapping(shapely.geometry.box(*f.bounds)) geom = fiona.transform.transform_geom( - f.crs.to_string(), "EPSG:3857", geom + f.crs.to_string(), 'EPSG:3857', geom ) row = { - "geometry": geom, - "properties": { - "split": prefix.split("/")[0].replace("_tiles-debuffered", "") + 'geometry': geom, + 'properties': { + 'split': prefix.split('/')[0].replace('_tiles-debuffered', '') }, } for suffix, data_profile in layer_data_profiles.items(): key = suffix_to_key_map[suffix] - row["properties"][key] = f"{prefix}_{suffix}.tif" + row['properties'][key] = f'{prefix}_{suffix}.tif' dst.write(row) # Create archive - archive_path = os.path.join(root, "enviroatlas_lotp") - shutil.make_archive(archive_path, "zip", root_dir=root, base_dir="enviroatlas_lotp") + archive_path = os.path.join(root, 'enviroatlas_lotp') + shutil.make_archive(archive_path, 'zip', root_dir=root, base_dir='enviroatlas_lotp') shutil.rmtree(folder_path) - md5: str = calculate_md5(archive_path + ".zip") + md5: str = calculate_md5(archive_path + '.zip') return md5 -if __name__ == "__main__": +if __name__ == '__main__': md5_hash = generate_test_data(os.getcwd()) print(md5_hash) diff --git a/tests/data/esri2020/data.py b/tests/data/esri2020/data.py index 4bb37df701d..901947851a6 100755 --- a/tests/data/esri2020/data.py +++ b/tests/data/esri2020/data.py @@ -15,31 +15,31 @@ SIZE = 64 -files = [{"image": "N00E020_agb.tif"}, {"image": "N00E020_agb_err.tif"}] +files = [{'image': 'N00E020_agb.tif'}, {'image': 'N00E020_agb_err.tif'}] def create_file(path: str, dtype: str, num_channels: int) -> None: profile = {} - profile["driver"] = "GTiff" - profile["dtype"] = dtype - profile["count"] = num_channels - profile["crs"] = "epsg:4326" - profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) - profile["height"] = SIZE - profile["width"] = SIZE - profile["compress"] = "lzw" - profile["predictor"] = 2 + profile['driver'] = 'GTiff' + profile['dtype'] = dtype + profile['count'] = num_channels + profile['crs'] = 'epsg:4326' + profile['transform'] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile['height'] = SIZE + profile['width'] = SIZE + profile['compress'] = 'lzw' + profile['predictor'] = 2 Z = np.random.randint( - np.iinfo(profile["dtype"]).max, size=(1, SIZE, SIZE), dtype=profile["dtype"] + np.iinfo(profile['dtype']).max, size=(1, SIZE, SIZE), dtype=profile['dtype'] ) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(Z) -if __name__ == "__main__": - dir = "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01" - tif_name = "00A_20200101-20210101.tif" +if __name__ == '__main__': + dir = 'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01' + tif_name = '00A_20200101-20210101.tif' if os.path.exists(dir): shutil.rmtree(dir) @@ -47,12 +47,12 @@ def create_file(path: str, dtype: str, num_channels: int) -> None: os.makedirs(dir) # Create mask file - create_file(os.path.join(dir, tif_name), dtype="int8", num_channels=1) + create_file(os.path.join(dir, tif_name), dtype='int8', num_channels=1) - shutil.make_archive(dir, "zip", base_dir=dir) + shutil.make_archive(dir, 'zip', base_dir=dir) # Compute checksums - zipfilename = dir + ".zip" - with open(zipfilename, "rb") as f: + zipfilename = dir + '.zip' + with open(zipfilename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{zipfilename}: {md5}") + print(f'{zipfilename}: {md5}') diff --git a/tests/data/etci2021/data.py b/tests/data/etci2021/data.py index b90b7f51b93..9b3a1900e23 100755 --- a/tests/data/etci2021/data.py +++ b/tests/data/etci2021/data.py @@ -13,39 +13,39 @@ metadatas = [ { - "filename": "train.zip", - "directory": "train", - "subdirs": [ - "nebraska_20170108t002112", - "bangladesh_20170314t115609", - "northal_20190302t234651", + 'filename': 'train.zip', + 'directory': 'train', + 'subdirs': [ + 'nebraska_20170108t002112', + 'bangladesh_20170314t115609', + 'northal_20190302t234651', ], }, { - "filename": "val_with_ref_labels.zip", - "directory": "test", - "subdirs": [ - "florence_20180510t231343", - "florence_20180522t231344", - "florence_20190302t234651", + 'filename': 'val_with_ref_labels.zip', + 'directory': 'test', + 'subdirs': [ + 'florence_20180510t231343', + 'florence_20180522t231344', + 'florence_20190302t234651', ], }, { - "filename": "test_without_ref_labels.zip", - "directory": "test_internal", - "subdirs": [ - "redrivernorth_20190104t002247", - "redrivernorth_20190116t002247", - "redrivernorth_20190302t234651", + 'filename': 'test_without_ref_labels.zip', + 'directory': 'test_internal', + 'subdirs': [ + 'redrivernorth_20190104t002247', + 'redrivernorth_20190116t002247', + 'redrivernorth_20190302t234651', ], }, ] -tiles = ["vh", "vv", "water_body_label", "flood_label"] +tiles = ['vh', 'vv', 'water_body_label', 'flood_label'] for metadata in metadatas: - filename = metadata["filename"] - directory = metadata["directory"] + filename = metadata['filename'] + directory = metadata['directory'] # Remove old data if os.path.exists(filename): @@ -54,25 +54,25 @@ shutil.rmtree(directory) # Create images - for subdir in metadata["subdirs"]: + for subdir in metadata['subdirs']: for tile in tiles: - if directory == "test_internal" and tile == "flood_label": + if directory == 'test_internal' and tile == 'flood_label': continue - fn = f"{subdir}_x-0_y-0" - if tile in ["vh", "vv"]: - fn += f"_{tile}" - fn += ".png" - fd = os.path.join(directory, subdir, "tiles", tile) + fn = f'{subdir}_x-0_y-0' + if tile in ['vh', 'vv']: + fn += f'_{tile}' + fn += '.png' + fd = os.path.join(directory, subdir, 'tiles', tile) os.makedirs(fd) - img = Image.new("RGB", (SIZE, SIZE)) + img = Image.new('RGB', (SIZE, SIZE)) img.save(os.path.join(fd, fn)) # Compress data - shutil.make_archive(filename.replace(".zip", ""), "zip", ".", directory) + shutil.make_archive(filename.replace('.zip', ''), 'zip', '.', directory) # Compute checksums - with open(filename, "rb") as f: + with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(repr(filename) + ":", repr(md5) + ",") + print(repr(filename) + ':', repr(md5) + ',') diff --git a/tests/data/eudem/data.py b/tests/data/eudem/data.py index 5b707fae2f4..e36c56bfcd2 100755 --- a/tests/data/eudem/data.py +++ b/tests/data/eudem/data.py @@ -12,51 +12,51 @@ SIZE = 64 -files = [{"image": "eu_dem_v11_E30N10.TIF"}, {"image": "eu_dem_v11_E30N10.TIF.ovr"}] +files = [{'image': 'eu_dem_v11_E30N10.TIF'}, {'image': 'eu_dem_v11_E30N10.TIF.ovr'}] def create_file(path: str, dtype: str, num_channels: int) -> None: profile = {} - profile["driver"] = "GTiff" - profile["dtype"] = dtype - profile["count"] = num_channels - profile["crs"] = "epsg:4326" - profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) - profile["height"] = SIZE - profile["width"] = SIZE - profile["compress"] = "lzw" - profile["predictor"] = 2 + profile['driver'] = 'GTiff' + profile['dtype'] = dtype + profile['count'] = num_channels + profile['crs'] = 'epsg:4326' + profile['transform'] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile['height'] = SIZE + profile['width'] = SIZE + profile['compress'] = 'lzw' + profile['predictor'] = 2 Z = np.random.randint( - np.iinfo(profile["dtype"]).max, size=(1, SIZE, SIZE), dtype=profile["dtype"] + np.iinfo(profile['dtype']).max, size=(1, SIZE, SIZE), dtype=profile['dtype'] ) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(Z) -if __name__ == "__main__": - zipfilename = "eu_dem_v11_E30N10.zip" +if __name__ == '__main__': + zipfilename = 'eu_dem_v11_E30N10.zip' files_to_zip = [] for file_dict in files: - path = file_dict["image"] + path = file_dict['image'] # remove old data if os.path.exists(path): os.remove(path) # Create mask file - create_file(path, dtype="int32", num_channels=1) + create_file(path, dtype='int32', num_channels=1) files_to_zip.append(path) # Compress data - with zipfile.ZipFile(zipfilename, "w") as zip: + with zipfile.ZipFile(zipfilename, 'w') as zip: for file in files_to_zip: zip.write(file, arcname=file) # Compute checksums - with open(zipfilename, "rb") as f: + with open(zipfilename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{zipfilename}: {md5}") + print(f'{zipfilename}: {md5}') # remove TIF files for file_dict in files: - os.remove(file_dict["image"]) + os.remove(file_dict['image']) diff --git a/tests/data/eurocrops/data.py b/tests/data/eurocrops/data.py index 54a5867e01b..4128407a3e1 100755 --- a/tests/data/eurocrops/data.py +++ b/tests/data/eurocrops/data.py @@ -19,9 +19,9 @@ def create_data_file(dataname): - schema = {"geometry": "Polygon", "properties": {"EC_hcat_c": "str"}} + schema = {'geometry': 'Polygon', 'properties': {'EC_hcat_c': 'str'}} with fiona.open( - dataname, "w", crs=CRS.from_epsg(32616), driver="ESRI Shapefile", schema=schema + dataname, 'w', crs=CRS.from_epsg(32616), driver='ESRI Shapefile', schema=schema ) as shpfile: coordinates = [[0.0, 0.0], [0.0, SIZE], [SIZE, SIZE], [SIZE, 0.0], [0.0, 0.0]] # The offset aligns with tests/data/sentinel2/data.py. @@ -29,34 +29,34 @@ def create_data_file(dataname): coordinates = [[x + offset[0], y + offset[1]] for x, y in coordinates] polygon = Polygon(coordinates) - properties = {"EC_hcat_c": "1000000010"} - shpfile.write({"geometry": mapping(polygon), "properties": properties}) + properties = {'EC_hcat_c': '1000000010'} + shpfile.write({'geometry': mapping(polygon), 'properties': properties}) def create_csv(fname): - with open(fname, "w") as f: - writer = csv.DictWriter(f, fieldnames=["HCAT2_code"]) + with open(fname, 'w') as f: + writer = csv.DictWriter(f, fieldnames=['HCAT2_code']) writer.writeheader() - writer.writerow({"HCAT2_code": "1000000000"}) - writer.writerow({"HCAT2_code": "1000000010"}) + writer.writerow({'HCAT2_code': '1000000000'}) + writer.writerow({'HCAT2_code': '1000000010'}) -if __name__ == "__main__": - csvname = "HCAT2.csv" - dataname = "AA_2022_EC21.shp" +if __name__ == '__main__': + csvname = 'HCAT2.csv' + dataname = 'AA_2022_EC21.shp' supportnames = [ - "AA_2022_EC21.cpg", - "AA_2022_EC21.dbf", - "AA_2022_EC21.prj", - "AA_2022_EC21.shx", + 'AA_2022_EC21.cpg', + 'AA_2022_EC21.dbf', + 'AA_2022_EC21.prj', + 'AA_2022_EC21.shx', ] - zipfilename = "AA.zip" + zipfilename = 'AA.zip' # create crop type data geojson_data = create_data_file(dataname) # archive the geojson to zip - with zipfile.ZipFile(zipfilename, "w") as zipf: + with zipfile.ZipFile(zipfilename, 'w') as zipf: zipf.write(dataname) for name in supportnames: zipf.write(name) @@ -68,9 +68,9 @@ def create_csv(fname): create_csv(csvname) # Compute checksums - with open(zipfilename, "rb") as f: + with open(zipfilename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{zipfilename}: {md5}") - with open(csvname, "rb") as f: + print(f'{zipfilename}: {md5}') + with open(csvname, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{csvname}: {md5}") + print(f'{csvname}: {md5}') diff --git a/tests/data/fire_risk/data.py b/tests/data/fire_risk/data.py index 6f6613037be..59e011a6d01 100755 --- a/tests/data/fire_risk/data.py +++ b/tests/data/fire_risk/data.py @@ -16,65 +16,65 @@ PATHS = [ os.path.join( - "FireRisk", "train", "High", "27032281_4_-103.430441201095_44.2804260315038.png" + 'FireRisk', 'train', 'High', '27032281_4_-103.430441201095_44.2804260315038.png' ), os.path.join( - "FireRisk", "train", "Low", "27032391_2_-103.058289903541_44.3007203324261.png" + 'FireRisk', 'train', 'Low', '27032391_2_-103.058289903541_44.3007203324261.png' ), os.path.join( - "FireRisk", - "train", - "Moderate", - "27033601_3_-98.95279624632_44.455109470962.png", + 'FireRisk', + 'train', + 'Moderate', + '27033601_3_-98.95279624632_44.455109470962.png', ), os.path.join( - "FireRisk", - "train", - "Non-burnable", - "27033161_6_-100.447787439271_44.4136022778593.png", + 'FireRisk', + 'train', + 'Non-burnable', + '27033161_6_-100.447787439271_44.4136022778593.png', ), os.path.join( - "FireRisk", - "train", - "Very_High", - "27041631_5_-123.547051830273_41.5463004986268.png", + 'FireRisk', + 'train', + 'Very_High', + '27041631_5_-123.547051830273_41.5463004986268.png', ), os.path.join( - "FireRisk", "val", "High", "35501951_4_-73.9911660056379_41.2755665931274.png" + 'FireRisk', 'val', 'High', '35501951_4_-73.9911660056379_41.2755665931274.png' ), os.path.join( - "FireRisk", "val", "Low", "35501621_2_-75.0371666057303_41.4540009148918.png" + 'FireRisk', 'val', 'Low', '35501621_2_-75.0371666057303_41.4540009148918.png' ), os.path.join( - "FireRisk", - "val", - "Moderate", - "35501731_3_-74.6879125510064_41.3954685534897.png", + 'FireRisk', + 'val', + 'Moderate', + '35501731_3_-74.6879125510064_41.3954685534897.png', ), os.path.join( - "FireRisk", - "val", - "Non-burnable", - "35502061_6_-73.6436892181052_41.2142019946826.png", + 'FireRisk', + 'val', + 'Non-burnable', + '35502061_6_-73.6436892181052_41.2142019946826.png', ), os.path.join( - "FireRisk", - "val", - "Very_High", - "35502941_5_-122.968467383602_40.2960022654498.png", + 'FireRisk', + 'val', + 'Very_High', + '35502941_5_-122.968467383602_40.2960022654498.png', ), ] def create_file(path: str) -> None: Z = np.random.randint(255, size=(SIZE, SIZE, 3), dtype=np.uint8) - img = Image.fromarray(Z).convert("RGB") + img = Image.fromarray(Z).convert('RGB') img.save(path) -if __name__ == "__main__": - directory = "FireRisk" - filename = "FireRisk.zip" +if __name__ == '__main__': + directory = 'FireRisk' + filename = 'FireRisk.zip' # remove old data if os.path.isdir(directory): @@ -85,9 +85,9 @@ def create_file(path: str) -> None: create_file(path) # compress data - shutil.make_archive(directory, "zip", ".", directory) + shutil.make_archive(directory, 'zip', '.', directory) # compute checksum - with open(filename, "rb") as f: + with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{filename}: {md5}") + print(f'{filename}: {md5}') diff --git a/tests/data/forestdamage/data.py b/tests/data/forestdamage/data.py index 264b5e97ffb..911af1bcf63 100755 --- a/tests/data/forestdamage/data.py +++ b/tests/data/forestdamage/data.py @@ -16,40 +16,40 @@ np.random.seed(0) PATHS = { - "images": [ - "Bebehojd_20190527/Images/B01_0004.JPG", - "Bebehojd_20190527/Images/B01_0005.JPG", + 'images': [ + 'Bebehojd_20190527/Images/B01_0004.JPG', + 'Bebehojd_20190527/Images/B01_0005.JPG', ], - "annotations": [ - "Bebehojd_20190527/Annotations/B01_0004.xml", - "Bebehojd_20190527/Annotations/B01_0005.xml", + 'annotations': [ + 'Bebehojd_20190527/Annotations/B01_0004.xml', + 'Bebehojd_20190527/Annotations/B01_0005.xml', ], - "labels": [True, False], + 'labels': [True, False], } def create_annotation(path: str) -> None: - root = ET.Element("annotation") + root = ET.Element('annotation') - ET.SubElement(root, "filename").text = os.path.basename(path) + ET.SubElement(root, 'filename').text = os.path.basename(path) - size = ET.SubElement(root, "size") + size = ET.SubElement(root, 'size') - ET.SubElement(size, "width").text = str(SIZE) - ET.SubElement(size, "height").text = str(SIZE) - ET.SubElement(size, "depth").text = str(3) + ET.SubElement(size, 'width').text = str(SIZE) + ET.SubElement(size, 'height').text = str(SIZE) + ET.SubElement(size, 'depth').text = str(3) - for label in PATHS["labels"]: - annotation = ET.SubElement(root, "object") + for label in PATHS['labels']: + annotation = ET.SubElement(root, 'object') if label: - ET.SubElement(annotation, "damage").text = "other" + ET.SubElement(annotation, 'damage').text = 'other' - bbox = ET.SubElement(annotation, "bndbox") - ET.SubElement(bbox, "xmin").text = str(0 + int(SIZE / 4)) - ET.SubElement(bbox, "ymin").text = str(0 + int(SIZE / 4)) - ET.SubElement(bbox, "xmax").text = str(SIZE - int(SIZE / 4)) - ET.SubElement(bbox, "ymax").text = str(SIZE - int(SIZE / 4)) + bbox = ET.SubElement(annotation, 'bndbox') + ET.SubElement(bbox, 'xmin').text = str(0 + int(SIZE / 4)) + ET.SubElement(bbox, 'ymin').text = str(0 + int(SIZE / 4)) + ET.SubElement(bbox, 'xmax').text = str(SIZE - int(SIZE / 4)) + ET.SubElement(bbox, 'ymax').text = str(SIZE - int(SIZE / 4)) tree = ET.ElementTree(root) tree.write(path) @@ -57,30 +57,30 @@ def create_annotation(path: str) -> None: def create_file(path: str) -> None: Z = np.random.rand(SIZE, SIZE, 3) * 255 - img = Image.fromarray(Z.astype("uint8")).convert("RGB") + img = Image.fromarray(Z.astype('uint8')).convert('RGB') img.save(path) -if __name__ == "__main__": - data_root = "Data_Set_Larch_Casebearer" +if __name__ == '__main__': + data_root = 'Data_Set_Larch_Casebearer' # remove old data if os.path.isdir(data_root): shutil.rmtree(data_root) else: os.makedirs(data_root) - for path in PATHS["images"]: + for path in PATHS['images']: os.makedirs(os.path.join(data_root, os.path.dirname(path)), exist_ok=True) create_file(os.path.join(data_root, path)) - for path in PATHS["annotations"]: + for path in PATHS['annotations']: os.makedirs(os.path.join(data_root, os.path.dirname(path)), exist_ok=True) create_annotation(os.path.join(data_root, path)) # compress data - shutil.make_archive(data_root, "zip", ".", data_root) + shutil.make_archive(data_root, 'zip', '.', data_root) # Compute checksums - with open(data_root + ".zip", "rb") as f: + with open(data_root + '.zip', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{data_root}: {md5}") + print(f'{data_root}: {md5}') diff --git a/tests/data/gbif/data.py b/tests/data/gbif/data.py index b2f001e407d..34ddb3fa696 100755 --- a/tests/data/gbif/data.py +++ b/tests/data/gbif/data.py @@ -5,61 +5,61 @@ import pandas as pd -filename = "0123456-012345678901234.csv" +filename = '0123456-012345678901234.csv' size = 6 data = { - "gbifID": [""] * size, - "datasetKey": [""] * size, - "occurrenceID": [""] * size, - "kingdom": ["Animalia"] * size, - "phylum": ["Chordata"] * size, - "class": ["Mammalia"] * size, - "order": ["Primates"] * size, - "family": ["Hominidae"] * size, - "genus": ["Homo"] * size, - "species": ["Homo sapiens"] * size, - "infraspecificEpithet": [""] * size, - "taxonRank": ["SPECIES"] * size, - "scientificName": ["Homo sapiens Linnaeus, 1758"] * size, - "verbatimScientificName": ["Homo sapiens Linnaeus, 1758"] * size, - "verbatimScientificNameAuthorship": ["Linnaeus, 1758"] * size, - "countryCode": ["US"] * size, - "locality": ["Chicago"] * size, - "stateProvince": ["Illinois"] * size, - "occurrenceStatus": ["PRESENT"] * size, - "individualCount": [1] * size, - "publishingOrgKey": [""] * size, - "decimalLatitude": [41.881832] * size, - "decimalLongitude": [""] + [-87.623177] * (size - 1), - "coordinateUncertaintyInMeters": [5] * size, - "coordinatePrecision": [""] * size, - "elevation": [""] * size, - "elevationAccuracy": [""] * size, - "depth": [""] * size, - "depthAccuracy": [""] * size, - "eventDate": ["", "", "", "", -450, "2022-04-16T10:13:35.123Z"], - "day": [16, "", "", "", "", 16], - "month": [4, "", "", 12, 4, 4], - "year": [2022, "", 2022, 2022, 2022, 2022], - "taxonKey": [1] * size, - "speciesKey": [1] * size, - "basisOfRecord": ["HUMAN_OBSERVATION"] * size, - "institutionCode": [""] * size, - "collectionCode": [""] * size, - "catalogNumber": [""] * size, - "recordNumber": [""] * size, - "identifiedBy": [""] * size, - "dateIdentified": [""] * size, - "license": [""] * size, - "rightsHolder": [""] * size, - "recordedBy": [""] * size, - "typeStatus": [""] * size, - "establishmentMeans": [""] * size, - "lastInterpreted": [""] * size, - "mediaType": [""] * size, - "issue": [""] * size, + 'gbifID': [''] * size, + 'datasetKey': [''] * size, + 'occurrenceID': [''] * size, + 'kingdom': ['Animalia'] * size, + 'phylum': ['Chordata'] * size, + 'class': ['Mammalia'] * size, + 'order': ['Primates'] * size, + 'family': ['Hominidae'] * size, + 'genus': ['Homo'] * size, + 'species': ['Homo sapiens'] * size, + 'infraspecificEpithet': [''] * size, + 'taxonRank': ['SPECIES'] * size, + 'scientificName': ['Homo sapiens Linnaeus, 1758'] * size, + 'verbatimScientificName': ['Homo sapiens Linnaeus, 1758'] * size, + 'verbatimScientificNameAuthorship': ['Linnaeus, 1758'] * size, + 'countryCode': ['US'] * size, + 'locality': ['Chicago'] * size, + 'stateProvince': ['Illinois'] * size, + 'occurrenceStatus': ['PRESENT'] * size, + 'individualCount': [1] * size, + 'publishingOrgKey': [''] * size, + 'decimalLatitude': [41.881832] * size, + 'decimalLongitude': [''] + [-87.623177] * (size - 1), + 'coordinateUncertaintyInMeters': [5] * size, + 'coordinatePrecision': [''] * size, + 'elevation': [''] * size, + 'elevationAccuracy': [''] * size, + 'depth': [''] * size, + 'depthAccuracy': [''] * size, + 'eventDate': ['', '', '', '', -450, '2022-04-16T10:13:35.123Z'], + 'day': [16, '', '', '', '', 16], + 'month': [4, '', '', 12, 4, 4], + 'year': [2022, '', 2022, 2022, 2022, 2022], + 'taxonKey': [1] * size, + 'speciesKey': [1] * size, + 'basisOfRecord': ['HUMAN_OBSERVATION'] * size, + 'institutionCode': [''] * size, + 'collectionCode': [''] * size, + 'catalogNumber': [''] * size, + 'recordNumber': [''] * size, + 'identifiedBy': [''] * size, + 'dateIdentified': [''] * size, + 'license': [''] * size, + 'rightsHolder': [''] * size, + 'recordedBy': [''] * size, + 'typeStatus': [''] * size, + 'establishmentMeans': [''] * size, + 'lastInterpreted': [''] * size, + 'mediaType': [''] * size, + 'issue': [''] * size, } df = pd.DataFrame(data) -df.to_csv(filename, sep="\t", index=False) +df.to_csv(filename, sep='\t', index=False) diff --git a/tests/data/globbiomass/data.py b/tests/data/globbiomass/data.py index f407d26e879..31c673fceef 100755 --- a/tests/data/globbiomass/data.py +++ b/tests/data/globbiomass/data.py @@ -16,48 +16,48 @@ files = { - "agb": ["N00E020_agb.tif", "N00E020_agb_err.tif"], - "gsv": ["N00E020_gsv.tif", "N00E020_gsv_err.tif"], + 'agb': ['N00E020_agb.tif', 'N00E020_agb_err.tif'], + 'gsv': ['N00E020_gsv.tif', 'N00E020_gsv_err.tif'], } def create_file(path: str, dtype: str, num_channels: int) -> None: profile = {} - profile["driver"] = "GTiff" - profile["dtype"] = dtype - profile["count"] = num_channels - profile["crs"] = "epsg:4326" - profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) - profile["height"] = SIZE - profile["width"] = SIZE - profile["compress"] = "lzw" - profile["predictor"] = 2 + profile['driver'] = 'GTiff' + profile['dtype'] = dtype + profile['count'] = num_channels + profile['crs'] = 'epsg:4326' + profile['transform'] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile['height'] = SIZE + profile['width'] = SIZE + profile['compress'] = 'lzw' + profile['predictor'] = 2 Z = np.random.randint( - np.iinfo(profile["dtype"]).max, size=(1, SIZE, SIZE), dtype=profile["dtype"] + np.iinfo(profile['dtype']).max, size=(1, SIZE, SIZE), dtype=profile['dtype'] ) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(Z) -if __name__ == "__main__": +if __name__ == '__main__': for measurement, file_paths in files.items(): - zipfilename = f"N00E020_{measurement}.zip" + zipfilename = f'N00E020_{measurement}.zip' files_to_zip = [] for path in file_paths: # remove old data if os.path.exists(path): os.remove(path) # Create mask file - create_file(path, dtype="int32", num_channels=1) + create_file(path, dtype='int32', num_channels=1) files_to_zip.append(path) # Compress data - with zipfile.ZipFile(zipfilename, "w") as zip: + with zipfile.ZipFile(zipfilename, 'w') as zip: for file in files_to_zip: zip.write(file, arcname=file) # Compute checksums - with open(zipfilename, "rb") as f: + with open(zipfilename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{zipfilename}: {md5}") + print(f'{zipfilename}: {md5}') diff --git a/tests/data/inaturalist/data.py b/tests/data/inaturalist/data.py index 6bfbc685008..1c9a39107a1 100755 --- a/tests/data/inaturalist/data.py +++ b/tests/data/inaturalist/data.py @@ -5,53 +5,53 @@ import pandas as pd -filename = "observations-012345.csv" +filename = 'observations-012345.csv' # User can select which columns to export. The following are the default columns. # Not all columns may exist in the actual dataset. size = 4 data = { - "id": [""] * size, - "observed_on_string": [""] * size, - "observed_on": ["", "", "2022-05-07", "2022-05-07"], - "time_observed_at": ["", "", "", "2022-05-07 11:02:53 +0100"], - "time_zone": ["Central Time (US & Canada)"] * size, - "user_id": [123] * size, - "user_login": ["darwin"] * size, - "created_at": ["2022-05-07 11:02:53 +0100"] * size, - "updated_at": ["2022-05-07 11:02:53 +0100"] * size, - "quality_grade": ["research"] * size, - "license": ["CCO"] * size, - "url": ["https://inaturalist.org/observations/123"] * size, - "image_url": [ - "https://inaturalist-open-data.s3.amazonaws.com/photos/123/medium.jpg" + 'id': [''] * size, + 'observed_on_string': [''] * size, + 'observed_on': ['', '', '2022-05-07', '2022-05-07'], + 'time_observed_at': ['', '', '', '2022-05-07 11:02:53 +0100'], + 'time_zone': ['Central Time (US & Canada)'] * size, + 'user_id': [123] * size, + 'user_login': ['darwin'] * size, + 'created_at': ['2022-05-07 11:02:53 +0100'] * size, + 'updated_at': ['2022-05-07 11:02:53 +0100'] * size, + 'quality_grade': ['research'] * size, + 'license': ['CCO'] * size, + 'url': ['https://inaturalist.org/observations/123'] * size, + 'image_url': [ + 'https://inaturalist-open-data.s3.amazonaws.com/photos/123/medium.jpg' ] * size, - "sound_url": ["https://static.inaturalist.org/sounds/123.m4a?123"] * size, - "tag_list": ["Chicago"] * size, - "description": [""] * size, - "num_identification_agreements": [1] * size, - "num_identification_disagreements": [0] * size, - "captive_cultivated": ["false"] * size, - "oauth_application_id": [""] * size, - "place_guess": ["Chicago"] * size, - "latitude": [41.881832] * size, - "longitude": [""] + [-87.623177] * (size - 1), - "positional_accuracy": [5] * size, - "private_place_guess": [""] * size, - "private_latitude": [""] * size, - "private_longitude": [""] * size, - "public_positional_accuracy": [5] * size, - "geoprivacy": [""] * size, - "taxon_geoprivacy": [""] * size, - "coordinates_obscured": ["false"] * size, - "positioning_method": ["gps"] * size, - "positioning_device": ["gps"] * size, - "species_guess": ["Homo sapiens"] * size, - "scientific_name": ["Homo sapiens"] * size, - "common_name": ["human"] * size, - "iconic_taxon_name": ["Animalia"] * size, - "taxon_id": [123] * size, + 'sound_url': ['https://static.inaturalist.org/sounds/123.m4a?123'] * size, + 'tag_list': ['Chicago'] * size, + 'description': [''] * size, + 'num_identification_agreements': [1] * size, + 'num_identification_disagreements': [0] * size, + 'captive_cultivated': ['false'] * size, + 'oauth_application_id': [''] * size, + 'place_guess': ['Chicago'] * size, + 'latitude': [41.881832] * size, + 'longitude': [''] + [-87.623177] * (size - 1), + 'positional_accuracy': [5] * size, + 'private_place_guess': [''] * size, + 'private_latitude': [''] * size, + 'private_longitude': [''] * size, + 'public_positional_accuracy': [5] * size, + 'geoprivacy': [''] * size, + 'taxon_geoprivacy': [''] * size, + 'coordinates_obscured': ['false'] * size, + 'positioning_method': ['gps'] * size, + 'positioning_device': ['gps'] * size, + 'species_guess': ['Homo sapiens'] * size, + 'scientific_name': ['Homo sapiens'] * size, + 'common_name': ['human'] * size, + 'iconic_taxon_name': ['Animalia'] * size, + 'taxon_id': [123] * size, } df = pd.DataFrame(data) diff --git a/tests/data/inria/data.py b/tests/data/inria/data.py index 4e304947adc..96626dea5fa 100755 --- a/tests/data/inria/data.py +++ b/tests/data/inria/data.py @@ -18,7 +18,7 @@ def write_data( ) -> None: with rio.open( path, - "w", + 'w', driver=driver, height=img.shape[0], width=img.shape[1], @@ -42,18 +42,18 @@ def generate_test_data(root: str, n_samples: int = 2) -> str: Returns: str: md5 hash of created archive """ - dtype = np.dtype("uint8") + dtype = np.dtype('uint8') size = (8, 8) - driver = "GTiff" + driver = 'GTiff' transform = Affine(0.3, 0.0, 616500.0, 0.0, -0.3, 3345000.0) crs = CRS.from_epsg(26914) - folder_path = os.path.join(root, "AerialImageDataset") + folder_path = os.path.join(root, 'AerialImageDataset') - img_dir = os.path.join(folder_path, "train", "images") - lbl_dir = os.path.join(folder_path, "train", "gt") - timg_dir = os.path.join(folder_path, "test", "images") + img_dir = os.path.join(folder_path, 'train', 'images') + lbl_dir = os.path.join(folder_path, 'train', 'gt') + timg_dir = os.path.join(folder_path, 'test', 'images') if not os.path.exists(img_dir): os.makedirs(img_dir) @@ -68,22 +68,22 @@ def generate_test_data(root: str, n_samples: int = 2) -> str: lbl = np.random.randint(2, size=size, dtype=dtype) timg = np.random.randint(dtype_max, size=size, dtype=dtype) - img_path = os.path.join(img_dir, f"austin{i+1}.tif") - lbl_path = os.path.join(lbl_dir, f"austin{i+1}.tif") - timg_path = os.path.join(timg_dir, f"austin{i+10}.tif") + img_path = os.path.join(img_dir, f'austin{i+1}.tif') + lbl_path = os.path.join(lbl_dir, f'austin{i+1}.tif') + timg_path = os.path.join(timg_dir, f'austin{i+10}.tif') write_data(img_path, img, driver, crs, transform) write_data(lbl_path, lbl, driver, crs, transform) write_data(timg_path, timg, driver, crs, transform) # Create archive - archive_path = os.path.join(root, "NEW2-AerialImageDataset") + archive_path = os.path.join(root, 'NEW2-AerialImageDataset') shutil.make_archive( - archive_path, "zip", root_dir=root, base_dir="AerialImageDataset" + archive_path, 'zip', root_dir=root, base_dir='AerialImageDataset' ) - return calculate_md5(f"{archive_path}.zip") + return calculate_md5(f'{archive_path}.zip') -if __name__ == "__main__": +if __name__ == '__main__': md5_hash = generate_test_data(os.getcwd(), 7) print(md5_hash) diff --git a/tests/data/l7irish/data.py b/tests/data/l7irish/data.py index 373b34eeeae..de6d6e66ac5 100755 --- a/tests/data/l7irish/data.py +++ b/tests/data/l7irish/data.py @@ -16,47 +16,47 @@ np.random.seed(0) -FILENAME_HIERARCHY = dict[str, "FILENAME_HIERARCHY"] | list[str] +FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str] filenames: FILENAME_HIERARCHY = { - "l7irish": { - "austral": { - "p226_r98": ["L71226098_09820011112.TIF", "L7_p226_r98_newmask2015.TIF"], - "p227_r98": ["L71227098_09820011103.TIF", "L7_p227_r98_newmask2015.TIF"], - "p231_r93_2": ["L71231093_09320010507.TIF", "L7_p231_r93_newmask2015.TIF"], + 'l7irish': { + 'austral': { + 'p226_r98': ['L71226098_09820011112.TIF', 'L7_p226_r98_newmask2015.TIF'], + 'p227_r98': ['L71227098_09820011103.TIF', 'L7_p227_r98_newmask2015.TIF'], + 'p231_r93_2': ['L71231093_09320010507.TIF', 'L7_p231_r93_newmask2015.TIF'], }, - "boreal": { - "p2_r27": ["L71002027_02720010604.TIF", "L7_p2_r27_newmask2015.TIF"], - "p143_r21_3": ["L71143021_02120010803.TIF", "L7_p143_r21_newmask2015.TIF"], + 'boreal': { + 'p2_r27': ['L71002027_02720010604.TIF', 'L7_p2_r27_newmask2015.TIF'], + 'p143_r21_3': ['L71143021_02120010803.TIF', 'L7_p143_r21_newmask2015.TIF'], }, } } def create_file(path: str) -> None: - dtype = "uint8" + dtype = 'uint8' profile = { - "driver": "COG", - "compression": "LZW", - "predictor": 2, - "dtype": dtype, - "width": SIZE, - "height": SIZE, - "crs": CRS.from_epsg(32719), - "transform": Affine(30.0, 0.0, 462884.99999999994, 0.0, -30.0, 4071915.0), + 'driver': 'COG', + 'compression': 'LZW', + 'predictor': 2, + 'dtype': dtype, + 'width': SIZE, + 'height': SIZE, + 'crs': CRS.from_epsg(32719), + 'transform': Affine(30.0, 0.0, 462884.99999999994, 0.0, -30.0, 4071915.0), } - if path.endswith("_newmask2015.TIF"): + if path.endswith('_newmask2015.TIF'): Z = np.random.choice( np.array([0, 64, 128, 192, 255], dtype=dtype), size=(SIZE, SIZE) ) - profile["count"] = 1 + profile['count'] = 1 else: Z = np.random.randint(256, size=(SIZE, SIZE), dtype=dtype) - profile["count"] = 9 + profile['count'] = 9 - with rasterio.open(path, "w", **profile) as src: - for i in range(1, profile["count"] + 1): + with rasterio.open(path, 'w', **profile) as src: + for i in range(1, profile['count'] + 1): src.write(Z, i) @@ -74,17 +74,17 @@ def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: create_file(path) -if __name__ == "__main__": - create_directory(".", filenames) +if __name__ == '__main__': + create_directory('.', filenames) - directories = ["austral", "boreal"] + directories = ['austral', 'boreal'] for directory in directories: filename = str(directory) # Create tarballs - shutil.make_archive(filename, "gztar", ".", os.path.join("l7irish", directory)) + shutil.make_archive(filename, 'gztar', '.', os.path.join('l7irish', directory)) # # Compute checksums - with open(f"{filename}.tar.gz", "rb") as f: + with open(f'{filename}.tar.gz', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() print(filename, md5) diff --git a/tests/data/l8biome/data.py b/tests/data/l8biome/data.py index 8a9503049a9..b0efd0fc7a6 100755 --- a/tests/data/l8biome/data.py +++ b/tests/data/l8biome/data.py @@ -16,32 +16,32 @@ np.random.seed(0) -FILENAME_HIERARCHY = dict[str, "FILENAME_HIERARCHY"] | list[str] +FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str] filenames: FILENAME_HIERARCHY = { - "l8biome": { - "barren": { - "LC80420082013220LGN00": [ - "LC80420082013220LGN00.TIF", - "LC80420082013220LGN00_fixedmask.TIF", + 'l8biome': { + 'barren': { + 'LC80420082013220LGN00': [ + 'LC80420082013220LGN00.TIF', + 'LC80420082013220LGN00_fixedmask.TIF', ], - "LC80530022014156LGN00": [ - "LC80530022014156LGN00.TIF", - "LC80530022014156LGN00_fixedmask.TIF", + 'LC80530022014156LGN00': [ + 'LC80530022014156LGN00.TIF', + 'LC80530022014156LGN00_fixedmask.TIF', ], - "LC81360302014162LGN00": [ - "LC81360302014162LGN00.TIF", - "LC81360302014162LGN00_fixedmask.TIF", + 'LC81360302014162LGN00': [ + 'LC81360302014162LGN00.TIF', + 'LC81360302014162LGN00_fixedmask.TIF', ], }, - "forest": { - "LC80070662014234LGN00": [ - "LC80070662014234LGN00.TIF", - "LC80070662014234LGN00_fixedmask.TIF", + 'forest': { + 'LC80070662014234LGN00': [ + 'LC80070662014234LGN00.TIF', + 'LC80070662014234LGN00_fixedmask.TIF', ], - "LC80200462014005LGN00": [ - "LC80200462014005LGN00.TIF", - "LC80200462014005LGN00_fixedmask.TIF", + 'LC80200462014005LGN00': [ + 'LC80200462014005LGN00.TIF', + 'LC80200462014005LGN00_fixedmask.TIF', ], }, } @@ -49,31 +49,31 @@ def create_file(path: str) -> None: - dtype = "uint8" + dtype = 'uint8' profile = { - "driver": "COG", - "compression": "LZW", - "predictor": 2, - "dtype": dtype, - "width": SIZE, - "height": SIZE, - "crs": CRS.from_epsg(32615), - "transform": Affine(30.0, 0.0, 339885.0, 0.0, -30.0, 8286915.0), + 'driver': 'COG', + 'compression': 'LZW', + 'predictor': 2, + 'dtype': dtype, + 'width': SIZE, + 'height': SIZE, + 'crs': CRS.from_epsg(32615), + 'transform': Affine(30.0, 0.0, 339885.0, 0.0, -30.0, 8286915.0), } - Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"]) + Z = np.random.randn(SIZE, SIZE).astype(profile['dtype']) - if path.endswith("_fixedmask.TIF"): + if path.endswith('_fixedmask.TIF'): Z = np.random.choice( np.array([0, 64, 128, 192, 255], dtype=dtype), size=(SIZE, SIZE) ) - profile["count"] = 1 + profile['count'] = 1 else: Z = np.random.randint(256, size=(SIZE, SIZE), dtype=dtype) - profile["count"] = 11 + profile['count'] = 11 - with rasterio.open(path, "w", **profile) as src: - for i in range(1, profile["count"] + 1): + with rasterio.open(path, 'w', **profile) as src: + for i in range(1, profile['count'] + 1): src.write(Z, i) @@ -91,17 +91,17 @@ def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: create_file(path) -if __name__ == "__main__": - create_directory(".", filenames) +if __name__ == '__main__': + create_directory('.', filenames) - directories = ["barren", "forest"] + directories = ['barren', 'forest'] for directory in directories: filename = str(directory) # Create tarballs - shutil.make_archive(filename, "gztar", ".", os.path.join("l8biome", directory)) + shutil.make_archive(filename, 'gztar', '.', os.path.join('l8biome', directory)) # # Compute checksums - with open(f"{filename}.tar.gz", "rb") as f: + with open(f'{filename}.tar.gz', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() print(filename, md5) diff --git a/tests/data/landcoverai/data.py b/tests/data/landcoverai/data.py index 2e2fd219c50..87e27450878 100755 --- a/tests/data/landcoverai/data.py +++ b/tests/data/landcoverai/data.py @@ -42,47 +42,47 @@ dtype = np.uint8 kwargs = { - "driver": "GTiff", - "dtype": "uint8", - "crs": CRS.from_wkt(wkt), - "transform": Affine(0.25, 0.0, 280307.7499987148, 0.0, -0.25, 394546.9999900842), - "height": SIZE, - "width": SIZE, + 'driver': 'GTiff', + 'dtype': 'uint8', + 'crs': CRS.from_wkt(wkt), + 'transform': Affine(0.25, 0.0, 280307.7499987148, 0.0, -0.25, 394546.9999900842), + 'height': SIZE, + 'width': SIZE, } -filename = "M-33-20-D-c-4-2" +filename = 'M-33-20-D-c-4-2' # Remove old data -zipfilename = "landcover.ai.v1.zip" -for fn in ["train.txt", "val.txt", "test.txt", "split.py", zipfilename]: +zipfilename = 'landcover.ai.v1.zip' +for fn in ['train.txt', 'val.txt', 'test.txt', 'split.py', zipfilename]: if os.path.exists(fn): os.remove(fn) -for directory in ["images", "masks", "output"]: +for directory in ['images', 'masks', 'output']: if os.path.exists(directory): shutil.rmtree(directory) # Create images -os.makedirs("images") +os.makedirs('images') Z = np.random.randint(np.iinfo(dtype).max, size=(SIZE, SIZE), dtype=dtype) with rasterio.open( - os.path.join("images", f"{filename}.tif"), "w", count=3, **kwargs + os.path.join('images', f'{filename}.tif'), 'w', count=3, **kwargs ) as f: for i in range(1, 4): f.write(Z, i) # Create masks -os.makedirs("masks") +os.makedirs('masks') Z = np.random.randint(4, size=(SIZE, SIZE), dtype=dtype) with rasterio.open( - os.path.join("masks", f"{filename}.tif"), "w", count=1, **kwargs + os.path.join('masks', f'{filename}.tif'), 'w', count=1, **kwargs ) as f: f.write(Z, 1) # Create train/val/test splits -files = ["M-33-20-D-c-4-2_0", "M-33-20-D-c-4-2_1"] -for split in ["train", "val", "test"]: - with open(f"{split}.txt", "w") as f: +files = ['M-33-20-D-c-4-2_0', 'M-33-20-D-c-4-2_1'] +for split in ['train', 'val', 'test']: + with open(f'{split}.txt', 'w') as f: for file in files: - f.write(f"{file}\n") + f.write(f'{file}\n') # Create split.py code = f"""\ @@ -98,28 +98,28 @@ cv2.imwrite(os.path.join("output", f"{filename}_{{i}}.jpg"), image) cv2.imwrite(os.path.join("output", f"{filename}_{{i}}_m.png"), mask) """ -with open("split.py", "w") as f: +with open('split.py', 'w') as f: f.write(code) # Create output -with open("split.py") as f: - split = f.read().encode("utf-8") +with open('split.py') as f: + split = f.read().encode('utf-8') exec(split) # Compress data -with zipfile.ZipFile(zipfilename, "w") as f: +with zipfile.ZipFile(zipfilename, 'w') as f: for file in [ - "images/M-33-20-D-c-4-2.tif", - "masks/M-33-20-D-c-4-2.tif", - "train.txt", - "val.txt", - "test.txt", - "split.py", + 'images/M-33-20-D-c-4-2.tif', + 'masks/M-33-20-D-c-4-2.tif', + 'train.txt', + 'val.txt', + 'test.txt', + 'split.py', ]: f.write(file, arcname=file) # Compute checksums -with open(zipfilename, "rb") as f: +with open(zipfilename, 'rb') as f: print(zipfilename, hashlib.md5(f.read()).hexdigest()) -with open("split.py", "rb") as f: - print("split.py", hashlib.sha256(f.read()).hexdigest()) +with open('split.py', 'rb') as f: + print('split.py', hashlib.sha256(f.read()).hexdigest()) diff --git a/tests/data/landcoverai/split.py b/tests/data/landcoverai/split.py index df6e8c53ee3..695b664e027 100644 --- a/tests/data/landcoverai/split.py +++ b/tests/data/landcoverai/split.py @@ -2,10 +2,10 @@ import cv2 -image = cv2.imread(os.path.join("images", "M-33-20-D-c-4-2.tif")) -mask = cv2.imread(os.path.join("masks", "M-33-20-D-c-4-2.tif")) +image = cv2.imread(os.path.join('images', 'M-33-20-D-c-4-2.tif')) +mask = cv2.imread(os.path.join('masks', 'M-33-20-D-c-4-2.tif')) -os.makedirs("output") +os.makedirs('output') for i in range(2): - cv2.imwrite(os.path.join("output", f"M-33-20-D-c-4-2_{i}.jpg"), image) - cv2.imwrite(os.path.join("output", f"M-33-20-D-c-4-2_{i}_m.png"), mask) + cv2.imwrite(os.path.join('output', f'M-33-20-D-c-4-2_{i}.jpg'), image) + cv2.imwrite(os.path.join('output', f'M-33-20-D-c-4-2_{i}_m.png'), mask) diff --git a/tests/data/levircd/levircd/data.py b/tests/data/levircd/levircd/data.py index 533e76d30d4..41b56c903f6 100644 --- a/tests/data/levircd/levircd/data.py +++ b/tests/data/levircd/levircd/data.py @@ -16,37 +16,37 @@ def create_image(path: str) -> None: Z = np.random.randint(255, size=(1, 1, 3), dtype=np.uint8) - img = Image.fromarray(Z).convert("RGB") + img = Image.fromarray(Z).convert('RGB') img.save(path) def create_mask(path: str) -> None: Z = np.random.randint(2, size=(1, 1, 3), dtype=np.uint8) * 255 - img = Image.fromarray(Z).convert("L") + img = Image.fromarray(Z).convert('L') img.save(path) -if __name__ == "__main__": - splits = ["train", "val", "test"] - filenames = ["train.zip", "val.zip", "test.zip"] - directories = ["A", "B", "label"] +if __name__ == '__main__': + splits = ['train', 'val', 'test'] + filenames = ['train.zip', 'val.zip', 'test.zip'] + directories = ['A', 'B', 'label'] for split, filename in zip(splits, filenames): for directory in directories: os.mkdir(directory) for i in range(2): - path = os.path.join("A", f"{split}_{i}.png") + path = os.path.join('A', f'{split}_{i}.png') create_image(path) - path = os.path.join("B", f"{split}_{i}.png") + path = os.path.join('B', f'{split}_{i}.png') create_image(path) - path = os.path.join("label", f"{split}_{i}.png") + path = os.path.join('label', f'{split}_{i}.png') create_mask(path) # compress data - with zipfile.ZipFile(filename, mode="a") as f: + with zipfile.ZipFile(filename, mode='a') as f: for directory in directories: for file in os.listdir(directory): f.write(os.path.join(directory, file)) @@ -55,6 +55,6 @@ def create_mask(path: str) -> None: shutil.rmtree(directory) # compute checksum - with open(filename, "rb") as f: + with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{filename}: {md5}") + print(f'{filename}: {md5}') diff --git a/tests/data/levircd/levircdplus/data.py b/tests/data/levircd/levircdplus/data.py index 5190acafe57..5ea6296e91b 100644 --- a/tests/data/levircd/levircdplus/data.py +++ b/tests/data/levircd/levircdplus/data.py @@ -15,20 +15,20 @@ def create_image(path: str) -> None: Z = np.random.randint(255, size=(1, 1, 3), dtype=np.uint8) - img = Image.fromarray(Z).convert("RGB") + img = Image.fromarray(Z).convert('RGB') img.save(path) def create_mask(path: str) -> None: Z = np.random.randint(2, size=(1, 1, 3), dtype=np.uint8) * 255 - img = Image.fromarray(Z).convert("L") + img = Image.fromarray(Z).convert('L') img.save(path) -if __name__ == "__main__": - root = "LEVIR-CD+" - splits = ["train", "test"] - directories = ["A", "B", "label"] +if __name__ == '__main__': + root = 'LEVIR-CD+' + splits = ['train', 'test'] + directories = ['A', 'B', 'label'] if os.path.exists(root): shutil.rmtree(root) @@ -38,24 +38,24 @@ def create_mask(path: str) -> None: os.makedirs(os.path.join(root, split, directory)) for i in range(2): - folder = os.path.join(root, split, "A") - path = os.path.join(folder, f"0{i}.png") + folder = os.path.join(root, split, 'A') + path = os.path.join(folder, f'0{i}.png') create_image(path) - folder = os.path.join(root, split, "B") - path = os.path.join(folder, f"0{i}.png") + folder = os.path.join(root, split, 'B') + path = os.path.join(folder, f'0{i}.png') create_image(path) - folder = os.path.join(root, split, "label") - path = os.path.join(folder, f"0{i}.png") + folder = os.path.join(root, split, 'label') + path = os.path.join(folder, f'0{i}.png') create_mask(path) # Compress data - shutil.make_archive(root, "zip", ".", root) + shutil.make_archive(root, 'zip', '.', root) # compute checksum - with open(f"{root}.zip", "rb") as f: + with open(f'{root}.zip', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{root}.zip: {md5}") + print(f'{root}.zip: {md5}') shutil.rmtree(root) diff --git a/tests/data/mapinwild/data.py b/tests/data/mapinwild/data.py index f6d089add93..baf640b5dca 100644 --- a/tests/data/mapinwild/data.py +++ b/tests/data/mapinwild/data.py @@ -18,75 +18,75 @@ np.random.seed(0) meta = { - "driver": "GTiff", - "nodata": None, - "width": SIZE, - "height": SIZE, - "crs": CRS.from_epsg(32720), - "transform": Affine(10.0, 0.0, 612190.0, 0.0, -10.0, 7324250.0), + 'driver': 'GTiff', + 'nodata': None, + 'width': SIZE, + 'height': SIZE, + 'crs': CRS.from_epsg(32720), + 'transform': Affine(10.0, 0.0, 612190.0, 0.0, -10.0, 7324250.0), } count = { - "ESA_WC": 1, - "VIIRS": 1, - "mask": 1, - "s1_part1": 2, - "s1_part2": 2, - "s2_temporal_subset_part1": 10, - "s2_temporal_subset_part2": 10, - "s2_autumn_part1": 10, - "s2_autumn_part2": 10, - "s2_spring_part1": 10, - "s2_spring_part2": 10, - "s2_summer_part1": 10, - "s2_summer_part2": 10, - "s2_winter_part1": 10, - "s2_winter_part2": 10, + 'ESA_WC': 1, + 'VIIRS': 1, + 'mask': 1, + 's1_part1': 2, + 's1_part2': 2, + 's2_temporal_subset_part1': 10, + 's2_temporal_subset_part2': 10, + 's2_autumn_part1': 10, + 's2_autumn_part2': 10, + 's2_spring_part1': 10, + 's2_spring_part2': 10, + 's2_summer_part1': 10, + 's2_summer_part2': 10, + 's2_winter_part1': 10, + 's2_winter_part2': 10, } dtype = { - "ESA_WC": np.uint8, - "VIIRS": np.float32, - "mask": np.byte, - "s1_part1": np.float64, - "s1_part2": np.float64, - "s2_temporal_subset_part1": np.uint16, - "s2_temporal_subset_part2": np.uint16, - "s2_autumn_part1": np.uint16, - "s2_autumn_part2": np.uint16, - "s2_spring_part1": np.uint16, - "s2_spring_part2": np.uint16, - "s2_summer_part1": np.uint16, - "s2_summer_part2": np.uint16, - "s2_winter_part1": np.uint16, - "s2_winter_part2": np.uint16, + 'ESA_WC': np.uint8, + 'VIIRS': np.float32, + 'mask': np.byte, + 's1_part1': np.float64, + 's1_part2': np.float64, + 's2_temporal_subset_part1': np.uint16, + 's2_temporal_subset_part2': np.uint16, + 's2_autumn_part1': np.uint16, + 's2_autumn_part2': np.uint16, + 's2_spring_part1': np.uint16, + 's2_spring_part2': np.uint16, + 's2_summer_part1': np.uint16, + 's2_summer_part2': np.uint16, + 's2_winter_part1': np.uint16, + 's2_winter_part2': np.uint16, } stop = { - "ESA_WC": np.iinfo(np.uint8).max, - "VIIRS": np.finfo(np.float32).max, - "mask": np.iinfo(np.byte).max, - "s1_part1": np.finfo(np.float64).max, - "s1_part2": np.finfo(np.float64).max, - "s2_temporal_subset_part1": np.iinfo(np.uint16).max, - "s2_temporal_subset_part2": np.iinfo(np.uint16).max, - "s2_autumn_part1": np.iinfo(np.uint16).max, - "s2_autumn_part2": np.iinfo(np.uint16).max, - "s2_spring_part1": np.iinfo(np.uint16).max, - "s2_spring_part2": np.iinfo(np.uint16).max, - "s2_summer_part1": np.iinfo(np.uint16).max, - "s2_summer_part2": np.iinfo(np.uint16).max, - "s2_winter_part1": np.iinfo(np.uint16).max, - "s2_winter_part2": np.iinfo(np.uint16).max, + 'ESA_WC': np.iinfo(np.uint8).max, + 'VIIRS': np.finfo(np.float32).max, + 'mask': np.iinfo(np.byte).max, + 's1_part1': np.finfo(np.float64).max, + 's1_part2': np.finfo(np.float64).max, + 's2_temporal_subset_part1': np.iinfo(np.uint16).max, + 's2_temporal_subset_part2': np.iinfo(np.uint16).max, + 's2_autumn_part1': np.iinfo(np.uint16).max, + 's2_autumn_part2': np.iinfo(np.uint16).max, + 's2_spring_part1': np.iinfo(np.uint16).max, + 's2_spring_part2': np.iinfo(np.uint16).max, + 's2_summer_part1': np.iinfo(np.uint16).max, + 's2_summer_part2': np.iinfo(np.uint16).max, + 's2_winter_part1': np.iinfo(np.uint16).max, + 's2_winter_part2': np.iinfo(np.uint16).max, } -folder_path = os.path.join(os.getcwd(), "tests", "data", "mapinwild") +folder_path = os.path.join(os.getcwd(), 'tests', 'data', 'mapinwild') dict_all = { - "s2_sum": ["s2_summer_part1", "s2_summer_part2"], - "s2_spr": ["s2_spring_part1", "s2_spring_part2"], - "s2_win": ["s2_winter_part1", "s2_winter_part2"], - "s2_aut": ["s2_autumn_part1", "s2_autumn_part2"], - "s1": ["s1_part1", "s1_part2"], - "s2_temp": ["s2_temporal_subset_part1", "s2_temporal_subset_part2"], + 's2_sum': ['s2_summer_part1', 's2_summer_part2'], + 's2_spr': ['s2_spring_part1', 's2_spring_part2'], + 's2_win': ['s2_winter_part1', 's2_winter_part2'], + 's2_aut': ['s2_autumn_part1', 's2_autumn_part2'], + 's1': ['s1_part1', 's1_part2'], + 's2_temp': ['s2_temporal_subset_part1', 's2_temporal_subset_part2'], } md5s = {} @@ -103,14 +103,14 @@ # Random images for i in range(1, 3): - filename = f"{i}.tif" + filename = f'{i}.tif' filepath = os.path.join(directory, filename) - meta["count"] = count[source] - meta["dtype"] = dtype[source] - with rasterio.open(filepath, "w", **meta) as f: + meta['count'] = count[source] + meta['dtype'] = dtype[source] + with rasterio.open(filepath, 'w', **meta) as f: for j in range(1, count[source] + 1): - if meta["dtype"] is np.float32 or meta["dtype"] is np.float64: + if meta['dtype'] is np.float32 or meta['dtype'] is np.float64: data = np.random.randn(SIZE, SIZE).astype(dtype[source]) else: @@ -145,22 +145,22 @@ root = os.path.dirname(directory) # Compress data - shutil.make_archive(directory, "zip", root_dir=root, base_dir=source) + shutil.make_archive(directory, 'zip', root_dir=root, base_dir=source) # Compute checksums - with open(directory + ".zip", "rb") as f: + with open(directory + '.zip', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{directory}: {md5}") - name = i + ".zip" + print(f'{directory}: {md5}') + name = i + '.zip' md5s[name] = md5 tvt_split = pd.DataFrame( - [["1", "2", "3"], [np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]], - index=["0", "1", "2"], - columns=["train", "validation", "test"], + [['1', '2', '3'], [np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]], + index=['0', '1', '2'], + columns=['train', 'validation', 'test'], ) tvt_split.dropna() -tvt_split.to_csv(os.path.join(folder_path, "split_IDs.csv")) +tvt_split.to_csv(os.path.join(folder_path, 'split_IDs.csv')) -with open(os.path.join(folder_path, "split_IDs.csv"), "rb") as f: +with open(os.path.join(folder_path, 'split_IDs.csv'), 'rb') as f: csv_md5 = hashlib.md5(f.read()).hexdigest() diff --git a/tests/data/millionaid/data.py b/tests/data/millionaid/data.py index 03ea05ad5df..41e024c4532 100755 --- a/tests/data/millionaid/data.py +++ b/tests/data/millionaid/data.py @@ -15,26 +15,26 @@ np.random.seed(0) PATHS = { - "train": [ + 'train': [ os.path.join( - "train", "agriculture_land", "grassland", "meadow", "P0115918.jpg" + 'train', 'agriculture_land', 'grassland', 'meadow', 'P0115918.jpg' ), - os.path.join("train", "water_area", "beach", "P0060208.jpg"), + os.path.join('train', 'water_area', 'beach', 'P0060208.jpg'), ], - "test": [ - os.path.join("test", "agriculture_land", "grassland", "meadow", "P0115918.jpg"), - os.path.join("test", "water_area", "beach", "P0060208.jpg"), + 'test': [ + os.path.join('test', 'agriculture_land', 'grassland', 'meadow', 'P0115918.jpg'), + os.path.join('test', 'water_area', 'beach', 'P0060208.jpg'), ], } def create_file(path: str) -> None: Z = np.random.rand(SIZE, SIZE, 3) * 255 - img = Image.fromarray(Z.astype("uint8")).convert("RGB") + img = Image.fromarray(Z.astype('uint8')).convert('RGB') img.save(path) -if __name__ == "__main__": +if __name__ == '__main__': for split, paths in PATHS.items(): # remove old data if os.path.isdir(split): @@ -44,9 +44,9 @@ def create_file(path: str) -> None: create_file(path) # compress data - shutil.make_archive(split, "zip", ".", split) + shutil.make_archive(split, 'zip', '.', split) # Compute checksums - with open(split + ".zip", "rb") as f: + with open(split + '.zip', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{split}: {md5}") + print(f'{split}: {md5}') diff --git a/tests/data/naip/data.py b/tests/data/naip/data.py index cf9359b77c3..002a572e01f 100755 --- a/tests/data/naip/data.py +++ b/tests/data/naip/data.py @@ -42,32 +42,32 @@ def create_file(path: str, dtype: str, num_channels: int) -> None: profile = {} - profile["driver"] = "GTiff" - profile["dtype"] = dtype - profile["count"] = num_channels - profile["crs"] = CRS.from_wkt(wkt) - profile["transform"] = Affine( + profile['driver'] = 'GTiff' + profile['dtype'] = dtype + profile['count'] = num_channels + profile['crs'] = CRS.from_wkt(wkt) + profile['transform'] = Affine( 1.0, 0.0, 1303555.0000000005, 0.0, -1.0, 2535064.999999998 ) - profile["height"] = SIZE - profile["width"] = SIZE - profile["compress"] = "lzw" - profile["predictor"] = 2 + profile['height'] = SIZE + profile['width'] = SIZE + profile['compress'] = 'lzw' + profile['predictor'] = 2 - if "float" in profile["dtype"]: - Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"]) + if 'float' in profile['dtype']: + Z = np.random.randn(SIZE, SIZE).astype(profile['dtype']) else: Z = np.random.randint( - np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"] + np.iinfo(profile['dtype']).max, size=(SIZE, SIZE), dtype=profile['dtype'] ) - with rasterio.open(path, "w", **profile) as src: - for i in range(1, profile["count"] + 1): + with rasterio.open(path, 'w', **profile) as src: + for i in range(1, profile['count'] + 1): src.write(Z, i) -if __name__ == "__main__": - filenames = ["m_3807511_ne_18_060_20181104.tif", "m_3807511_ne_18_060_20190605.tif"] +if __name__ == '__main__': + filenames = ['m_3807511_ne_18_060_20181104.tif', 'm_3807511_ne_18_060_20190605.tif'] for f in filenames: - create_file(os.path.join(os.getcwd(), f), "uint8", 4) + create_file(os.path.join(os.getcwd(), f), 'uint8', 4) diff --git a/tests/data/nccm/data.py b/tests/data/nccm/data.py index db629664a1d..9dda733f4b3 100644 --- a/tests/data/nccm/data.py +++ b/tests/data/nccm/data.py @@ -14,40 +14,40 @@ SIZE = 128 np.random.seed(0) -files = ["CDL2017_clip.tif", "CDL2018_clip1.tif", "CDL2019_clip.tif"] +files = ['CDL2017_clip.tif', 'CDL2018_clip1.tif', 'CDL2019_clip.tif'] def create_file(path: str, dtype: str): """Create the testing file.""" profile = { - "driver": "GTiff", - "dtype": dtype, - "count": 1, - "crs": CRS.from_epsg(32616), - "transform": Affine(10, 0.0, 399960.0, 0.0, -10, 4500000.0), - "height": SIZE, - "width": SIZE, - "compress": "lzw", - "predictor": 2, + 'driver': 'GTiff', + 'dtype': dtype, + 'count': 1, + 'crs': CRS.from_epsg(32616), + 'transform': Affine(10, 0.0, 399960.0, 0.0, -10, 4500000.0), + 'height': SIZE, + 'width': SIZE, + 'compress': 'lzw', + 'predictor': 2, } allowed_values = [0, 1, 2, 3, 15] Z = np.random.choice(allowed_values, size=(SIZE, SIZE)) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(Z, 1) -if __name__ == "__main__": +if __name__ == '__main__': dir = os.path.join(os.getcwd()) os.makedirs(dir, exist_ok=True) for file in files: - create_file(os.path.join(dir, file), dtype="int8") + create_file(os.path.join(dir, file), dtype='int8') # Compute checksums for file in files: - with open(file, "rb") as f: + with open(file, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{file}: {md5}") + print(f'{file}: {md5}') diff --git a/tests/data/nlcd/data.py b/tests/data/nlcd/data.py index fa1c592ea29..072b6637500 100755 --- a/tests/data/nlcd/data.py +++ b/tests/data/nlcd/data.py @@ -16,7 +16,7 @@ np.random.seed(0) -dir = "nlcd_{}_land_cover_l48_20210604" +dir = 'nlcd_{}_land_cover_l48_20210604' years = [2011, 2019] @@ -46,26 +46,26 @@ def create_file(path: str, dtype: str): """Create the testing file.""" profile = { - "driver": "GTiff", - "dtype": dtype, - "count": 1, - "crs": CRS.from_wkt(wkt), - "transform": Affine(30.0, 0.0, -2493045.0, 0.0, -30.0, 3310005.0), - "height": SIZE, - "width": SIZE, - "compress": "lzw", - "predictor": 2, + 'driver': 'GTiff', + 'dtype': dtype, + 'count': 1, + 'crs': CRS.from_wkt(wkt), + 'transform': Affine(30.0, 0.0, -2493045.0, 0.0, -30.0, 3310005.0), + 'height': SIZE, + 'width': SIZE, + 'compress': 'lzw', + 'predictor': 2, } allowed_values = [0, 11, 12, 21, 22, 23, 24, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95] Z = np.random.choice(allowed_values, size=(SIZE, SIZE)) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(Z, 1) -if __name__ == "__main__": +if __name__ == '__main__': for year in years: year_dir = dir.format(year) # Remove old data @@ -74,14 +74,14 @@ def create_file(path: str, dtype: str): os.makedirs(os.path.join(os.getcwd(), year_dir)) - zip_filename = year_dir + ".zip" - filename = year_dir + ".img" - create_file(os.path.join(year_dir, filename), dtype="int8") + zip_filename = year_dir + '.zip' + filename = year_dir + '.img' + create_file(os.path.join(year_dir, filename), dtype='int8') # Compress data - shutil.make_archive(year_dir, "zip", ".", year_dir) + shutil.make_archive(year_dir, 'zip', '.', year_dir) # Compute checksums - with open(zip_filename, "rb") as f: + with open(zip_filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{zip_filename}: {md5}") + print(f'{zip_filename}: {md5}') diff --git a/tests/data/openbuildings/data.py b/tests/data/openbuildings/data.py index 3e363cf611c..8babe48758d 100755 --- a/tests/data/openbuildings/data.py +++ b/tests/data/openbuildings/data.py @@ -17,20 +17,20 @@ def create_meta_data_file(zipfilename): meta_data = { - "type": "FeatureCollection", - "features": [ + 'type': 'FeatureCollection', + 'features': [ { - "type": "Feature", - "geometry": { - "type": "Polygon", - "coordinates": [ + 'type': 'Feature', + 'geometry': { + 'type': 'Polygon', + 'coordinates': [ [[0.0, 0.0], [0.0, SIZE], [SIZE, SIZE], [SIZE, 0.0], [0.0, 0.0]] ], }, - "properties": { - "tile_id": "025", - "tile_url": f"polygons_s2_level_4_gzip/{zipfilename}", - "size_mb": 0.2, + 'properties': { + 'tile_id': '025', + 'tile_url': f'polygons_s2_level_4_gzip/{zipfilename}', + 'size_mb': 0.2, }, } ], @@ -48,12 +48,12 @@ def create_csv_data_row(lat, long): polygon = Polygon(coordinates) data_row = { - "latitude": lat, - "longitude": long, - "area_in_meters": 1.0, - "confidence": 1.0, - "geometry": polygon.wkt, - "full_plus_code": "ABC", + 'latitude': lat, + 'longitude': long, + 'area_in_meters': 1.0, + 'confidence': 1.0, + 'geometry': polygon.wkt, + 'full_plus_code': 'ABC', } return data_row @@ -69,32 +69,32 @@ def create_buildings_data(): return dict_data -if __name__ == "__main__": - csvname = "000_buildings.csv" - zipfilename = csvname + ".gz" +if __name__ == '__main__': + csvname = '000_buildings.csv' + zipfilename = csvname + '.gz' # create and save metadata meta_data = create_meta_data_file(zipfilename) - with open("tiles.geojson", "w") as fp: + with open('tiles.geojson', 'w') as fp: json.dump(meta_data, fp) # create and archive buildings data buildings_data = create_buildings_data() keys = buildings_data[0].keys() - with open(csvname, "w") as f: + with open(csvname, 'w') as f: w = csv.DictWriter(f, keys) w.writeheader() w.writerows(buildings_data) # archive the csv to gzip - with open(csvname, "rb") as f_in: - with gzip.open(zipfilename, "wb") as f_out: + with open(csvname, 'rb') as f_in: + with gzip.open(zipfilename, 'wb') as f_out: shutil.copyfileobj(f_in, f_out) # Compute checksums - with open(zipfilename, "rb") as f: + with open(zipfilename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{zipfilename}: {md5}") + print(f'{zipfilename}: {md5}') # remove csv file os.remove(csvname) diff --git a/tests/data/oscd/data.py b/tests/data/oscd/data.py index 4a6256ad666..56f9e1d57eb 100755 --- a/tests/data/oscd/data.py +++ b/tests/data/oscd/data.py @@ -15,29 +15,29 @@ np.random.seed(0) directories = [ - "Onera Satellite Change Detection dataset - Images", - "Onera Satellite Change Detection dataset - Train Labels", - "Onera Satellite Change Detection dataset - Test Labels", + 'Onera Satellite Change Detection dataset - Images', + 'Onera Satellite Change Detection dataset - Train Labels', + 'Onera Satellite Change Detection dataset - Test Labels', ] bands = [ - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B09", - "B10", - "B11", - "B12", - "B8A", + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B09', + 'B10', + 'B11', + 'B12', + 'B8A', ] # Remove old data for directory in directories: - filename = f"{directory}.zip" + filename = f'{directory}.zip' if os.path.exists(filename): os.remove(filename) @@ -45,39 +45,39 @@ shutil.rmtree(directory) # Create images -for subdir in ["train1", "train2", "test"]: - for rect in ["imgs_1_rect", "imgs_2_rect"]: +for subdir in ['train1', 'train2', 'test']: + for rect in ['imgs_1_rect', 'imgs_2_rect']: directory = os.path.join(directories[0], subdir, rect) os.makedirs(directory) for band in bands: - filename = os.path.join(directory, f"{band}.tif") + filename = os.path.join(directory, f'{band}.tif') arr = np.random.randint( np.iinfo(np.uint16).max, size=(SIZE, SIZE), dtype=np.uint16 ) img = Image.fromarray(arr) img.save(filename) - filename = os.path.join(directories[0], subdir, "dates.txt") - with open(filename, "w") as f: - for key, value in [("date_1", "20161130"), ("date_2", "20170829")]: - f.write(f"{key}: {value}\n") + filename = os.path.join(directories[0], subdir, 'dates.txt') + with open(filename, 'w') as f: + for key, value in [('date_1', '20161130'), ('date_2', '20170829')]: + f.write(f'{key}: {value}\n') # Create labels -for i, subdir in [(1, "train1"), (1, "train2"), (2, "test")]: - directory = os.path.join(directories[i], subdir, "cm") +for i, subdir in [(1, 'train1'), (1, 'train2'), (2, 'test')]: + directory = os.path.join(directories[i], subdir, 'cm') os.makedirs(directory) - filename = os.path.join(directory, "cm.png") + filename = os.path.join(directory, 'cm.png') arr = np.random.randint(np.iinfo(np.uint8).max, size=(SIZE, SIZE), dtype=np.uint8) img = Image.fromarray(arr) img.save(filename) for directory in directories: # Compress data - shutil.make_archive(directory, "zip", ".", directory) + shutil.make_archive(directory, 'zip', '.', directory) # Compute checksums - filename = f"{directory}.zip" - with open(filename, "rb") as f: + filename = f'{directory}.zip' + with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(repr(filename) + ": " + repr(md5) + ",") + print(repr(filename) + ': ' + repr(md5) + ',') diff --git a/tests/data/pastis/data.py b/tests/data/pastis/data.py index 67ea3c9d7ce..22af29114ec 100644 --- a/tests/data/pastis/data.py +++ b/tests/data/pastis/data.py @@ -15,31 +15,31 @@ MAX_NUM_TIME_STEPS = 10 np.random.seed(0) -FILENAME_HIERARCHY = dict[str, "FILENAME_HIERARCHY"] | list[str] +FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str] filenames: FILENAME_HIERARCHY = { - "DATA_S2": ["S2"], - "DATA_S1A": ["S1A"], - "DATA_S1D": ["S1D"], - "ANNOTATIONS": ["TARGET"], - "INSTANCE_ANNOTATIONS": ["INSTANCES"], + 'DATA_S2': ['S2'], + 'DATA_S1A': ['S1A'], + 'DATA_S1D': ['S1D'], + 'ANNOTATIONS': ['TARGET'], + 'INSTANCE_ANNOTATIONS': ['INSTANCES'], } def create_file(path: str) -> None: for i in range(NUM_SAMPLES): - new_path = f"{path}_{i}.npy" + new_path = f'{path}_{i}.npy' fn = os.path.basename(new_path) t = np.random.randint(1, MAX_NUM_TIME_STEPS) - if fn.startswith("S2"): + if fn.startswith('S2'): data = np.random.randint(0, 256, size=(t, 10, SIZE, SIZE)).astype(np.int16) - elif fn.startswith("S1A"): + elif fn.startswith('S1A'): data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16) - elif fn.startswith("S1D"): + elif fn.startswith('S1D'): data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16) - elif fn.startswith("TARGET"): + elif fn.startswith('TARGET'): data = np.random.randint(0, 20, size=(3, SIZE, SIZE)).astype(np.uint8) - elif fn.startswith("INSTANCES"): + elif fn.startswith('INSTANCES'): data = np.random.randint(0, 100, size=(SIZE, SIZE)).astype(np.int64) np.save(new_path, data) @@ -58,33 +58,33 @@ def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: create_file(path) -if __name__ == "__main__": - create_directory("PASTIS-R", filenames) +if __name__ == '__main__': + create_directory('PASTIS-R', filenames) - schema = {"geometry": "Polygon", "properties": {"Fold": "int", "ID_PATCH": "int"}} + schema = {'geometry': 'Polygon', 'properties': {'Fold': 'int', 'ID_PATCH': 'int'}} with fiona.open( - os.path.join("PASTIS-R", "metadata.geojson"), - "w", - "GeoJSON", - crs="EPSG:4326", + os.path.join('PASTIS-R', 'metadata.geojson'), + 'w', + 'GeoJSON', + crs='EPSG:4326', schema=schema, ) as f: for i in range(NUM_SAMPLES): f.write( { - "geometry": { - "type": "Polygon", - "coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]], + 'geometry': { + 'type': 'Polygon', + 'coordinates': [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]], }, - "id": str(i), - "properties": {"Fold": (i % 5) + 1, "ID_PATCH": i}, + 'id': str(i), + 'properties': {'Fold': (i % 5) + 1, 'ID_PATCH': i}, } ) - filename = "PASTIS-R.zip" - shutil.make_archive(filename.replace(".zip", ""), "zip", ".", "PASTIS-R") + filename = 'PASTIS-R.zip' + shutil.make_archive(filename.replace('.zip', ''), 'zip', '.', 'PASTIS-R') # Compute checksums - with open(filename, "rb") as f: + with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{filename}: {md5}") + print(f'{filename}: {md5}') diff --git a/tests/data/prisma/data.py b/tests/data/prisma/data.py index aa6255f521d..9ec8238c194 100755 --- a/tests/data/prisma/data.py +++ b/tests/data/prisma/data.py @@ -14,31 +14,31 @@ files = [ - "PRS_L0S_EO_NRT_20191215092453_20191215092457_0001.tif", - "PRS_L1_STD_OFFL_20191215092453_20191215092457_0002.tif", - "PRS_L2D_STD_20191215092453_20191215092457_0003.tif", - "PRS_CF_AX_FDP_REPR_20191215092453_20191215092457_0004_0.tif", + 'PRS_L0S_EO_NRT_20191215092453_20191215092457_0001.tif', + 'PRS_L1_STD_OFFL_20191215092453_20191215092457_0002.tif', + 'PRS_L2D_STD_20191215092453_20191215092457_0003.tif', + 'PRS_CF_AX_FDP_REPR_20191215092453_20191215092457_0004_0.tif', ] for file in files: res = 10 profile = { - "driver": "GTiff", - "dtype": "uint16", - "count": 239, - "crs": CRS.from_epsg(32634), - "transform": Affine( + 'driver': 'GTiff', + 'dtype': 'uint16', + 'count': 239, + 'crs': CRS.from_epsg(32634), + 'transform': Affine( 29.974884647651006, 0.0, 718687.5, 0.0, -29.97457627118644, 4503407.5 ), - "height": SIZE, - "width": SIZE, + 'height': SIZE, + 'width': SIZE, } Z = np.random.randint( - np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"] + np.iinfo(profile['dtype']).max, size=(SIZE, SIZE), dtype=profile['dtype'] ) - with rasterio.open(file, "w", **profile) as src: - for i in range(profile["count"]): + with rasterio.open(file, 'w', **profile) as src: + for i in range(profile['count']): src.write(Z, i + 1) diff --git a/tests/data/raster/data.py b/tests/data/raster/data.py index 304c9f96f16..c11722f7e42 100755 --- a/tests/data/raster/data.py +++ b/tests/data/raster/data.py @@ -16,7 +16,7 @@ def write_raster( res: int = RES[0], epsg: int = EPSG[0], - dtype: str = "uint8", + dtype: str = 'uint8', path: str | None = None, ) -> None: """Write a raster file. @@ -29,24 +29,24 @@ def write_raster( """ size = SIZE // res profile = { - "driver": "GTiff", - "dtype": dtype, - "count": 1, - "crs": f"epsg:{epsg}", - "transform": from_bounds(0, 0, SIZE, SIZE, size, size), - "height": size, - "width": size, - "nodata": 0, + 'driver': 'GTiff', + 'dtype': dtype, + 'count': 1, + 'crs': f'epsg:{epsg}', + 'transform': from_bounds(0, 0, SIZE, SIZE, size, size), + 'height': size, + 'width': size, + 'nodata': 0, } if path is None: - name = f"res_{res}_epsg_{epsg}" - path = os.path.join(name, f"{name}.tif") + name = f'res_{res}_epsg_{epsg}' + path = os.path.join(name, f'{name}.tif') directory = os.path.dirname(path) os.makedirs(directory, exist_ok=True) - with rio.open(path, "w", **profile) as f: + with rio.open(path, 'w', **profile) as f: x = np.ones((1, size, size)) f.write(x) @@ -59,21 +59,21 @@ def reproject_raster(res: int, src_epsg: int, dst_epsg: int) -> None: src_epsg: EPSG of source file. dst_epsg: EPSG of destination file. """ - src_name = f"res_{res}_epsg_{src_epsg}" - src_path = os.path.join(src_name, f"{src_name}.tif") + src_name = f'res_{res}_epsg_{src_epsg}' + src_path = os.path.join(src_name, f'{src_name}.tif') with rio.open(src_path) as src: - dst_crs = f"epsg:{dst_epsg}" + dst_crs = f'epsg:{dst_epsg}' transform, width, height = calculate_default_transform( src.crs, dst_crs, src.width, src.height, *src.bounds ) profile = src.profile.copy() profile.update( - {"crs": dst_crs, "transform": transform, "width": width, "height": height} + {'crs': dst_crs, 'transform': transform, 'width': width, 'height': height} ) - dst_name = f"res_{res}_epsg_{dst_epsg}" + dst_name = f'res_{res}_epsg_{dst_epsg}' os.makedirs(dst_name, exist_ok=True) - dst_path = os.path.join(dst_name, f"{dst_name}.tif") - with rio.open(dst_path, "w", **profile) as dst: + dst_path = os.path.join(dst_name, f'{dst_name}.tif') + with rio.open(dst_path, 'w', **profile) as dst: reproject( source=rio.band(src, 1), destination=rio.band(dst, 1), @@ -84,7 +84,7 @@ def reproject_raster(res: int, src_epsg: int, dst_epsg: int) -> None: ) -if __name__ == "__main__": +if __name__ == '__main__': for res in RES: src_epsg = EPSG[0] write_raster(res, src_epsg) @@ -92,8 +92,8 @@ def reproject_raster(res: int, src_epsg: int, dst_epsg: int) -> None: for dst_epsg in EPSG[1:]: reproject_raster(res, src_epsg, dst_epsg) - for dtype in ["uint16", "uint32"]: - path = os.path.join(dtype, f"{dtype}.tif") + for dtype in ['uint16', 'uint32']: + path = os.path.join(dtype, f'{dtype}.tif') write_raster(dtype=dtype, path=path) - with open(os.path.join(dtype, "corrupted.tif"), "w") as f: - f.write("not a tif file\n") + with open(os.path.join(dtype, 'corrupted.tif'), 'w') as f: + f.write('not a tif file\n') diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/data.py b/tests/data/ref_cloud_cover_detection_challenge_v1/data.py index a66f22b76cb..e8a771e0fa5 100755 --- a/tests/data/ref_cloud_cover_detection_challenge_v1/data.py +++ b/tests/data/ref_cloud_cover_detection_challenge_v1/data.py @@ -32,17 +32,17 @@ np.random.seed(0) SIZE = 512 -BANDS = ["B02", "B03", "B04", "B08"] +BANDS = ['B02', 'B03', 'B04', 'B08'] -SOURCE_COLLECTION_ID = "ref_cloud_cover_detection_challenge_v1_test_source" -SOURCE_ITEM_ID = "ref_cloud_cover_detection_challenge_v1_test_source_aaaa" -LABEL_COLLECTION_ID = "ref_cloud_cover_detection_challenge_v1_test_labels" -LABEL_ITEM_ID = "ref_cloud_cover_detection_challenge_v1_test_labels_aaaa" +SOURCE_COLLECTION_ID = 'ref_cloud_cover_detection_challenge_v1_test_source' +SOURCE_ITEM_ID = 'ref_cloud_cover_detection_challenge_v1_test_source_aaaa' +LABEL_COLLECTION_ID = 'ref_cloud_cover_detection_challenge_v1_test_labels' +LABEL_ITEM_ID = 'ref_cloud_cover_detection_challenge_v1_test_labels_aaaa' # geometry used by both source and label items TEST_GEOMETRY = { - "type": "Polygon", - "coordinates": [ + 'type': 'Polygon', + 'coordinates': [ [ [137.86580132892396, -29.52744848758255], [137.86450090473795, -29.481297003404038], @@ -63,14 +63,14 @@ # sentinel-2 bands for EO extension S2_BANDS = [ - Band.create(name="B02", common_name="blue", description="Blue"), - Band.create(name="B03", common_name="green", description="Green"), - Band.create(name="B04", common_name="red", description="Red"), - Band.create(name="B08", common_name="nir", description="NIR"), + Band.create(name='B02', common_name='blue', description='Blue'), + Band.create(name='B03', common_name='green', description='Green'), + Band.create(name='B04', common_name='red', description='Red'), + Band.create(name='B08', common_name='nir', description='NIR'), ] # class map for overviews -CLASS_COUNT_MAP = {"0": "no cloud", "1": "cloud"} +CLASS_COUNT_MAP = {'0': 'no cloud', '1': 'cloud'} # define the spatial and temporal extent of collections TEST_EXTENT = Extent( @@ -87,8 +87,8 @@ temporal=TemporalExtent( intervals=[ [ - dt.strptime("2018-02-18", "%Y-%m-%d"), - dt.strptime("2020-09-13", "%Y-%m-%d"), + dt.strptime('2018-02-18', '%Y-%m-%d'), + dt.strptime('2020-09-13', '%Y-%m-%d'), ] ] ), @@ -100,30 +100,30 @@ def create_raster(path: str, dtype: str, num_channels: int, collection: str) -> Path(os.path.split(path)[0]).mkdir(parents=True) profile = {} - profile["driver"] = "GTiff" - profile["dtype"] = dtype - profile["count"] = num_channels - profile["crs"] = CRS.from_epsg(32753) - profile["transform"] = Affine(1.0, 0.0, 777760.0, 0.0, -10.0, 6735270.0) - profile["height"] = SIZE - profile["width"] = SIZE - profile["compress"] = "lzw" - profile["predictor"] = 2 - - if collection == "source": - if "float" in profile["dtype"]: - Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"]) + profile['driver'] = 'GTiff' + profile['dtype'] = dtype + profile['count'] = num_channels + profile['crs'] = CRS.from_epsg(32753) + profile['transform'] = Affine(1.0, 0.0, 777760.0, 0.0, -10.0, 6735270.0) + profile['height'] = SIZE + profile['width'] = SIZE + profile['compress'] = 'lzw' + profile['predictor'] = 2 + + if collection == 'source': + if 'float' in profile['dtype']: + Z = np.random.randn(SIZE, SIZE).astype(profile['dtype']) else: Z = np.random.randint( - np.iinfo(profile["dtype"]).max, + np.iinfo(profile['dtype']).max, size=(SIZE, SIZE), - dtype=profile["dtype"], + dtype=profile['dtype'], ) - elif collection == "labels": - Z = np.random.randint(0, 2, (SIZE, SIZE)).astype(profile["dtype"]) + elif collection == 'labels': + Z = np.random.randint(0, 2, (SIZE, SIZE)).astype(profile['dtype']) - with rasterio.open(path, "w", **profile) as src: - for i in range(1, profile["count"] + 1): + with rasterio.open(path, 'w', **profile) as src: + for i in range(1, profile['count'] + 1): src.write(Z, i) @@ -133,14 +133,14 @@ def create_source_item() -> Item: id=SOURCE_ITEM_ID, geometry=TEST_GEOMETRY, bbox=TEST_BBOX, - datetime=dt.strptime("2020-06-03", "%Y-%m-%d"), + datetime=dt.strptime('2020-06-03', '%Y-%m-%d'), properties={}, ) # add Asset with EO Extension for each S2 band for band in BANDS: img_path = os.path.join( - os.getcwd(), SOURCE_COLLECTION_ID, SOURCE_ITEM_ID, f"{band}.tif" + os.getcwd(), SOURCE_COLLECTION_ID, SOURCE_ITEM_ID, f'{band}.tif' ) image_asset = Asset(href=img_path, media_type=MediaType.GEOTIFF) eo_asset_ext = EOExtension.ext(image_asset) @@ -157,8 +157,8 @@ def create_source_item() -> Item: def get_class_label_list(overview: LabelOverview) -> LabelClasses: - label_list = [d["name"] for d in overview.properties["counts"]] - label_classes = LabelClasses.create(classes=label_list, name="labels") + label_list = [d['name'] for d in overview.properties['counts']] + label_classes = LabelClasses.create(classes=label_list, name='labels') return label_classes @@ -189,7 +189,7 @@ def get_item_class_overview(label_type: LabelType, asset_path: str) -> LabelOver count_list.append(label_count) overview = LabelOverview(properties={}) - overview.apply(property_key="labels", counts=count_list) + overview.apply(property_key='labels', counts=count_list) return overview @@ -200,7 +200,7 @@ def create_label_item() -> Item: id=LABEL_ITEM_ID, geometry=TEST_GEOMETRY, bbox=TEST_BBOX, - datetime=dt.strptime("2020-06-03", "%Y-%m-%d"), + datetime=dt.strptime('2020-06-03', '%Y-%m-%d'), properties={}, ) @@ -209,39 +209,39 @@ def create_label_item() -> Item: label_ext = LabelExtension.ext(test_label_item, add_if_missing=True) label_ext.apply( - label_description="Sentinel-2 Cloud Cover Segmentation Test Labels", + label_description='Sentinel-2 Cloud Cover Segmentation Test Labels', label_type=LabelType.RASTER, label_classes=[label_list], label_overviews=[label_overview], ) label_asset = Asset(href=label_path, media_type=MediaType.GEOTIFF) - test_label_item.add_asset(key="labels", asset=label_asset) + test_label_item.add_asset(key='labels', asset=label_asset) return test_label_item -if __name__ == "__main__": +if __name__ == '__main__': # create a geotiff for each s2 band for b in BANDS: tif_path = os.path.join( - os.getcwd(), SOURCE_COLLECTION_ID, SOURCE_ITEM_ID, f"{b}.tif" + os.getcwd(), SOURCE_COLLECTION_ID, SOURCE_ITEM_ID, f'{b}.tif' ) - create_raster(tif_path, "uint8", 1, "source") + create_raster(tif_path, 'uint8', 1, 'source') # create a geotiff for label label_path = os.path.join( - os.getcwd(), LABEL_COLLECTION_ID, LABEL_ITEM_ID, "labels.tif" + os.getcwd(), LABEL_COLLECTION_ID, LABEL_ITEM_ID, 'labels.tif' ) - create_raster(label_path, "uint8", 1, "labels") + create_raster(label_path, 'uint8', 1, 'labels') # instantiate the source Collection test_source_collection = Collection( id=SOURCE_COLLECTION_ID, - description="Test Source Collection for Torchgo Cloud Cover Detection Dataset", + description='Test Source Collection for Torchgo Cloud Cover Detection Dataset', extent=TEST_EXTENT, catalog_type=CatalogType.RELATIVE_PUBLISHED, - license="CC-BY-4.0", + license='CC-BY-4.0', ) source_item = create_source_item() @@ -256,15 +256,15 @@ def create_label_item() -> Item: # instantiate the label Collection test_label_collection = Collection( id=LABEL_COLLECTION_ID, - description="Test Label Collection for Torchgo Cloud Cover Detection Dataset", + description='Test Label Collection for Torchgo Cloud Cover Detection Dataset', extent=TEST_EXTENT, catalog_type=CatalogType.RELATIVE_PUBLISHED, - license="CC-BY-4.0", + license='CC-BY-4.0', ) label_item = create_label_item() label_item.add_link( - Link(rel="source", target=source_item, media_type=MediaType.GEOTIFF) + Link(rel='source', target=source_item, media_type=MediaType.GEOTIFF) ) test_label_collection.add_item(label_item) diff --git a/tests/data/reforestree/data.py b/tests/data/reforestree/data.py index ee5515e49b1..6bb6c5ce2f5 100755 --- a/tests/data/reforestree/data.py +++ b/tests/data/reforestree/data.py @@ -16,26 +16,26 @@ np.random.seed(0) PATHS = { - "images": [ - "tiles/Site1/Site1_RGB_0_0_0_4000_4000.png", - "tiles/Site2/Site2_RGB_0_0_0_4000_4000.png", + 'images': [ + 'tiles/Site1/Site1_RGB_0_0_0_4000_4000.png', + 'tiles/Site2/Site2_RGB_0_0_0_4000_4000.png', ], - "annotation": "mapping/final_dataset.csv", + 'annotation': 'mapping/final_dataset.csv', } def create_annotation(path: str, img_paths: list[str]) -> None: - cols = ["img_path", "xmin", "ymin", "xmax", "ymax", "group", "AGB"] + cols = ['img_path', 'xmin', 'ymin', 'xmax', 'ymax', 'group', 'AGB'] data = [] for img_path in img_paths: data.append( - [os.path.basename(img_path), 0, 0, SIZE / 2, SIZE / 2, "banana", 6.75] + [os.path.basename(img_path), 0, 0, SIZE / 2, SIZE / 2, 'banana', 6.75] ) data.append( - [os.path.basename(img_path), SIZE / 2, SIZE / 2, SIZE, SIZE, "cacao", 6.75] + [os.path.basename(img_path), SIZE / 2, SIZE / 2, SIZE, SIZE, 'cacao', 6.75] ) - with open(path, "w", newline="") as f: + with open(path, 'w', newline='') as f: writer = csv.writer(f) writer.writerow(cols) writer.writerows(data) @@ -43,32 +43,32 @@ def create_annotation(path: str, img_paths: list[str]) -> None: def create_img(path: str) -> None: Z = np.random.rand(SIZE, SIZE, 3) * 255 - img = Image.fromarray(Z.astype("uint8")).convert("RGB") + img = Image.fromarray(Z.astype('uint8')).convert('RGB') img.save(path) -if __name__ == "__main__": - data_root = "reforesTree" +if __name__ == '__main__': + data_root = 'reforesTree' # remove old data if os.path.isdir(data_root): shutil.rmtree(data_root) # create imagery - for path in PATHS["images"]: + for path in PATHS['images']: os.makedirs(os.path.join(data_root, os.path.dirname(path)), exist_ok=True) create_img(os.path.join(data_root, path)) # create annotations os.makedirs( - os.path.join(data_root, os.path.dirname(PATHS["annotation"])), exist_ok=True + os.path.join(data_root, os.path.dirname(PATHS['annotation'])), exist_ok=True ) - create_annotation(os.path.join(data_root, PATHS["annotation"]), PATHS["images"]) + create_annotation(os.path.join(data_root, PATHS['annotation']), PATHS['images']) # compress data - shutil.make_archive(data_root, "zip", data_root) + shutil.make_archive(data_root, 'zip', data_root) # Compute checksums - with open(data_root + ".zip", "rb") as f: + with open(data_root + '.zip', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{data_root}: {md5}") + print(f'{data_root}: {md5}') diff --git a/tests/data/rwanda_field_boundary/data.py b/tests/data/rwanda_field_boundary/data.py index 7a23b385cf6..a3522e8c962 100644 --- a/tests/data/rwanda_field_boundary/data.py +++ b/tests/data/rwanda_field_boundary/data.py @@ -10,8 +10,8 @@ import numpy as np import rasterio -dates = ("2021_03", "2021_04", "2021_08", "2021_10", "2021_11", "2021_12") -all_bands = ("B01", "B02", "B03", "B04") +dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12') +all_bands = ('B01', 'B02', 'B03', 'B04') SIZE = 32 NUM_SAMPLES = 5 @@ -20,82 +20,82 @@ def create_mask(fn: str) -> None: profile = { - "driver": "GTiff", - "dtype": "uint8", - "nodata": 0.0, - "width": SIZE, - "height": SIZE, - "count": 1, - "crs": "epsg:3857", - "compress": "lzw", - "predictor": 2, - "transform": rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0), - "blockysize": 32, - "tiled": False, - "interleave": "band", + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': 0.0, + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': 'epsg:3857', + 'compress': 'lzw', + 'predictor': 2, + 'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0), + 'blockysize': 32, + 'tiled': False, + 'interleave': 'band', } - with rasterio.open(fn, "w", **profile) as f: + with rasterio.open(fn, 'w', **profile) as f: f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint8), 1) def create_img(fn: str) -> None: profile = { - "driver": "GTiff", - "dtype": "uint16", - "nodata": 0.0, - "width": SIZE, - "height": SIZE, - "count": 1, - "crs": "epsg:3857", - "compress": "lzw", - "predictor": 2, - "blockysize": 16, - "transform": rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0), - "tiled": False, - "interleave": "band", + 'driver': 'GTiff', + 'dtype': 'uint16', + 'nodata': 0.0, + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': 'epsg:3857', + 'compress': 'lzw', + 'predictor': 2, + 'blockysize': 16, + 'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0), + 'tiled': False, + 'interleave': 'band', } - with rasterio.open(fn, "w", **profile) as f: + with rasterio.open(fn, 'w', **profile) as f: f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint16), 1) -if __name__ == "__main__": +if __name__ == '__main__': # Train and test images - for split in ("train", "test"): + for split in ('train', 'test'): for i in range(NUM_SAMPLES): for date in dates: directory = os.path.join( - f"nasa_rwanda_field_boundary_competition_source_{split}", - f"nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}", # noqa: E501 + f'nasa_rwanda_field_boundary_competition_source_{split}', + f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501 ) os.makedirs(directory, exist_ok=True) for band in all_bands: - create_img(os.path.join(directory, f"{band}.tif")) + create_img(os.path.join(directory, f'{band}.tif')) # Create collections.json, this isn't used by the dataset but is checked to # exist with open( - f"nasa_rwanda_field_boundary_competition_source_{split}/collections.json", - "w", + f'nasa_rwanda_field_boundary_competition_source_{split}/collections.json', + 'w', ) as f: - f.write("Not used") + f.write('Not used') # Train labels for i in range(NUM_SAMPLES): directory = os.path.join( - "nasa_rwanda_field_boundary_competition_labels_train", - f"nasa_rwanda_field_boundary_competition_labels_train_{i:02d}", + 'nasa_rwanda_field_boundary_competition_labels_train', + f'nasa_rwanda_field_boundary_competition_labels_train_{i:02d}', ) os.makedirs(directory, exist_ok=True) - create_mask(os.path.join(directory, "raster_labels.tif")) + create_mask(os.path.join(directory, 'raster_labels.tif')) # Create directories and compute checksums for filename in [ - "nasa_rwanda_field_boundary_competition_source_train", - "nasa_rwanda_field_boundary_competition_source_test", - "nasa_rwanda_field_boundary_competition_labels_train", + 'nasa_rwanda_field_boundary_competition_source_train', + 'nasa_rwanda_field_boundary_competition_source_test', + 'nasa_rwanda_field_boundary_competition_labels_train', ]: - shutil.make_archive(filename, "gztar", ".", filename) + shutil.make_archive(filename, 'gztar', '.', filename) # Compute checksums - with open(f"{filename}.tar.gz", "rb") as f: + with open(f'{filename}.tar.gz', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{filename}: {md5}") + print(f'{filename}: {md5}') diff --git a/tests/data/seasonet/data.py b/tests/data/seasonet/data.py index 6befc1fd4ca..68fa8ffe397 100644 --- a/tests/data/seasonet/data.py +++ b/tests/data/seasonet/data.py @@ -15,40 +15,40 @@ np.random.seed(0) meta = { - "driver": "GTiff", - "nodata": None, - "crs": CRS.from_epsg(32632), - "transform": Affine(10.0, 0.0, 664800.0, 0.0, -10.0, 5342400.0), - "compress": "zstd", + 'driver': 'GTiff', + 'nodata': None, + 'crs': CRS.from_epsg(32632), + 'transform': Affine(10.0, 0.0, 664800.0, 0.0, -10.0, 5342400.0), + 'compress': 'zstd', } -bands = ["10m_RGB", "10m_IR", "20m", "60m", "labels"] -count = {"10m_RGB": 3, "10m_IR": 1, "20m": 6, "60m": 2, "labels": 1} +bands = ['10m_RGB', '10m_IR', '20m', '60m', 'labels'] +count = {'10m_RGB': 3, '10m_IR': 1, '20m': 6, '60m': 2, 'labels': 1} dtype = { - "10m_RGB": np.uint16, - "10m_IR": np.uint16, - "20m": np.uint16, - "60m": np.uint16, - "labels": np.uint8, + '10m_RGB': np.uint16, + '10m_IR': np.uint16, + '20m': np.uint16, + '60m': np.uint16, + 'labels': np.uint8, } -size = {"10m_RGB": 120, "10m_IR": 120, "20m": 60, "60m": 20, "labels": 120} -start = {"10m_RGB": 0, "10m_IR": 0, "20m": 0, "60m": 0, "labels": 1} +size = {'10m_RGB': 120, '10m_IR': 120, '20m': 60, '60m': 20, 'labels': 120} +start = {'10m_RGB': 0, '10m_IR': 0, '20m': 0, '60m': 0, 'labels': 1} stop = { - "10m_RGB": np.iinfo(np.uint16).max, - "10m_IR": np.iinfo(np.uint16).max, - "20m": np.iinfo(np.uint16).max, - "60m": np.iinfo(np.uint16).max, - "labels": 34, + '10m_RGB': np.iinfo(np.uint16).max, + '10m_IR': np.iinfo(np.uint16).max, + '20m': np.iinfo(np.uint16).max, + '60m': np.iinfo(np.uint16).max, + 'labels': 34, } meta_lines = [ - "Index,Season,Grid,Latitude,Longitude,Satellite,Year,Month,Day," - "Hour,Minute,Second,Clouds,Snow,Classes,SLRAUM,RTYP3,KTYP4,Path\n" + 'Index,Season,Grid,Latitude,Longitude,Satellite,Year,Month,Day,' + 'Hour,Minute,Second,Clouds,Snow,Classes,SLRAUM,RTYP3,KTYP4,Path\n' ] -seasons = ["spring", "summer", "fall", "winter", "snow"] +seasons = ['spring', 'summer', 'fall', 'winter', 'snow'] grids = [1, 2] name_comps = [ - ["32UME", "2018", "04", "18", "T", "10", "40", "21", "53", "928425", "7", "503876"], - ["32TMT", "2019", "02", "14", "T", "10", "31", "29", "47", "793488", "7", "808487"], + ['32UME', '2018', '04', '18', 'T', '10', '40', '21', '53', '928425', '7', '503876'], + ['32TMT', '2019', '02', '14', 'T', '10', '31', '29', '47', '793488', '7', '808487'], ] index = 0 for season in seasons: @@ -56,7 +56,7 @@ if os.path.exists(season): shutil.rmtree(season) - archive = f"{season}.zip" + archive = f'{season}.zip' # Remove old data if os.path.exists(archive): @@ -64,16 +64,16 @@ for grid, comp in zip(grids, name_comps): file_name = f"{comp[0]}_{''.join(comp[1:8])}_{'_'.join(comp[8:])}" - dir = os.path.join(season, f"grid{grid}", file_name) + dir = os.path.join(season, f'grid{grid}', file_name) os.makedirs(dir) # Random images for band in bands: - meta["count"] = count[band] - meta["dtype"] = dtype[band] - meta["width"] = meta["height"] = size[band] + meta['count'] = count[band] + meta['dtype'] = dtype[band] + meta['width'] = meta['height'] = size[band] with rasterio.open( - os.path.join(dir, f"{file_name}_{band}.tif"), "w", **meta + os.path.join(dir, f'{file_name}_{band}.tif'), 'w', **meta ) as f: for j in range(1, count[band] + 1): data = np.random.randint( @@ -86,9 +86,9 @@ index, season.capitalize(), grid, - f"{comp[8]}.{comp[9]}", - f"{comp[10]}.{comp[11]}", - "A", + f'{comp[8]}.{comp[9]}', + f'{comp[10]}.{comp[11]}', + 'A', comp[1], comp[2], comp[3], @@ -103,39 +103,39 @@ 1, dir, ] - meta_lines.append(",".join(map(str, meta_entries)) + "\n") + meta_lines.append(','.join(map(str, meta_entries)) + '\n') index += 1 # Create archives - shutil.make_archive(season, "zip", ".", season) + shutil.make_archive(season, 'zip', '.', season) # Compute checksums - with open(archive, "rb") as f: + with open(archive, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{season}: {repr(md5)}") + print(f'{season}: {repr(md5)}') # Write meta.csv -with open("meta.csv", "w") as f: +with open('meta.csv', 'w') as f: f.writelines(meta_lines) # Compute checksums -with open("meta.csv", "rb") as f: +with open('meta.csv', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"meta.csv: {repr(md5)}") + print(f'meta.csv: {repr(md5)}') -os.makedirs("splits", exist_ok=True) +os.makedirs('splits', exist_ok=True) -for split in ["train", "val", "test"]: - filename = f"{split}.csv" +for split in ['train', 'val', 'test']: + filename = f'{split}.csv' # Create file list - with open(os.path.join("splits", filename), "w") as f: + with open(os.path.join('splits', filename), 'w') as f: for i in range(index): - f.write(str(i) + "\n") + f.write(str(i) + '\n') -shutil.make_archive("splits", "zip", ".", "splits") +shutil.make_archive('splits', 'zip', '.', 'splits') # Compute checksums -with open("splits.zip", "rb") as f: +with open('splits.zip', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"splits: {repr(md5)}") + print(f'splits: {repr(md5)}') diff --git a/tests/data/sen12ms/data.py b/tests/data/sen12ms/data.py index e3fd63ccd9d..6fb899b058d 100755 --- a/tests/data/sen12ms/data.py +++ b/tests/data/sen12ms/data.py @@ -17,66 +17,66 @@ np.random.seed(0) meta = { - "driver": "GTiff", - "nodata": None, - "width": SIZE, - "height": SIZE, - "crs": CRS.from_epsg(32737), - "transform": Affine(10.0, 0.0, 390772.3389928384, 0.0, -10.0, 8114182.836060452), + 'driver': 'GTiff', + 'nodata': None, + 'width': SIZE, + 'height': SIZE, + 'crs': CRS.from_epsg(32737), + 'transform': Affine(10.0, 0.0, 390772.3389928384, 0.0, -10.0, 8114182.836060452), } -count = {"lc": 4, "s1": 2, "s2": 13} -dtype = {"lc": np.uint16, "s1": np.float32, "s2": np.uint16} -stop = {"lc": 11, "s1": np.iinfo(np.uint16).max, "s2": np.iinfo(np.uint16).max} +count = {'lc': 4, 's1': 2, 's2': 13} +dtype = {'lc': np.uint16, 's1': np.float32, 's2': np.uint16} +stop = {'lc': 11, 's1': np.iinfo(np.uint16).max, 's2': np.iinfo(np.uint16).max} file_list = [] -seasons = ["ROIs1158_spring", "ROIs1868_summer", "ROIs1970_fall", "ROIs2017_winter"] +seasons = ['ROIs1158_spring', 'ROIs1868_summer', 'ROIs1970_fall', 'ROIs2017_winter'] for season in seasons: # Remove old data if os.path.exists(season): shutil.rmtree(season) - for source in ["lc", "s1", "s2"]: - tarball = f"{season}_{source}.tar.gz" + for source in ['lc', 's1', 's2']: + tarball = f'{season}_{source}.tar.gz' # Remove old data if os.path.exists(tarball): os.remove(tarball) - directory = os.path.join(season, f"{source}_1") + directory = os.path.join(season, f'{source}_1') os.makedirs(directory) # Random images for i in range(1, 3): - filename = f"{season}_{source}_1_p{i}.tif" - meta["count"] = count[source] - meta["dtype"] = dtype[source] - with rasterio.open(os.path.join(directory, filename), "w", **meta) as f: + filename = f'{season}_{source}_1_p{i}.tif' + meta['count'] = count[source] + meta['dtype'] = dtype[source] + with rasterio.open(os.path.join(directory, filename), 'w', **meta) as f: for j in range(1, count[source] + 1): data = np.random.randint(stop[source], size=(SIZE, SIZE)).astype( dtype[source] ) f.write(data, j) - if source == "s2": + if source == 's2': file_list.append(filename) # Create tarballs - shutil.make_archive(f"{season}_{source}", "gztar", ".", directory) + shutil.make_archive(f'{season}_{source}', 'gztar', '.', directory) # Compute checksums - with open(tarball, "rb") as f: + with open(tarball, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(repr(md5) + ",") + print(repr(md5) + ',') -for split in ["train", "test"]: - filename = f"{split}_list.txt" +for split in ['train', 'test']: + filename = f'{split}_list.txt' # Create file list - with open(filename, "w") as f: + with open(filename, 'w') as f: for fname in file_list: - f.write(f"{fname}\n") + f.write(f'{fname}\n') # Compute checksums - with open(filename, "rb") as f: + with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(repr(md5) + ",") + print(repr(md5) + ',') diff --git a/tests/data/sentinel1/data.py b/tests/data/sentinel1/data.py index edfa589a2dc..e7ff885cbfc 100755 --- a/tests/data/sentinel1/data.py +++ b/tests/data/sentinel1/data.py @@ -14,34 +14,34 @@ np.random.seed(0) -FILENAME_HIERARCHY = dict[str, "FILENAME_HIERARCHY"] | list[str] +FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str] filenames: FILENAME_HIERARCHY = { # ASF DAAC - "S1A_IW_20221204T161641_DVR_RTC30_G_gpuned_1AE1": [ - "S1A_IW_20221204T161641_DVR_RTC30_G_gpuned_1AE1_VH.tif", - "S1A_IW_20221204T161641_DVR_RTC30_G_gpuned_1AE1_VV.tif", + 'S1A_IW_20221204T161641_DVR_RTC30_G_gpuned_1AE1': [ + 'S1A_IW_20221204T161641_DVR_RTC30_G_gpuned_1AE1_VH.tif', + 'S1A_IW_20221204T161641_DVR_RTC30_G_gpuned_1AE1_VV.tif', ], - "S1B_IW_20161021T042948_DHP_RTC30_G_gpuned_A784": [ - "S1B_IW_20161021T042948_DHP_RTC30_G_gpuned_A784_HH.tif", - "S1B_IW_20161021T042948_DHP_RTC30_G_gpuned_A784_HV.tif", + 'S1B_IW_20161021T042948_DHP_RTC30_G_gpuned_A784': [ + 'S1B_IW_20161021T042948_DHP_RTC30_G_gpuned_A784_HH.tif', + 'S1B_IW_20161021T042948_DHP_RTC30_G_gpuned_A784_HV.tif', ], } def create_file(path: str) -> None: profile = {} - profile["driver"] = "GTiff" - profile["dtype"] = "float32" - profile["count"] = 1 - profile["crs"] = CRS.from_epsg(32605) - profile["transform"] = Affine(30.0, 0.0, 79860.0, 0.0, -30.0, 2298240.0) - profile["height"] = SIZE - profile["width"] = SIZE + profile['driver'] = 'GTiff' + profile['dtype'] = 'float32' + profile['count'] = 1 + profile['crs'] = CRS.from_epsg(32605) + profile['transform'] = Affine(30.0, 0.0, 79860.0, 0.0, -30.0, 2298240.0) + profile['height'] = SIZE + profile['width'] = SIZE - Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"]) + Z = np.random.randn(SIZE, SIZE).astype(profile['dtype']) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(Z, 1) @@ -59,5 +59,5 @@ def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: create_file(path) -if __name__ == "__main__": - create_directory(".", filenames) +if __name__ == '__main__': + create_directory('.', filenames) diff --git a/tests/data/sentinel2/data.py b/tests/data/sentinel2/data.py index e31628a9998..426195a4869 100755 --- a/tests/data/sentinel2/data.py +++ b/tests/data/sentinel2/data.py @@ -14,128 +14,128 @@ np.random.seed(0) -FILENAME_HIERARCHY = dict[str, "FILENAME_HIERARCHY"] | list[str] +FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str] filenames: FILENAME_HIERARCHY = { # USGS Earth Explorer - "S2A_MSIL1C_20220412T162841_N0400_R083_T16TFM_20220412T202300.SAFE": { - "GRANULE": { - "L1C_T16TFM_A035544_20220412T163959": { - "IMG_DATA": [ - "T16TFM_20220412T162841_B01.jp2", - "T16TFM_20220412T162841_B02.jp2", - "T16TFM_20220412T162841_B03.jp2", - "T16TFM_20220412T162841_B04.jp2", - "T16TFM_20220412T162841_B05.jp2", - "T16TFM_20220412T162841_B06.jp2", - "T16TFM_20220412T162841_B07.jp2", - "T16TFM_20220412T162841_B08.jp2", - "T16TFM_20220412T162841_B09.jp2", - "T16TFM_20220412T162841_B10.jp2", - "T16TFM_20220412T162841_B11.jp2", - "T16TFM_20220412T162841_B12.jp2", - "T16TFM_20220412T162841_B8A.jp2", - "T16TFM_20220412T162841_TCI.jp2", - "T16TFM_20190412T162841_B01.jp2", - "T16TFM_20190412T162841_B02.jp2", - "T16TFM_20190412T162841_B03.jp2", - "T16TFM_20190412T162841_B04.jp2", - "T16TFM_20190412T162841_B05.jp2", - "T16TFM_20190412T162841_B06.jp2", - "T16TFM_20190412T162841_B07.jp2", - "T16TFM_20190412T162841_B08.jp2", - "T16TFM_20190412T162841_B09.jp2", - "T16TFM_20190412T162841_B10.jp2", - "T16TFM_20190412T162841_B11.jp2", - "T16TFM_20190412T162841_B12.jp2", - "T16TFM_20190412T162841_B8A.jp2", - "T16TFM_20190412T162841_TCI.jp2", + 'S2A_MSIL1C_20220412T162841_N0400_R083_T16TFM_20220412T202300.SAFE': { + 'GRANULE': { + 'L1C_T16TFM_A035544_20220412T163959': { + 'IMG_DATA': [ + 'T16TFM_20220412T162841_B01.jp2', + 'T16TFM_20220412T162841_B02.jp2', + 'T16TFM_20220412T162841_B03.jp2', + 'T16TFM_20220412T162841_B04.jp2', + 'T16TFM_20220412T162841_B05.jp2', + 'T16TFM_20220412T162841_B06.jp2', + 'T16TFM_20220412T162841_B07.jp2', + 'T16TFM_20220412T162841_B08.jp2', + 'T16TFM_20220412T162841_B09.jp2', + 'T16TFM_20220412T162841_B10.jp2', + 'T16TFM_20220412T162841_B11.jp2', + 'T16TFM_20220412T162841_B12.jp2', + 'T16TFM_20220412T162841_B8A.jp2', + 'T16TFM_20220412T162841_TCI.jp2', + 'T16TFM_20190412T162841_B01.jp2', + 'T16TFM_20190412T162841_B02.jp2', + 'T16TFM_20190412T162841_B03.jp2', + 'T16TFM_20190412T162841_B04.jp2', + 'T16TFM_20190412T162841_B05.jp2', + 'T16TFM_20190412T162841_B06.jp2', + 'T16TFM_20190412T162841_B07.jp2', + 'T16TFM_20190412T162841_B08.jp2', + 'T16TFM_20190412T162841_B09.jp2', + 'T16TFM_20190412T162841_B10.jp2', + 'T16TFM_20190412T162841_B11.jp2', + 'T16TFM_20190412T162841_B12.jp2', + 'T16TFM_20190412T162841_B8A.jp2', + 'T16TFM_20190412T162841_TCI.jp2', ] } } }, # Copernicus Open Access Hub - "S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE": { - "GRANULE": { - "L2A_T26EMU_A035569_20220414T110747": { - "IMG_DATA": { - "R10m": [ - "T26EMU_20220414T110751_AOT_10m.jp2", - "T26EMU_20220414T110751_B02_10m.jp2", - "T26EMU_20220414T110751_B03_10m.jp2", - "T26EMU_20220414T110751_B04_10m.jp2", - "T26EMU_20220414T110751_B08_10m.jp2", - "T26EMU_20220414T110751_TCI_10m.jp2", - "T26EMU_20220414T110751_WVP_10m.jp2", - "T26EMU_20190414T110751_AOT_10m.jp2", - "T26EMU_20190414T110751_B02_10m.jp2", - "T26EMU_20190414T110751_B03_10m.jp2", - "T26EMU_20190414T110751_B04_10m.jp2", - "T26EMU_20190414T110751_B08_10m.jp2", - "T26EMU_20190414T110751_TCI_10m.jp2", - "T26EMU_20190414T110751_WVP_10m.jp2", + 'S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE': { + 'GRANULE': { + 'L2A_T26EMU_A035569_20220414T110747': { + 'IMG_DATA': { + 'R10m': [ + 'T26EMU_20220414T110751_AOT_10m.jp2', + 'T26EMU_20220414T110751_B02_10m.jp2', + 'T26EMU_20220414T110751_B03_10m.jp2', + 'T26EMU_20220414T110751_B04_10m.jp2', + 'T26EMU_20220414T110751_B08_10m.jp2', + 'T26EMU_20220414T110751_TCI_10m.jp2', + 'T26EMU_20220414T110751_WVP_10m.jp2', + 'T26EMU_20190414T110751_AOT_10m.jp2', + 'T26EMU_20190414T110751_B02_10m.jp2', + 'T26EMU_20190414T110751_B03_10m.jp2', + 'T26EMU_20190414T110751_B04_10m.jp2', + 'T26EMU_20190414T110751_B08_10m.jp2', + 'T26EMU_20190414T110751_TCI_10m.jp2', + 'T26EMU_20190414T110751_WVP_10m.jp2', ], - "R20m": [ - "T26EMU_20220414T110751_AOT_20m.jp2", - "T26EMU_20220414T110751_B01_20m.jp2", - "T26EMU_20220414T110751_B02_20m.jp2", - "T26EMU_20220414T110751_B03_20m.jp2", - "T26EMU_20220414T110751_B04_20m.jp2", - "T26EMU_20220414T110751_B05_20m.jp2", - "T26EMU_20220414T110751_B06_20m.jp2", - "T26EMU_20220414T110751_B07_20m.jp2", - "T26EMU_20220414T110751_B11_20m.jp2", - "T26EMU_20220414T110751_B12_20m.jp2", - "T26EMU_20220414T110751_B8A_20m.jp2", - "T26EMU_20220414T110751_SCL_20m.jp2", - "T26EMU_20220414T110751_TCI_20m.jp2", - "T26EMU_20220414T110751_WVP_20m.jp2", - "T26EMU_20190414T110751_AOT_20m.jp2", - "T26EMU_20190414T110751_B01_20m.jp2", - "T26EMU_20190414T110751_B02_20m.jp2", - "T26EMU_20190414T110751_B03_20m.jp2", - "T26EMU_20190414T110751_B04_20m.jp2", - "T26EMU_20190414T110751_B05_20m.jp2", - "T26EMU_20190414T110751_B06_20m.jp2", - "T26EMU_20190414T110751_B07_20m.jp2", - "T26EMU_20190414T110751_B11_20m.jp2", - "T26EMU_20190414T110751_B12_20m.jp2", - "T26EMU_20190414T110751_B8A_20m.jp2", - "T26EMU_20190414T110751_SCL_20m.jp2", - "T26EMU_20190414T110751_TCI_20m.jp2", - "T26EMU_20190414T110751_WVP_20m.jp2", + 'R20m': [ + 'T26EMU_20220414T110751_AOT_20m.jp2', + 'T26EMU_20220414T110751_B01_20m.jp2', + 'T26EMU_20220414T110751_B02_20m.jp2', + 'T26EMU_20220414T110751_B03_20m.jp2', + 'T26EMU_20220414T110751_B04_20m.jp2', + 'T26EMU_20220414T110751_B05_20m.jp2', + 'T26EMU_20220414T110751_B06_20m.jp2', + 'T26EMU_20220414T110751_B07_20m.jp2', + 'T26EMU_20220414T110751_B11_20m.jp2', + 'T26EMU_20220414T110751_B12_20m.jp2', + 'T26EMU_20220414T110751_B8A_20m.jp2', + 'T26EMU_20220414T110751_SCL_20m.jp2', + 'T26EMU_20220414T110751_TCI_20m.jp2', + 'T26EMU_20220414T110751_WVP_20m.jp2', + 'T26EMU_20190414T110751_AOT_20m.jp2', + 'T26EMU_20190414T110751_B01_20m.jp2', + 'T26EMU_20190414T110751_B02_20m.jp2', + 'T26EMU_20190414T110751_B03_20m.jp2', + 'T26EMU_20190414T110751_B04_20m.jp2', + 'T26EMU_20190414T110751_B05_20m.jp2', + 'T26EMU_20190414T110751_B06_20m.jp2', + 'T26EMU_20190414T110751_B07_20m.jp2', + 'T26EMU_20190414T110751_B11_20m.jp2', + 'T26EMU_20190414T110751_B12_20m.jp2', + 'T26EMU_20190414T110751_B8A_20m.jp2', + 'T26EMU_20190414T110751_SCL_20m.jp2', + 'T26EMU_20190414T110751_TCI_20m.jp2', + 'T26EMU_20190414T110751_WVP_20m.jp2', ], - "R60m": [ - "T26EMU_20220414T110751_AOT_60m.jp2", - "T26EMU_20220414T110751_B01_60m.jp2", - "T26EMU_20220414T110751_B02_60m.jp2", - "T26EMU_20220414T110751_B03_60m.jp2", - "T26EMU_20220414T110751_B04_60m.jp2", - "T26EMU_20220414T110751_B05_60m.jp2", - "T26EMU_20220414T110751_B06_60m.jp2", - "T26EMU_20220414T110751_B07_60m.jp2", - "T26EMU_20220414T110751_B09_60m.jp2", - "T26EMU_20220414T110751_B11_60m.jp2", - "T26EMU_20220414T110751_B12_60m.jp2", - "T26EMU_20220414T110751_B8A_60m.jp2", - "T26EMU_20220414T110751_SCL_60m.jp2", - "T26EMU_20220414T110751_TCI_60m.jp2", - "T26EMU_20220414T110751_WVP_60m.jp2", - "T26EMU_20190414T110751_AOT_60m.jp2", - "T26EMU_20190414T110751_B01_60m.jp2", - "T26EMU_20190414T110751_B02_60m.jp2", - "T26EMU_20190414T110751_B03_60m.jp2", - "T26EMU_20190414T110751_B04_60m.jp2", - "T26EMU_20190414T110751_B05_60m.jp2", - "T26EMU_20190414T110751_B06_60m.jp2", - "T26EMU_20190414T110751_B07_60m.jp2", - "T26EMU_20190414T110751_B09_60m.jp2", - "T26EMU_20190414T110751_B11_60m.jp2", - "T26EMU_20190414T110751_B12_60m.jp2", - "T26EMU_20190414T110751_B8A_60m.jp2", - "T26EMU_20190414T110751_SCL_60m.jp2", - "T26EMU_20190414T110751_TCI_60m.jp2", - "T26EMU_20190414T110751_WVP_60m.jp2", + 'R60m': [ + 'T26EMU_20220414T110751_AOT_60m.jp2', + 'T26EMU_20220414T110751_B01_60m.jp2', + 'T26EMU_20220414T110751_B02_60m.jp2', + 'T26EMU_20220414T110751_B03_60m.jp2', + 'T26EMU_20220414T110751_B04_60m.jp2', + 'T26EMU_20220414T110751_B05_60m.jp2', + 'T26EMU_20220414T110751_B06_60m.jp2', + 'T26EMU_20220414T110751_B07_60m.jp2', + 'T26EMU_20220414T110751_B09_60m.jp2', + 'T26EMU_20220414T110751_B11_60m.jp2', + 'T26EMU_20220414T110751_B12_60m.jp2', + 'T26EMU_20220414T110751_B8A_60m.jp2', + 'T26EMU_20220414T110751_SCL_60m.jp2', + 'T26EMU_20220414T110751_TCI_60m.jp2', + 'T26EMU_20220414T110751_WVP_60m.jp2', + 'T26EMU_20190414T110751_AOT_60m.jp2', + 'T26EMU_20190414T110751_B01_60m.jp2', + 'T26EMU_20190414T110751_B02_60m.jp2', + 'T26EMU_20190414T110751_B03_60m.jp2', + 'T26EMU_20190414T110751_B04_60m.jp2', + 'T26EMU_20190414T110751_B05_60m.jp2', + 'T26EMU_20190414T110751_B06_60m.jp2', + 'T26EMU_20190414T110751_B07_60m.jp2', + 'T26EMU_20190414T110751_B09_60m.jp2', + 'T26EMU_20190414T110751_B11_60m.jp2', + 'T26EMU_20190414T110751_B12_60m.jp2', + 'T26EMU_20190414T110751_B8A_60m.jp2', + 'T26EMU_20190414T110751_SCL_60m.jp2', + 'T26EMU_20190414T110751_TCI_60m.jp2', + 'T26EMU_20190414T110751_WVP_60m.jp2', ], } } @@ -147,27 +147,27 @@ def create_file(path: str, dtype: str, num_channels: int) -> None: res = 10 root, _ = os.path.splitext(path) - if root.endswith("m"): + if root.endswith('m'): res = int(root[-3:-1]) profile = {} - profile["driver"] = "JP2OpenJPEG" - profile["dtype"] = dtype - profile["count"] = num_channels - profile["crs"] = CRS.from_epsg(32616) - profile["transform"] = Affine(res, 0.0, 399960.0, 0.0, -res, 4500000.0) - profile["height"] = round(SIZE * 10 / res) - profile["width"] = round(SIZE * 10 / res) - - if "float" in profile["dtype"]: - Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"]) + profile['driver'] = 'JP2OpenJPEG' + profile['dtype'] = dtype + profile['count'] = num_channels + profile['crs'] = CRS.from_epsg(32616) + profile['transform'] = Affine(res, 0.0, 399960.0, 0.0, -res, 4500000.0) + profile['height'] = round(SIZE * 10 / res) + profile['width'] = round(SIZE * 10 / res) + + if 'float' in profile['dtype']: + Z = np.random.randn(SIZE, SIZE).astype(profile['dtype']) else: Z = np.random.randint( - np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"] + np.iinfo(profile['dtype']).max, size=(SIZE, SIZE), dtype=profile['dtype'] ) - with rasterio.open(path, "w", **profile) as src: - for i in range(1, profile["count"] + 1): + with rasterio.open(path, 'w', **profile) as src: + for i in range(1, profile['count'] + 1): src.write(Z, i) @@ -182,8 +182,8 @@ def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: # Base case for value in hierarchy: path = os.path.join(directory, value) - create_file(path, dtype="uint16", num_channels=1) + create_file(path, dtype='uint16', num_channels=1) -if __name__ == "__main__": - create_directory(".", filenames) +if __name__ == '__main__': + create_directory('.', filenames) diff --git a/tests/data/skippd/data.py b/tests/data/skippd/data.py index e717c2ce025..97a2670bd0c 100755 --- a/tests/data/skippd/data.py +++ b/tests/data/skippd/data.py @@ -20,18 +20,18 @@ np.random.seed(0) -tasks = ["nowcast", "forecast"] -data_file = "2017_2019_images_pv_processed_{}.hdf5" -splits = ["trainval", "test"] +tasks = ['nowcast', 'forecast'] +data_file = '2017_2019_images_pv_processed_{}.hdf5' +splits = ['trainval', 'test'] # Create dataset file data = { - "nowcast": np.random.randint( + 'nowcast': np.random.randint( RGB_MAX, size=(NUM_SAMPLES, SIZE, SIZE, NUM_CHANNELS), dtype=np.int16 ), - "forecast": np.random.randint( + 'forecast': np.random.randint( RGB_MAX, size=(NUM_SAMPLES, TIME_STEPS, SIZE, SIZE, NUM_CHANNELS), dtype=np.int16, @@ -40,38 +40,38 @@ labels = { - "nowcast": np.random.random(size=(NUM_SAMPLES)), - "forecast": np.random.random(size=(NUM_SAMPLES, TIME_STEPS)), + 'nowcast': np.random.random(size=(NUM_SAMPLES)), + 'forecast': np.random.random(size=(NUM_SAMPLES, TIME_STEPS)), } -if __name__ == "__main__": +if __name__ == '__main__': for task in tasks: - with h5py.File(data_file.format(task), "w") as f: + with h5py.File(data_file.format(task), 'w') as f: for split in splits: grp = f.create_group(split) - grp.create_dataset("images_log", data=data[task]) - grp.create_dataset("pv_log", data=labels[task]) + grp.create_dataset('images_log', data=data[task]) + grp.create_dataset('pv_log', data=labels[task]) # create time stamps for split in splits: time_stamps = np.array( [datetime.now() - timedelta(days=i) for i in range(NUM_SAMPLES)] ) - np.save(f"times_{split}_{task}.npy", time_stamps) + np.save(f'times_{split}_{task}.npy', time_stamps) # Compress data with zipfile.ZipFile( - data_file.format(task).replace(".hdf5", ".zip"), "w" + data_file.format(task).replace('.hdf5', '.zip'), 'w' ) as zip: for file in [ data_file.format(task), - f"times_trainval_{task}.npy", - f"times_test_{task}.npy", + f'times_trainval_{task}.npy', + f'times_test_{task}.npy', ]: zip.write(file, arcname=file) # Compute checksums - with open(data_file.format(task).replace(".hdf5", ".zip"), "rb") as f: + with open(data_file.format(task).replace('.hdf5', '.zip'), 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{task}: {md5}") + print(f'{task}: {md5}') diff --git a/tests/data/so2sat/data.py b/tests/data/so2sat/data.py index ef34aa68ddf..f94be235cb1 100755 --- a/tests/data/so2sat/data.py +++ b/tests/data/so2sat/data.py @@ -16,8 +16,8 @@ np.random.seed(0) -for split in ["training", "validation", "testing"]: - filename = f"{split}.h5" +for split in ['training', 'validation', 'testing']: + filename = f'{split}.h5' # Remove old data if os.path.exists(filename): @@ -33,17 +33,17 @@ sen2 = np.random.randint(2, size=(NUM_SAMPLES, SIZE, SIZE, 10), dtype=np.uint8) # Create datasets - with h5py.File(filename, "w") as f: - f.create_dataset("label", data=label, compression="gzip", compression_opts=9) - f.create_dataset("sen1", data=sen1, compression="gzip", compression_opts=9) - f.create_dataset("sen2", data=sen2, compression="gzip", compression_opts=9) + with h5py.File(filename, 'w') as f: + f.create_dataset('label', data=label, compression='gzip', compression_opts=9) + f.create_dataset('sen1', data=sen1, compression='gzip', compression_opts=9) + f.create_dataset('sen2', data=sen2, compression='gzip', compression_opts=9) # Compute checksums - with open(filename, "rb") as f: + with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(repr(split.replace("ing", "")) + ":", repr(md5) + ",") + print(repr(split.replace('ing', '')) + ':', repr(md5) + ',') -for version in ["random", "block", "culture_10"]: +for version in ['random', 'block', 'culture_10']: os.makedirs(version, exist_ok=True) - shutil.copyfile("training.h5", os.path.join(version, "training.h5")) - shutil.copyfile("testing.h5", os.path.join(version, "testing.h5")) + shutil.copyfile('training.h5', os.path.join(version, 'training.h5')) + shutil.copyfile('testing.h5', os.path.join(version, 'testing.h5')) diff --git a/tests/data/south_africa_crop_type/data.py b/tests/data/south_africa_crop_type/data.py index dcb7a1d4b6d..8b30fc60217 100644 --- a/tests/data/south_africa_crop_type/data.py +++ b/tests/data/south_africa_crop_type/data.py @@ -21,7 +21,7 @@ def generate_test_data() -> str: Returns: md5 hash of created archive """ - paths = "south_africa_crop_type" + paths = 'south_africa_crop_type' dtype = np.uint8 dtype_max = np.iinfo(dtype).max @@ -29,46 +29,46 @@ def generate_test_data() -> str: np.random.seed(0) - s1_bands = ("VH", "VV") + s1_bands = ('VH', 'VV') s2_bands = ( - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B11", - "B12", + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', ) profile = { - "dtype": dtype, - "width": SIZE, - "height": SIZE, - "count": 1, - "crs": CRS.from_epsg(32634), - "transform": Affine(10.0, 0.0, 535840.0, 0.0, -10.0, 3079680.0), + 'dtype': dtype, + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': CRS.from_epsg(32634), + 'transform': Affine(10.0, 0.0, 535840.0, 0.0, -10.0, 3079680.0), } - train_imagery_s1_dir = os.path.join(paths, "train", "imagery", "s1") - train_imagery_s2_dir = os.path.join(paths, "train", "imagery", "s2") - train_labels_dir = os.path.join(paths, "train", "labels") + train_imagery_s1_dir = os.path.join(paths, 'train', 'imagery', 's1') + train_imagery_s2_dir = os.path.join(paths, 'train', 'imagery', 's2') + train_labels_dir = os.path.join(paths, 'train', 'labels') os.makedirs(train_imagery_s1_dir, exist_ok=True) os.makedirs(train_imagery_s2_dir, exist_ok=True) os.makedirs(train_labels_dir, exist_ok=True) - train_field_ids = ["12"] + train_field_ids = ['12'] - s1_timestamps = ["2017_04_01", "2017_07_28"] - s2_timestamps = ["2017_05_04", "2017_07_22"] + s1_timestamps = ['2017_04_01', '2017_07_28'] + s2_timestamps = ['2017_05_04', '2017_07_22'] def write_raster(path: str, arr: np.array) -> None: - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(arr, 1) for field_id in train_field_ids: @@ -77,19 +77,19 @@ def write_raster(path: str, arr: np.array) -> None: os.makedirs(s1_dir, exist_ok=True) for band in s1_bands: train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype) - path = os.path.join(s1_dir, f"{field_id}_{date}_{band}_10m.tif") + path = os.path.join(s1_dir, f'{field_id}_{date}_{band}_10m.tif') write_raster(path, train_arr) for date in s2_timestamps: s2_dir = os.path.join(train_imagery_s2_dir, field_id, date) os.makedirs(s2_dir, exist_ok=True) for band in s2_bands: train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype) - path = os.path.join(s2_dir, f"{field_id}_{date}_{band}_10m.tif") + path = os.path.join(s2_dir, f'{field_id}_{date}_{band}_10m.tif') write_raster(path, train_arr) - label_path = os.path.join(train_labels_dir, f"{field_id}.tif") + label_path = os.path.join(train_labels_dir, f'{field_id}.tif') label_arr = np.random.randint(9, size=(SIZE, SIZE), dtype=dtype) write_raster(label_path, label_arr) -if __name__ == "__main__": +if __name__ == '__main__': generate_test_data() diff --git a/tests/data/south_america_soybean/data.py b/tests/data/south_america_soybean/data.py index 63e93ebdc1d..9ee5760f971 100644 --- a/tests/data/south_america_soybean/data.py +++ b/tests/data/south_america_soybean/data.py @@ -15,45 +15,45 @@ np.random.seed(0) -files = ["South_America_Soybean_2002.tif", "South_America_Soybean_2021.tif"] +files = ['South_America_Soybean_2002.tif', 'South_America_Soybean_2021.tif'] def create_file(path: str, dtype: str): """Create the testing file.""" profile = { - "driver": "GTiff", - "dtype": dtype, - "count": 1, - "crs": CRS.from_epsg(32616), - "transform": Affine(10, 0.0, 399960.0, 0.0, -10, 4500000.0), - "height": SIZE, - "width": SIZE, - "compress": "lzw", - "predictor": 2, + 'driver': 'GTiff', + 'dtype': dtype, + 'count': 1, + 'crs': CRS.from_epsg(32616), + 'transform': Affine(10, 0.0, 399960.0, 0.0, -10, 4500000.0), + 'height': SIZE, + 'width': SIZE, + 'compress': 'lzw', + 'predictor': 2, } allowed_values = [0, 1] Z = np.random.choice(allowed_values, size=(SIZE, SIZE)) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(Z, 1) -if __name__ == "__main__": - dir = os.path.join(os.getcwd(), "SouthAmericaSoybean") +if __name__ == '__main__': + dir = os.path.join(os.getcwd(), 'SouthAmericaSoybean') if os.path.exists(dir) and os.path.isdir(dir): shutil.rmtree(dir) os.makedirs(dir, exist_ok=True) for file in files: - create_file(os.path.join(dir, file), dtype="int8") + create_file(os.path.join(dir, file), dtype='int8') # Compress data - shutil.make_archive("SouthAmericaSoybean", "zip", ".", dir) + shutil.make_archive('SouthAmericaSoybean', 'zip', '.', dir) # Compute checksums - with open("SouthAmericaSoybean.zip", "rb") as f: + with open('SouthAmericaSoybean.zip', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"SouthAmericaSoybean.zip: {md5}") + print(f'SouthAmericaSoybean.zip: {md5}') diff --git a/tests/data/spacenet/data.py b/tests/data/spacenet/data.py index f51485ce8d5..3e0e6faf497 100755 --- a/tests/data/spacenet/data.py +++ b/tests/data/spacenet/data.py @@ -30,34 +30,34 @@ crs = CRS.from_epsg(4326) img_count = { - "MS.tif": 8, - "PAN.tif": 1, - "PS-MS.tif": 8, - "PS-RGB.tif": 3, - "PS-RGBNIR.tif": 4, - "RGB.tif": 3, - "RGBNIR.tif": 4, - "SAR-Intensity.tif": 1, - "mosaic.tif": 3, - "8Band.tif": 8, + 'MS.tif': 8, + 'PAN.tif': 1, + 'PS-MS.tif': 8, + 'PS-RGB.tif': 3, + 'PS-RGBNIR.tif': 4, + 'RGB.tif': 3, + 'RGBNIR.tif': 4, + 'SAR-Intensity.tif': 1, + 'mosaic.tif': 3, + '8Band.tif': 8, } sn4_catalog = [ - "10300100023BC100", - "10300100036D5200", - "1030010003BDDC00", - "1030010003CD4300", + '10300100023BC100', + '10300100036D5200', + '1030010003BDDC00', + '1030010003CD4300', ] sn4_angles = [8, 30, 52, 53] -sn4_imgdirname = "sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3730989-nadir{}_catid_{}" -sn4_lbldirname = "sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3730989-labels" +sn4_imgdirname = 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3730989-nadir{}_catid_{}' +sn4_lbldirname = 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3730989-labels' sn4_emptyimgdirname = ( - "sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3720639-nadir53_" - + "catid_1030010003CD4300" + 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3720639-nadir53_' + + 'catid_1030010003CD4300' ) -sn4_emptylbldirname = "sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3720639-labels" +sn4_emptylbldirname = 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3720639-labels' datasets = [SpaceNet1, SpaceNet2, SpaceNet3, SpaceNet4, SpaceNet5, SpaceNet6, SpaceNet7] @@ -75,12 +75,12 @@ def create_test_image(img_dir: str, imgs: list[str]) -> list[list[float]]: """ for img in imgs: imgpath = os.path.join(img_dir, img) - Z = np.arange(4, dtype="uint16").reshape(2, 2) + Z = np.arange(4, dtype='uint16').reshape(2, 2) count = img_count[img] with rasterio.open( imgpath, - "w", - driver="GTiff", + 'w', + driver='GTiff', height=Z.shape[0], width=Z.shape[1], count=count, @@ -117,57 +117,57 @@ def create_test_label( """ if empty: # Creates a new file - with open(os.path.join(lbldir, lblname), "w"): + with open(os.path.join(lbldir, lblname), 'w'): pass return - if det_type == "buildings": + if det_type == 'buildings': meta_properties = OrderedDict() - geom = "Polygon" + geom = 'Polygon' rec = { - "type": "Feature", - "id": "0", - "properties": OrderedDict(), - "geometry": {"type": "Polygon", "coordinates": [coords]}, + 'type': 'Feature', + 'id': '0', + 'properties': OrderedDict(), + 'geometry': {'type': 'Polygon', 'coordinates': [coords]}, } else: meta_properties = OrderedDict( [ - ("heading", "str"), - ("lane_number", "str"), - ("one_way_ty", "str"), - ("paved", "str"), - ("road_id", "int"), - ("road_type", "str"), - ("origarea", "int"), - ("origlen", "float"), - ("partialDec", "int"), - ("truncated", "int"), - ("bridge_type", "str"), - ("inferred_speed_mph", "float"), - ("inferred_speed_mps", "float"), + ('heading', 'str'), + ('lane_number', 'str'), + ('one_way_ty', 'str'), + ('paved', 'str'), + ('road_id', 'int'), + ('road_type', 'str'), + ('origarea', 'int'), + ('origlen', 'float'), + ('partialDec', 'int'), + ('truncated', 'int'), + ('bridge_type', 'str'), + ('inferred_speed_mph', 'float'), + ('inferred_speed_mps', 'float'), ] ) - geom = "LineString" + geom = 'LineString' - dummy_vals = {"str": "a", "float": 45.0, "int": 0} + dummy_vals = {'str': 'a', 'float': 45.0, 'int': 0} ROAD_DICT = [(k, dummy_vals[v]) for k, v in meta_properties.items()] rec = { - "type": "Feature", - "id": "0", - "properties": OrderedDict(ROAD_DICT), - "geometry": {"type": "LineString", "coordinates": [coords[0], coords[2]]}, + 'type': 'Feature', + 'id': '0', + 'properties': OrderedDict(ROAD_DICT), + 'geometry': {'type': 'LineString', 'coordinates': [coords[0], coords[2]]}, } meta = { - "driver": "GeoJSON", - "schema": {"properties": meta_properties, "geometry": geom}, - "crs": {"init": "epsg:4326"}, + 'driver': 'GeoJSON', + 'schema': {'properties': meta_properties, 'geometry': geom}, + 'crs': {'init': 'epsg:4326'}, } if diff_crs: - meta["crs"] = {"init": "epsg:3857"} + meta['crs'] = {'init': 'epsg:3857'} out_file = os.path.join(lbldir, lblname) - with fiona.open(out_file, "w", **meta) as dst: + with fiona.open(out_file, 'w', **meta) as dst: dst.write(rec) @@ -178,25 +178,25 @@ def main() -> None: collections = list(dataset.collection_md5_dict.keys()) for collection in collections: dataset = cast(SpaceNet, dataset) - if dataset.dataset_id == "spacenet4": + if dataset.dataset_id == 'spacenet4': num_samples = 4 - elif collection == "sn5_AOI_7_Moscow" or collection not in [ - "sn5_AOI_8_Mumbai", - "sn7_test_source", + elif collection == 'sn5_AOI_7_Moscow' or collection not in [ + 'sn5_AOI_8_Mumbai', + 'sn7_test_source', ]: num_samples = 3 - elif collection == "sn5_AOI_8_Mumbai": + elif collection == 'sn5_AOI_8_Mumbai': num_samples = 3 else: num_samples = 1 for sample in range(num_samples): out_dir = os.path.join(ROOT_DIR, collection) - if collection == "sn6_AOI_11_Rotterdam": - out_dir = os.path.join(ROOT_DIR, "spacenet6", collection) + if collection == 'sn6_AOI_11_Rotterdam': + out_dir = os.path.join(ROOT_DIR, 'spacenet6', collection) # Create img dir - if dataset.dataset_id == "spacenet4": + if dataset.dataset_id == 'spacenet4': assert num_samples == 4 if sample != 3: imgdirname = sn4_imgdirname.format( @@ -209,8 +209,8 @@ def main() -> None: ) lbldirname = sn4_emptylbldirname else: - imgdirname = f"{collection}_img{sample + 1}" - lbldirname = f"{collection}_img{sample + 1}-labels" + imgdirname = f'{collection}_img{sample + 1}' + lbldirname = f'{collection}_img{sample + 1}-labels' imgdir = os.path.join(out_dir, imgdirname) os.makedirs(imgdir, exist_ok=True) @@ -219,8 +219,8 @@ def main() -> None: # Create lbl dir lbldir = os.path.join(out_dir, lbldirname) os.makedirs(lbldir, exist_ok=True) - det_type = "roads" if dataset in [SpaceNet3, SpaceNet5] else "buildings" - if dataset.dataset_id == "spacenet4" and sample == 3: + det_type = 'roads' if dataset in [SpaceNet3, SpaceNet5] else 'buildings' + if dataset.dataset_id == 'spacenet4' and sample == 3: # Creates an empty file create_test_label( lbldir, dataset.label_glob, bounds, det_type, empty=True @@ -228,7 +228,7 @@ def main() -> None: else: create_test_label(lbldir, dataset.label_glob, bounds, det_type) - if collection == "sn5_AOI_8_Mumbai": + if collection == 'sn5_AOI_8_Mumbai': if sample == 1: create_test_label( lbldir, dataset.label_glob, bounds, det_type, empty=True @@ -238,44 +238,44 @@ def main() -> None: lbldir, dataset.label_glob, bounds, det_type, diff_crs=True ) - if collection == "sn1_AOI_1_RIO" and sample == 1: + if collection == 'sn1_AOI_1_RIO' and sample == 1: create_test_label( lbldir, dataset.label_glob, bounds, det_type, diff_crs=True ) if collection not in [ - "sn2_AOI_2_Vegas", - "sn3_AOI_5_Khartoum", - "sn4_AOI_6_Atlanta", - "sn5_AOI_8_Mumbai", - "sn6_AOI_11_Rotterdam", - "sn7_train_source", + 'sn2_AOI_2_Vegas', + 'sn3_AOI_5_Khartoum', + 'sn4_AOI_6_Atlanta', + 'sn5_AOI_8_Mumbai', + 'sn6_AOI_11_Rotterdam', + 'sn7_train_source', ]: # Create collection.json with open( - os.path.join(ROOT_DIR, collection, "collection.json"), "w" + os.path.join(ROOT_DIR, collection, 'collection.json'), 'w' ): pass - if collection == "sn6_AOI_11_Rotterdam": + if collection == 'sn6_AOI_11_Rotterdam': # Create collection.json with open( os.path.join( - ROOT_DIR, "spacenet6", collection, "collection.json" + ROOT_DIR, 'spacenet6', collection, 'collection.json' ), - "w", + 'w', ): pass # Create archive - if collection == "sn6_AOI_11_Rotterdam": + if collection == 'sn6_AOI_11_Rotterdam': break archive_path = os.path.join(ROOT_DIR, collection) shutil.make_archive( - archive_path, "gztar", root_dir=ROOT_DIR, base_dir=collection + archive_path, 'gztar', root_dir=ROOT_DIR, base_dir=collection ) shutil.rmtree(out_dir) print(f'{collection}: {calculate_md5(f"{archive_path}.tar.gz")}') -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/tests/data/ssl4eo/l/data.py b/tests/data/ssl4eo/l/data.py index 1c9ada4411a..000087c10d5 100755 --- a/tests/data/ssl4eo/l/data.py +++ b/tests/data/ssl4eo/l/data.py @@ -17,99 +17,99 @@ np.random.seed(0) -FILENAME_HIERARCHY = dict[str, "FILENAME_HIERARCHY"] | list[str] +FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str] filenames: FILENAME_HIERARCHY = { - "ssl4eo_l_tm_toa": { - "0000002": { - "LT05_172034_20010526": ["all_bands.tif"], - "LT05_172034_20020310": ["all_bands.tif"], - "LT05_172034_20020902": ["all_bands.tif"], - "LT05_172034_20021121": ["all_bands.tif"], + 'ssl4eo_l_tm_toa': { + '0000002': { + 'LT05_172034_20010526': ['all_bands.tif'], + 'LT05_172034_20020310': ['all_bands.tif'], + 'LT05_172034_20020902': ['all_bands.tif'], + 'LT05_172034_20021121': ['all_bands.tif'], }, - "0000005": { - "LT05_223084_20010413": ["all_bands.tif"], - "LT05_223084_20011225": ["all_bands.tif"], - "LT05_223084_20020619": ["all_bands.tif"], - "LT5_223084_20020923": ["all_bands.tif"], + '0000005': { + 'LT05_223084_20010413': ['all_bands.tif'], + 'LT05_223084_20011225': ['all_bands.tif'], + 'LT05_223084_20020619': ['all_bands.tif'], + 'LT5_223084_20020923': ['all_bands.tif'], }, }, - "ssl4eo_l_etm_toa": { - "0000002": { - "LE07_172034_20010526": ["all_bands.tif"], - "LE07_172034_20020310": ["all_bands.tif"], - "LE07_172034_20020902": ["all_bands.tif"], - "LE07_172034_20021121": ["all_bands.tif"], + 'ssl4eo_l_etm_toa': { + '0000002': { + 'LE07_172034_20010526': ['all_bands.tif'], + 'LE07_172034_20020310': ['all_bands.tif'], + 'LE07_172034_20020902': ['all_bands.tif'], + 'LE07_172034_20021121': ['all_bands.tif'], }, - "0000005": { - "LE07_223084_20010413": ["all_bands.tif"], - "LE07_223084_20011225": ["all_bands.tif"], - "LE07_223084_20020619": ["all_bands.tif"], - "LE07_223084_20020923": ["all_bands.tif"], + '0000005': { + 'LE07_223084_20010413': ['all_bands.tif'], + 'LE07_223084_20011225': ['all_bands.tif'], + 'LE07_223084_20020619': ['all_bands.tif'], + 'LE07_223084_20020923': ['all_bands.tif'], }, }, - "ssl4eo_l_etm_sr": { - "0000002": { - "LE07_172034_20010526": ["all_bands.tif"], - "LE07_172034_20020310": ["all_bands.tif"], - "LE07_172034_20020902": ["all_bands.tif"], - "LE07_172034_20021121": ["all_bands.tif"], + 'ssl4eo_l_etm_sr': { + '0000002': { + 'LE07_172034_20010526': ['all_bands.tif'], + 'LE07_172034_20020310': ['all_bands.tif'], + 'LE07_172034_20020902': ['all_bands.tif'], + 'LE07_172034_20021121': ['all_bands.tif'], }, - "0000005": { - "LE07_223084_20010413": ["all_bands.tif"], - "LE07_223084_20011225": ["all_bands.tif"], - "LE07_223084_20020619": ["all_bands.tif"], - "LE07_223084_20020923": ["all_bands.tif"], + '0000005': { + 'LE07_223084_20010413': ['all_bands.tif'], + 'LE07_223084_20011225': ['all_bands.tif'], + 'LE07_223084_20020619': ['all_bands.tif'], + 'LE07_223084_20020923': ['all_bands.tif'], }, }, - "ssl4eo_l_oli_tirs_toa": { - "0000002": { - "LC08_172034_20210306": ["all_bands.tif"], - "LC08_172034_20210829": ["all_bands.tif"], - "LC08_172034_20211203": ["all_bands.tif"], - "LC08_172034_20220715": ["all_bands.tif"], + 'ssl4eo_l_oli_tirs_toa': { + '0000002': { + 'LC08_172034_20210306': ['all_bands.tif'], + 'LC08_172034_20210829': ['all_bands.tif'], + 'LC08_172034_20211203': ['all_bands.tif'], + 'LC08_172034_20220715': ['all_bands.tif'], }, - "0000005": { - "LC08_223084_20210412": ["all_bands.tif"], - "LC08_223084_20211005": ["all_bands.tif"], - "LC08_223084_20220618": ["all_bands.tif"], - "LC08_223084_20221211": ["all_bands.tif"], + '0000005': { + 'LC08_223084_20210412': ['all_bands.tif'], + 'LC08_223084_20211005': ['all_bands.tif'], + 'LC08_223084_20220618': ['all_bands.tif'], + 'LC08_223084_20221211': ['all_bands.tif'], }, }, - "ssl4eo_l_oli_sr": { - "0000002": { - "LC08_172034_20210306": ["all_bands.tif"], - "LC08_172034_20210829": ["all_bands.tif"], - "LC08_172034_20211203": ["all_bands.tif"], - "LC08_172034_20220715": ["all_bands.tif"], + 'ssl4eo_l_oli_sr': { + '0000002': { + 'LC08_172034_20210306': ['all_bands.tif'], + 'LC08_172034_20210829': ['all_bands.tif'], + 'LC08_172034_20211203': ['all_bands.tif'], + 'LC08_172034_20220715': ['all_bands.tif'], }, - "0000005": { - "LC08_223084_20210412": ["all_bands.tif"], - "LC08_223084_20211005": ["all_bands.tif"], - "LC08_223084_20220618": ["all_bands.tif"], - "LC08_223084_20221211": ["all_bands.tif"], + '0000005': { + 'LC08_223084_20210412': ['all_bands.tif'], + 'LC08_223084_20211005': ['all_bands.tif'], + 'LC08_223084_20220618': ['all_bands.tif'], + 'LC08_223084_20221211': ['all_bands.tif'], }, }, } num_bands = { - "ssl4eo_l_tm_toa": 7, - "ssl4eo_l_etm_toa": 9, - "ssl4eo_l_etm_sr": 6, - "ssl4eo_l_oli_tirs_toa": 11, - "ssl4eo_l_oli_sr": 7, + 'ssl4eo_l_tm_toa': 7, + 'ssl4eo_l_etm_toa': 9, + 'ssl4eo_l_etm_sr': 6, + 'ssl4eo_l_oli_tirs_toa': 11, + 'ssl4eo_l_oli_sr': 7, } def create_file(path: str) -> None: profile = { - "driver": "GTiff", - "dtype": "uint8", - "width": SIZE, - "height": SIZE, - "count": num_bands[path.split(os.sep)[1]], - "crs": CRS.from_epsg(4326), - "transform": Affine( + 'driver': 'GTiff', + 'dtype': 'uint8', + 'width': SIZE, + 'height': SIZE, + 'count': num_bands[path.split(os.sep)[1]], + 'crs': CRS.from_epsg(4326), + 'transform': Affine( 0.00033331040066238285, 0.0, 40.31409193350423, @@ -117,13 +117,13 @@ def create_file(path: str) -> None: -0.0002658855613264443, 37.60408425220701, ), - "compress": "lzw", - "predictor": 2, + 'compress': 'lzw', + 'predictor': 2, } - Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"]) + Z = np.random.randn(SIZE, SIZE).astype(profile['dtype']) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: for i in src.indexes: src.write(Z, i) @@ -142,22 +142,22 @@ def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: create_file(path) -if __name__ == "__main__": - create_directory(".", filenames) +if __name__ == '__main__': + create_directory('.', filenames) directories = filenames.keys() for directory in directories: # Create tarball - shutil.make_archive(directory, "gztar", ".", directory) + shutil.make_archive(directory, 'gztar', '.', directory) # Split tarball - path = f"{directory}.tar.gz" + path = f'{directory}.tar.gz' paths = [] - with open(path, "rb") as f: - suffix = "a" + with open(path, 'rb') as f: + suffix = 'a' while chunk := f.read(CHUNK_SIZE): - split = f"{path}a{suffix}" - with open(split, "wb") as g: + split = f'{path}a{suffix}' + with open(split, 'wb') as g: g.write(chunk) suffix = chr(ord(suffix) + 1) paths.append(split) @@ -166,6 +166,6 @@ def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: # Compute checksums for path in paths: - with open(path, "rb") as f: + with open(path, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() print(path, md5) diff --git a/tests/data/ssl4eo/s12/data.py b/tests/data/ssl4eo/s12/data.py index 01f47f976aa..3ba0f09aefb 100755 --- a/tests/data/ssl4eo/s12/data.py +++ b/tests/data/ssl4eo/s12/data.py @@ -16,67 +16,67 @@ np.random.seed(0) -FILENAME_HIERARCHY = dict[str, "FILENAME_HIERARCHY"] | list[str] +FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str] -s1 = ["VH.tif", "VV.tif"] +s1 = ['VH.tif', 'VV.tif'] s2c = [ - "B1.tif", - "B2.tif", - "B3.tif", - "B4.tif", - "B5.tif", - "B6.tif", - "B7.tif", - "B8.tif", - "B8A.tif", - "B9.tif", - "B10.tif", - "B11.tif", - "B12.tif", + 'B1.tif', + 'B2.tif', + 'B3.tif', + 'B4.tif', + 'B5.tif', + 'B6.tif', + 'B7.tif', + 'B8.tif', + 'B8A.tif', + 'B9.tif', + 'B10.tif', + 'B11.tif', + 'B12.tif', ] s2a = s2c.copy() -s2a.remove("B10.tif") +s2a.remove('B10.tif') filenames: FILENAME_HIERARCHY = { - "s1": { - "0000000": { - "S1A_IW_GRDH_1SDV_20200329T001515_20200329T001540_031883_03AE27_9BAF": s1, - "S1A_IW_GRDH_1SDV_20201230T001523_20201230T001548_035908_04349D_C91E": s1, - "S1B_IW_GRDH_1SDV_20200627T001449_20200627T001514_022212_02A27E_2A09": s1, - "S1B_IW_GRDH_1SDV_20200928T120105_20200928T120130_023575_02CCB0_F035": s1, + 's1': { + '0000000': { + 'S1A_IW_GRDH_1SDV_20200329T001515_20200329T001540_031883_03AE27_9BAF': s1, + 'S1A_IW_GRDH_1SDV_20201230T001523_20201230T001548_035908_04349D_C91E': s1, + 'S1B_IW_GRDH_1SDV_20200627T001449_20200627T001514_022212_02A27E_2A09': s1, + 'S1B_IW_GRDH_1SDV_20200928T120105_20200928T120130_023575_02CCB0_F035': s1, }, - "0000001": { - "S1B_IW_GRDH_1SDV_20201101T091054_20201101T091119_024069_02DC0F_F189": s1, - "S1B_IW_GRDH_1SDV_20210205T091050_20210205T091115_025469_0308CB_AA25": s1, - "S1B_IW_GRDH_1SDV_20210430T091051_20210430T091116_026694_03303D_69B6": s1, - "S1B_IW_GRDH_1SDV_20210804T091057_20210804T091122_028094_0359FE_6D9D": s1, + '0000001': { + 'S1B_IW_GRDH_1SDV_20201101T091054_20201101T091119_024069_02DC0F_F189': s1, + 'S1B_IW_GRDH_1SDV_20210205T091050_20210205T091115_025469_0308CB_AA25': s1, + 'S1B_IW_GRDH_1SDV_20210430T091051_20210430T091116_026694_03303D_69B6': s1, + 'S1B_IW_GRDH_1SDV_20210804T091057_20210804T091122_028094_0359FE_6D9D': s1, }, }, - "s2c": { - "0000000": { - "20200323T162931_20200323T163750_T15QXA": s2c, - "20200621T162901_20200621T164746_T15QXA": s2c, - "20200924T162929_20200924T164434_T15QXA": s2c, - "20201228T163711_20201228T164519_T15QXA": s2c, + 's2c': { + '0000000': { + '20200323T162931_20200323T163750_T15QXA': s2c, + '20200621T162901_20200621T164746_T15QXA': s2c, + '20200924T162929_20200924T164434_T15QXA': s2c, + '20201228T163711_20201228T164519_T15QXA': s2c, }, - "0000001": { - "20201104T135121_20201104T135117_T21KXT": s2c, - "20210123T135111_20210123T135113_T21KXT": s2c, - "20210508T135109_20210508T135519_T21KXT": s2c, - "20210811T135121_20210811T135115_T21KXT": s2c, + '0000001': { + '20201104T135121_20201104T135117_T21KXT': s2c, + '20210123T135111_20210123T135113_T21KXT': s2c, + '20210508T135109_20210508T135519_T21KXT': s2c, + '20210811T135121_20210811T135115_T21KXT': s2c, }, }, - "s2a": { - "0000000": { - "20200323T162931_20200323T163750_T15QXA": s2a, - "20200621T162901_20200621T164746_T15QXA": s2a, - "20200924T162929_20200924T164434_T15QXA": s2a, - "20201228T163711_20201228T164519_T15QXA": s2a, + 's2a': { + '0000000': { + '20200323T162931_20200323T163750_T15QXA': s2a, + '20200621T162901_20200621T164746_T15QXA': s2a, + '20200924T162929_20200924T164434_T15QXA': s2a, + '20201228T163711_20201228T164519_T15QXA': s2a, }, - "0000001": { - "20201104T135121_20201104T135117_T21KXT": s2a, - "20210123T135111_20210123T135113_T21KXT": s2a, - "20210508T135109_20210508T135519_T21KXT": s2a, - "20210811T135121_20210811T135115_T21KXT": s2a, + '0000001': { + '20201104T135121_20201104T135117_T21KXT': s2a, + '20210123T135111_20210123T135113_T21KXT': s2a, + '20210508T135109_20210508T135519_T21KXT': s2a, + '20210811T135121_20210811T135115_T21KXT': s2a, }, }, } @@ -84,13 +84,13 @@ def create_file(path: str) -> None: profile = { - "driver": "GTiff", - "dtype": "uint16", - "width": SIZE, - "height": SIZE, - "count": 1, - "crs": CRS.from_epsg(4326), - "transform": Affine( + 'driver': 'GTiff', + 'dtype': 'uint16', + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': CRS.from_epsg(4326), + 'transform': Affine( 9.221577104649252e-05, 0.0, -91.84569595740037, @@ -100,12 +100,12 @@ def create_file(path: str) -> None: ), } - if path.endswith("VH.tif") or path.endswith("VV.tif"): - profile["dtype"] = "float32" + if path.endswith('VH.tif') or path.endswith('VV.tif'): + profile['dtype'] = 'float32' - if path.endswith("B1.tif") or path.endswith("B9.tif") or path.endswith("B10.tif"): - profile["width"] = profile["height"] = SIZE // 6 - profile["transform"] = Affine( + if path.endswith('B1.tif') or path.endswith('B9.tif') or path.endswith('B10.tif'): + profile['width'] = profile['height'] = SIZE // 6 + profile['transform'] = Affine( 0.0005532946262789551, 0.0, -91.84592649682799, @@ -114,15 +114,15 @@ def create_file(path: str) -> None: 18.588322889892943, ) elif ( - path.endswith("B5.tif") - or path.endswith("B6.tif") - or path.endswith("B7.tif") - or path.endswith("B8A.tif") - or path.endswith("B11.tif") - or path.endswith("B12.tif") + path.endswith('B5.tif') + or path.endswith('B6.tif') + or path.endswith('B7.tif') + or path.endswith('B8A.tif') + or path.endswith('B11.tif') + or path.endswith('B12.tif') ): - profile["width"] = profile["height"] = SIZE // 2 - profile["transform"] = Affine( + profile['width'] = profile['height'] = SIZE // 2 + profile['transform'] = Affine( 0.00018443154209298504, 0.0, -91.84574206528589, @@ -131,9 +131,9 @@ def create_file(path: str) -> None: 18.588146945880982, ) - Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"]) + Z = np.random.randn(SIZE, SIZE).astype(profile['dtype']) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(Z, 1) @@ -151,16 +151,16 @@ def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: create_file(path) -if __name__ == "__main__": - create_directory(".", filenames) +if __name__ == '__main__': + create_directory('.', filenames) - files = ["s1", "s2_l1c", "s2_l2a"] - directories = ["s1", "s2c", "s2a"] + files = ['s1', 's2_l1c', 's2_l2a'] + directories = ['s1', 's2c', 's2a'] for file, directory in zip(files, directories): # Create tarballs - shutil.make_archive(file, "gztar", ".", directory) + shutil.make_archive(file, 'gztar', '.', directory) # Compute checksums - with open(f"{file}.tar.gz", "rb") as f: + with open(f'{file}.tar.gz', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() print(file, md5) diff --git a/tests/data/ssl4eo_benchmark_landsat/data.py b/tests/data/ssl4eo_benchmark_landsat/data.py index 5afe54eb574..f800c70bec8 100755 --- a/tests/data/ssl4eo_benchmark_landsat/data.py +++ b/tests/data/ssl4eo_benchmark_landsat/data.py @@ -16,85 +16,85 @@ np.random.seed(0) -FILENAME_HIERARCHY = dict[str, "FILENAME_HIERARCHY"] | list[str] +FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str] filenames: FILENAME_HIERARCHY = { - "tm_toa": { - "0000001": {"LT05_172030_20010526": ["all_bands.tif"]}, - "0000002": {"LT05_223084_20010413": ["all_bands.tif"]}, - "0000003": {"LT05_172034_20020902": ["all_bands.tif"]}, - "0000004": {"LT05_172034_20020903": ["all_bands.tif"]}, - "0000005": {"LT05_172034_20020904": ["all_bands.tif"]}, - "0000006": {"LT05_172034_20020905": ["all_bands.tif"]}, - "0000007": {"LT05_172034_20020906": ["all_bands.tif"]}, - "0000008": {"LT05_172034_20020907": ["all_bands.tif"]}, - "0000009": {"LT05_172034_20020908": ["all_bands.tif"]}, - "0000010": {"LT05_172034_20020909": ["all_bands.tif"]}, + 'tm_toa': { + '0000001': {'LT05_172030_20010526': ['all_bands.tif']}, + '0000002': {'LT05_223084_20010413': ['all_bands.tif']}, + '0000003': {'LT05_172034_20020902': ['all_bands.tif']}, + '0000004': {'LT05_172034_20020903': ['all_bands.tif']}, + '0000005': {'LT05_172034_20020904': ['all_bands.tif']}, + '0000006': {'LT05_172034_20020905': ['all_bands.tif']}, + '0000007': {'LT05_172034_20020906': ['all_bands.tif']}, + '0000008': {'LT05_172034_20020907': ['all_bands.tif']}, + '0000009': {'LT05_172034_20020908': ['all_bands.tif']}, + '0000010': {'LT05_172034_20020909': ['all_bands.tif']}, }, - "etm_sr": { - "0000001": {"LE07_172030_20010526": ["all_bands.tif"]}, - "0000002": {"LE07_223084_20010413": ["all_bands.tif"]}, - "0000003": {"LE07_172034_20020902": ["all_bands.tif"]}, - "0000004": {"LE07_172034_20020903": ["all_bands.tif"]}, - "0000005": {"LE07_172034_20020904": ["all_bands.tif"]}, - "0000006": {"LE07_172034_20020905": ["all_bands.tif"]}, - "0000007": {"LE07_172034_20020906": ["all_bands.tif"]}, - "0000008": {"LE07_172034_20020907": ["all_bands.tif"]}, - "0000009": {"LE07_172034_20020908": ["all_bands.tif"]}, - "0000010": {"LE07_172034_20020909": ["all_bands.tif"]}, + 'etm_sr': { + '0000001': {'LE07_172030_20010526': ['all_bands.tif']}, + '0000002': {'LE07_223084_20010413': ['all_bands.tif']}, + '0000003': {'LE07_172034_20020902': ['all_bands.tif']}, + '0000004': {'LE07_172034_20020903': ['all_bands.tif']}, + '0000005': {'LE07_172034_20020904': ['all_bands.tif']}, + '0000006': {'LE07_172034_20020905': ['all_bands.tif']}, + '0000007': {'LE07_172034_20020906': ['all_bands.tif']}, + '0000008': {'LE07_172034_20020907': ['all_bands.tif']}, + '0000009': {'LE07_172034_20020908': ['all_bands.tif']}, + '0000010': {'LE07_172034_20020909': ['all_bands.tif']}, }, - "etm_toa": { - "0000001": {"LE07_172030_20010526": ["all_bands.tif"]}, - "0000002": {"LE07_223084_20010413": ["all_bands.tif"]}, - "0000003": {"LE07_172034_20020902": ["all_bands.tif"]}, - "0000004": {"LE07_172034_20020903": ["all_bands.tif"]}, - "0000005": {"LE07_172034_20020904": ["all_bands.tif"]}, - "0000006": {"LE07_172034_20020905": ["all_bands.tif"]}, - "0000007": {"LE07_172034_20020906": ["all_bands.tif"]}, - "0000008": {"LE07_172034_20020907": ["all_bands.tif"]}, - "0000009": {"LE07_172034_20020908": ["all_bands.tif"]}, - "0000010": {"LE07_172034_20020909": ["all_bands.tif"]}, + 'etm_toa': { + '0000001': {'LE07_172030_20010526': ['all_bands.tif']}, + '0000002': {'LE07_223084_20010413': ['all_bands.tif']}, + '0000003': {'LE07_172034_20020902': ['all_bands.tif']}, + '0000004': {'LE07_172034_20020903': ['all_bands.tif']}, + '0000005': {'LE07_172034_20020904': ['all_bands.tif']}, + '0000006': {'LE07_172034_20020905': ['all_bands.tif']}, + '0000007': {'LE07_172034_20020906': ['all_bands.tif']}, + '0000008': {'LE07_172034_20020907': ['all_bands.tif']}, + '0000009': {'LE07_172034_20020908': ['all_bands.tif']}, + '0000010': {'LE07_172034_20020909': ['all_bands.tif']}, }, - "oli_tirs_toa": { - "0000001": {"LC08_172030_20010526": ["all_bands.tif"]}, - "0000002": {"LC08_223084_20010413": ["all_bands.tif"]}, - "0000003": {"LC08_172034_20020902": ["all_bands.tif"]}, - "0000004": {"LC08_172034_20020903": ["all_bands.tif"]}, - "0000005": {"LC08_172034_20020904": ["all_bands.tif"]}, - "0000006": {"LC08_172034_20020905": ["all_bands.tif"]}, - "0000007": {"LC08_172034_20020906": ["all_bands.tif"]}, - "0000008": {"LC08_172034_20020907": ["all_bands.tif"]}, - "0000009": {"LC08_172034_20020908": ["all_bands.tif"]}, - "0000010": {"LC08_172034_20020909": ["all_bands.tif"]}, + 'oli_tirs_toa': { + '0000001': {'LC08_172030_20010526': ['all_bands.tif']}, + '0000002': {'LC08_223084_20010413': ['all_bands.tif']}, + '0000003': {'LC08_172034_20020902': ['all_bands.tif']}, + '0000004': {'LC08_172034_20020903': ['all_bands.tif']}, + '0000005': {'LC08_172034_20020904': ['all_bands.tif']}, + '0000006': {'LC08_172034_20020905': ['all_bands.tif']}, + '0000007': {'LC08_172034_20020906': ['all_bands.tif']}, + '0000008': {'LC08_172034_20020907': ['all_bands.tif']}, + '0000009': {'LC08_172034_20020908': ['all_bands.tif']}, + '0000010': {'LC08_172034_20020909': ['all_bands.tif']}, }, - "oli_sr": { - "0000001": {"LC08_172030_20010526": ["all_bands.tif"]}, - "0000002": {"LC08_223084_20010413": ["all_bands.tif"]}, - "0000003": {"LC08_172034_20020902": ["all_bands.tif"]}, - "0000004": {"LC08_172034_20020903": ["all_bands.tif"]}, - "0000005": {"LC08_172034_20020904": ["all_bands.tif"]}, - "0000006": {"LC08_172034_20020905": ["all_bands.tif"]}, - "0000007": {"LC08_172034_20020906": ["all_bands.tif"]}, - "0000008": {"LC08_172034_20020907": ["all_bands.tif"]}, - "0000009": {"LC08_172034_20020908": ["all_bands.tif"]}, - "0000010": {"LC08_172034_20020909": ["all_bands.tif"]}, + 'oli_sr': { + '0000001': {'LC08_172030_20010526': ['all_bands.tif']}, + '0000002': {'LC08_223084_20010413': ['all_bands.tif']}, + '0000003': {'LC08_172034_20020902': ['all_bands.tif']}, + '0000004': {'LC08_172034_20020903': ['all_bands.tif']}, + '0000005': {'LC08_172034_20020904': ['all_bands.tif']}, + '0000006': {'LC08_172034_20020905': ['all_bands.tif']}, + '0000007': {'LC08_172034_20020906': ['all_bands.tif']}, + '0000008': {'LC08_172034_20020907': ['all_bands.tif']}, + '0000009': {'LC08_172034_20020908': ['all_bands.tif']}, + '0000010': {'LC08_172034_20020909': ['all_bands.tif']}, }, } -num_bands = {"tm_toa": 7, "etm_sr": 6, "etm_toa": 9, "oli_tirs_toa": 11, "oli_sr": 7} -years = {"tm": 2011, "etm": 2019, "oli": 2019} +num_bands = {'tm_toa': 7, 'etm_sr': 6, 'etm_toa': 9, 'oli_tirs_toa': 11, 'oli_sr': 7} +years = {'tm': 2011, 'etm': 2019, 'oli': 2019} def create_image(path: str) -> None: profile = { - "driver": "GTiff", - "dtype": "uint8", - "nodata": None, - "width": SIZE, - "height": SIZE, - "count": num_bands["_".join(path.split(os.sep)[1].split("_")[2:][:-1])], - "crs": CRS.from_epsg(4326), - "transform": Affine( + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'width': SIZE, + 'height': SIZE, + 'count': num_bands['_'.join(path.split(os.sep)[1].split('_')[2:][:-1])], + 'crs': CRS.from_epsg(4326), + 'transform': Affine( 0.00037672803497508636, 0.0, -109.07063613660262, @@ -102,29 +102,29 @@ def create_image(path: str) -> None: -0.0002554026278261721, 47.49838726154881, ), - "blockysize": 1, - "tiled": False, - "compress": "lzw", - "interleave": "pixel", + 'blockysize': 1, + 'tiled': False, + 'compress': 'lzw', + 'interleave': 'pixel', } Z = np.random.randint(low=0, high=255, size=(SIZE, SIZE)) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: for i in src.indexes: src.write(Z, i) def create_mask(path: str) -> None: profile = { - "driver": "GTiff", - "dtype": "uint8", - "nodata": None, - "width": SIZE, - "height": SIZE, - "count": 1, - "crs": CRS.from_epsg(4326), - "transform": Affine( + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': CRS.from_epsg(4326), + 'transform': Affine( 0.00037672803497508636, 0.0, -109.07063613660262, @@ -132,15 +132,15 @@ def create_mask(path: str) -> None: -0.0002554026278261721, 47.49838726154881, ), - "blockysize": 1, - "tiled": False, - "compress": "lzw", - "interleave": "band", + 'blockysize': 1, + 'tiled': False, + 'compress': 'lzw', + 'interleave': 'band', } Z = np.random.randint(low=0, high=10, size=(1, SIZE, SIZE)) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(Z) @@ -149,7 +149,7 @@ def create_img_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: # Recursive case for key, value in hierarchy.items(): if any([x in key for x in filenames.keys()]): - key = f"ssl4eo_l_{key}_benchmark" + key = f'ssl4eo_l_{key}_benchmark' path = os.path.join(directory, key) os.makedirs(path, exist_ok=True) create_img_directory(path, value) @@ -173,36 +173,36 @@ def create_mask_directory( # Base case for value in hierarchy: path = os.path.join(directory, value) - year = years[path.split(os.sep)[1].split("_")[2]] - create_mask(path.replace("all_bands", f"{mask_product}_{year}")) + year = years[path.split(os.sep)[1].split('_')[2]] + create_mask(path.replace('all_bands', f'{mask_product}_{year}')) def create_tarballs(directories) -> None: for directory in directories: # Create tarballs - shutil.make_archive(directory, "gztar", ".", directory) + shutil.make_archive(directory, 'gztar', '.', directory) # Compute checksums - with open(f"{directory}.tar.gz", "rb") as f: + with open(f'{directory}.tar.gz', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() print(directory, md5) -if __name__ == "__main__": +if __name__ == '__main__': # image directories - create_img_directory(".", filenames) + create_img_directory('.', filenames) directories = filenames.keys() - directories = [f"ssl4eo_l_{key}_benchmark" for key in directories] + directories = [f'ssl4eo_l_{key}_benchmark' for key in directories] create_tarballs(directories) # mask directory cdl - mask_keep = ["tm_toa", "etm_sr", "oli_sr"] + mask_keep = ['tm_toa', 'etm_sr', 'oli_sr'] mask_filenames = { f"ssl4eo_l_{key.split('_')[0]}_cdl": val for key, val in filenames.items() if key in mask_keep } - create_mask_directory(".", mask_filenames, "cdl") + create_mask_directory('.', mask_filenames, 'cdl') directories = mask_filenames.keys() create_tarballs(directories) @@ -212,6 +212,6 @@ def create_tarballs(directories) -> None: for key, val in filenames.items() if key in mask_keep } - create_mask_directory(".", mask_filenames, "nlcd") + create_mask_directory('.', mask_filenames, 'nlcd') directories = mask_filenames.keys() create_tarballs(directories) diff --git a/tests/data/sustainbench_crop_yield/data.py b/tests/data/sustainbench_crop_yield/data.py index 46b2e653ac6..27e0605bed8 100755 --- a/tests/data/sustainbench_crop_yield/data.py +++ b/tests/data/sustainbench_crop_yield/data.py @@ -15,27 +15,27 @@ np.random.seed(0) -countries = ["argentina", "brazil", "usa"] -splits = ["train", "dev", "test"] +countries = ['argentina', 'brazil', 'usa'] +splits = ['train', 'dev', 'test'] -root_dir = "soybeans" +root_dir = 'soybeans' def create_files(path: str, split: str) -> None: hist_img = np.random.random(size=(NUM_SAMPLES, SIZE, SIZE, NUM_BANDS)) - np.savez(os.path.join(path, f"{split}_hists.npz"), data=hist_img) + np.savez(os.path.join(path, f'{split}_hists.npz'), data=hist_img) target = np.random.random(size=(NUM_SAMPLES, 1)) - np.savez(os.path.join(path, f"{split}_yields.npz"), data=target) + np.savez(os.path.join(path, f'{split}_yields.npz'), data=target) ndvi = np.random.random(size=(NUM_SAMPLES, SIZE)) - np.savez(os.path.join(path, f"{split}_ndvi.npz"), data=ndvi) + np.savez(os.path.join(path, f'{split}_ndvi.npz'), data=ndvi) - year = np.array(["2009"] * NUM_SAMPLES, dtype=" None: for split in splits: create_files(dir, split) - filename = root_dir + ".zip" + filename = root_dir + '.zip' # Compress data - shutil.make_archive(filename.replace(".zip", ""), "zip", ".", root_dir) + shutil.make_archive(filename.replace('.zip', ''), 'zip', '.', root_dir) # Compute checksums - with open(filename, "rb") as f: + with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{filename}: {md5}") + print(f'{filename}: {md5}') diff --git a/tests/data/usavars/data.py b/tests/data/usavars/data.py index 17a4dca7c0c..f989e3c2a22 100755 --- a/tests/data/usavars/data.py +++ b/tests/data/usavars/data.py @@ -12,44 +12,44 @@ import pandas as pd import rasterio -data_dir = "uar" +data_dir = 'uar' labels = [ - "elevation", - "population", - "treecover", - "income", - "nightlights", - "housing", - "roads", + 'elevation', + 'population', + 'treecover', + 'income', + 'nightlights', + 'housing', + 'roads', ] -splits = ["train", "val", "test"] +splits = ['train', 'val', 'test'] SIZE = 3 def create_file(path: str, dtype: str, num_channels: int) -> None: profile = {} - profile["driver"] = "GTiff" - profile["dtype"] = dtype - profile["count"] = num_channels - profile["crs"] = "epsg:4326" - profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) - profile["height"] = SIZE - profile["width"] = SIZE - profile["compress"] = "lzw" - profile["predictor"] = 2 + profile['driver'] = 'GTiff' + profile['dtype'] = dtype + profile['count'] = num_channels + profile['crs'] = 'epsg:4326' + profile['transform'] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile['height'] = SIZE + profile['width'] = SIZE + profile['compress'] = 'lzw' + profile['predictor'] = 2 Z = np.random.randint( - np.iinfo(profile["dtype"]).max, size=(4, SIZE, SIZE), dtype=profile["dtype"] + np.iinfo(profile['dtype']).max, size=(4, SIZE, SIZE), dtype=profile['dtype'] ) - with rasterio.open(path, "w", **profile) as src: + with rasterio.open(path, 'w', **profile) as src: src.write(Z) # Remove old data -filename = f"{data_dir}.zip" -csvs = glob.glob("*.csv") -txts = glob.glob("*.txt") +filename = f'{data_dir}.zip' +csvs = glob.glob('*.csv') +txts = glob.glob('*.txt') for csv in csvs: os.remove(csv) @@ -62,32 +62,32 @@ def create_file(path: str, dtype: str, num_channels: int) -> None: # Create tifs: os.makedirs(data_dir) -create_file(os.path.join(data_dir, "tile_0,0.tif"), np.uint8, 4) -create_file(os.path.join(data_dir, "tile_0,1.tif"), np.uint8, 4) +create_file(os.path.join(data_dir, 'tile_0,0.tif'), np.uint8, 4) +create_file(os.path.join(data_dir, 'tile_0,1.tif'), np.uint8, 4) # Create labels: -columns = [["ID", "lon", "lat", lab] for lab in labels] -fake_vals = [["0,0", 0.0, 0.0, 0.0], ["0,1", 0.1, 0.1, 1.0]] +columns = [['ID', 'lon', 'lat', lab] for lab in labels] +fake_vals = [['0,0', 0.0, 0.0, 0.0], ['0,1', 0.1, 0.1, 1.0]] for lab, cols in zip(labels, columns): df = pd.DataFrame(fake_vals, columns=cols) - df.to_csv(lab + ".csv") + df.to_csv(lab + '.csv') # Create splits: -with open("train_split.txt", "w") as f: - f.write("tile_0,0.tif" + "\n") - f.write("tile_0,0.tif" + "\n") - f.write("tile_0,0.tif" + "\n") -with open("val_split.txt", "w") as f: - f.write("tile_0,1.tif" + "\n") - f.write("tile_0,1.tif" + "\n") -with open("test_split.txt", "w") as f: - f.write("tile_0,0.tif" + "\n") +with open('train_split.txt', 'w') as f: + f.write('tile_0,0.tif' + '\n') + f.write('tile_0,0.tif' + '\n') + f.write('tile_0,0.tif' + '\n') +with open('val_split.txt', 'w') as f: + f.write('tile_0,1.tif' + '\n') + f.write('tile_0,1.tif' + '\n') +with open('test_split.txt', 'w') as f: + f.write('tile_0,0.tif' + '\n') # Compress data -shutil.make_archive(data_dir, "zip", ".", data_dir) +shutil.make_archive(data_dir, 'zip', '.', data_dir) # Compute checksums -filename = f"{data_dir}.zip" -with open(filename, "rb") as f: +filename = f'{data_dir}.zip' +with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(repr(filename) + ": " + repr(md5) + ",") + print(repr(filename) + ': ' + repr(md5) + ',') diff --git a/tests/data/vector/data.py b/tests/data/vector/data.py index abf63c3ee66..8f09021d135 100755 --- a/tests/data/vector/data.py +++ b/tests/data/vector/data.py @@ -20,35 +20,35 @@ # * outside the dataset bounding box geojson = { - "type": "FeatureCollection", - "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, - "features": [ + 'type': 'FeatureCollection', + 'crs': {'type': 'name', 'properties': {'name': 'urn:ogc:def:crs:OGC:1.3:CRS84'}}, + 'features': [ { - "type": "Feature", - "properties": {"label_id": 1}, - "geometry": { - "type": "Polygon", - "coordinates": [ + 'type': 'Feature', + 'properties': {'label_id': 1}, + 'geometry': { + 'type': 'Polygon', + 'coordinates': [ [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]] ], }, }, { - "type": "Feature", - "properties": {"label_id": 2}, - "geometry": { - "type": "Polygon", - "coordinates": [ + 'type': 'Feature', + 'properties': {'label_id': 2}, + 'geometry': { + 'type': 'Polygon', + 'coordinates': [ [[1.0, 0.0], [1.0, 1.0], [2.0, 1.0], [2.0, 0.0], [1.0, 0.0]] ], }, }, { - "type": "Feature", - "properties": {"label_id": 3}, - "geometry": { - "type": "Polygon", - "coordinates": [ + 'type': 'Feature', + 'properties': {'label_id': 3}, + 'geometry': { + 'type': 'Polygon', + 'coordinates': [ [[0.0, 1.0], [0.0, 2.0], [1.0, 2.0], [1.0, 1.0], [0.0, 1.0]] ], }, @@ -56,5 +56,5 @@ ], } -with open("vector_2024.geojson", "w") as f: +with open('vector_2024.geojson', 'w') as f: json.dump(geojson, f) diff --git a/tests/data/vhr10/data.py b/tests/data/vhr10/data.py index 63e4855a529..44e60966c3b 100755 --- a/tests/data/vhr10/data.py +++ b/tests/data/vhr10/data.py @@ -10,7 +10,7 @@ from PIL import Image from torchvision.datasets.utils import calculate_md5 -ANNOTATION_FILE = {"images": [], "annotations": []} +ANNOTATION_FILE = {'images': [], 'annotations': []} def write_data(path: str, img: np.ndarray) -> None: @@ -20,11 +20,11 @@ def write_data(path: str, img: np.ndarray) -> None: def generate_test_data(root: str, n_imgs: int = 3) -> str: - folder_path = os.path.join(root, "NWPU VHR-10 dataset") - pos_img_dir = os.path.join(folder_path, "positive image set") - neg_img_dir = os.path.join(folder_path, "negative image set") - ann_file = os.path.join(folder_path, "annotations.json") - ann_file2 = os.path.join(root, "annotations.json") + folder_path = os.path.join(root, 'NWPU VHR-10 dataset') + pos_img_dir = os.path.join(folder_path, 'positive image set') + neg_img_dir = os.path.join(folder_path, 'negative image set') + ann_file = os.path.join(folder_path, 'annotations.json') + ann_file2 = os.path.join(root, 'annotations.json') if not os.path.exists(pos_img_dir): os.makedirs(pos_img_dir) @@ -32,53 +32,53 @@ def generate_test_data(root: str, n_imgs: int = 3) -> str: os.makedirs(neg_img_dir) for img_id in range(1, n_imgs + 1): - pos_img_name = os.path.join(pos_img_dir, f"00{img_id}.jpg") - neg_img_name = os.path.join(neg_img_dir, f"00{img_id}.jpg") + pos_img_name = os.path.join(pos_img_dir, f'00{img_id}.jpg') + neg_img_name = os.path.join(neg_img_dir, f'00{img_id}.jpg') - img = np.random.randint(255, size=(8, 8), dtype=np.dtype("uint8")) + img = np.random.randint(255, size=(8, 8), dtype=np.dtype('uint8')) write_data(pos_img_name, img) write_data(neg_img_name, img) img_name = os.path.basename(pos_img_name) - ANNOTATION_FILE["images"].append( - {"file_name": img_name, "height": 8, "width": 8, "id": img_id - 1} + ANNOTATION_FILE['images'].append( + {'file_name': img_name, 'height': 8, 'width': 8, 'id': img_id - 1} ) ann = 0 - for _, img in enumerate(ANNOTATION_FILE["images"]): + for _, img in enumerate(ANNOTATION_FILE['images']): annot = { - "id": ann, - "image_id": img["id"], - "category_id": 1, - "area": 4.0, - "bbox": [4, 4, 2, 2], - "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], - "iscrowd": 0, + 'id': ann, + 'image_id': img['id'], + 'category_id': 1, + 'area': 4.0, + 'bbox': [4, 4, 2, 2], + 'segmentation': [[1, 1, 2, 2, 3, 3, 4, 5, 5]], + 'iscrowd': 0, } - ANNOTATION_FILE["annotations"].append(annot) + ANNOTATION_FILE['annotations'].append(annot) ann += 1 - with open(ann_file, "w") as j: + with open(ann_file, 'w') as j: json.dump(ANNOTATION_FILE, j) - with open(ann_file2, "w") as j: + with open(ann_file2, 'w') as j: json.dump(ANNOTATION_FILE, j) # Create rar file subprocess.run( - ["rar", "a", "NWPU VHR-10 dataset.rar", "-m5", "NWPU VHR-10 dataset"], + ['rar', 'a', 'NWPU VHR-10 dataset.rar', '-m5', 'NWPU VHR-10 dataset'], capture_output=True, check=True, ) annotations_md5 = calculate_md5(ann_file) - archive_md5 = calculate_md5("NWPU VHR-10 dataset.rar") + archive_md5 = calculate_md5('NWPU VHR-10 dataset.rar') shutil.rmtree(folder_path) - return f"archive md5: {archive_md5}, annotation md5: {annotations_md5}" + return f'archive md5: {archive_md5}, annotation md5: {annotations_md5}' -if __name__ == "__main__": +if __name__ == '__main__': md5 = generate_test_data(os.getcwd(), 5) print(md5) diff --git a/tests/data/western_usa_live_fuel_moisture/data.py b/tests/data/western_usa_live_fuel_moisture/data.py index d64c445ed41..44fc8717b47 100755 --- a/tests/data/western_usa_live_fuel_moisture/data.py +++ b/tests/data/western_usa_live_fuel_moisture/data.py @@ -11,198 +11,198 @@ NUM_SAMPLES = 3 -data_dir = "su_sar_moisture_content" +data_dir = 'su_sar_moisture_content' LABELS = { - "type": "Feature", - "properties": { - "percent(t)": 132.6666667, - "site": "Blackstone", - "date": "6/30/15", - "slope(t)": 0.599961042, - "elevation(t)": 1522.0, - "canopy_height(t)": 0.0, - "forest_cover(t)": 130.0, - "silt(t)": 36.0, - "sand(t)": 38.0, - "clay(t)": 26.0, - "vv(t)": -12.80108143, - "vh(t)": -20.86413967, - "red(t)": 2007.5, - "green(t)": 1669.5, - "blue(t)": 1234.5, - "swir(t)": 3226.5, - "nir(t)": 2764.5, - "ndvi(t)": 0.158611467, - "ndwi(t)": -0.07713057, - "nirv(t)": 438.5596345, - "vv_red(t)": -0.006376628, - "vv_green(t)": -0.007667614, - "vv_blue(t)": -0.010369446, - "vv_swir(t)": -0.003967482, - "vv_nir(t)": -0.004630523, - "vv_ndvi(t)": -80.70716267, - "vv_ndwi(t)": 165.9663796, - "vv_nirv(t)": -0.029188919, - "vh_red(t)": -0.010393096, - "vh_green(t)": -0.012497238, - "vh_blue(t)": -0.016900883, - "vh_swir(t)": -0.006466493, - "vh_nir(t)": -0.007547166, - "vh_ndvi(t)": -131.5424422, - "vh_ndwi(t)": 270.5041557, - "vh_nirv(t)": -0.047574236, - "vh_vv(t)": -8.063058239, - "slope(t-1)": 0.599961042, - "elevation(t-1)": 1522.0, - "canopy_height(t-1)": 0.0, - "forest_cover(t-1)": 130.0, - "silt(t-1)": 36.0, - "sand(t-1)": 38.0, - "clay(t-1)": 26.0, - "vv(t-1)": -12.93716855, - "vh(t-1)": -20.92368901, - "red(t-1)": 1792.0, - "green(t-1)": 1490.0, - "blue(t-1)": 1102.5, - "swir(t-1)": 3047.0, - "nir(t-1)": 2574.0, - "ndvi(t-1)": 0.179116009, - "ndwi(t-1)": -0.084146807, - "nirv(t-1)": 461.0691997, - "vv_red(t-1)": -0.007219402, - "vv_green(t-1)": -0.008682663, - "vv_blue(t-1)": -0.011734393, - "vv_swir(t-1)": -0.004245871, - "vv_nir(t-1)": -0.005026095, - "vv_ndvi(t-1)": -72.22787422, - "vv_ndwi(t-1)": 153.7452097, - "vv_nirv(t-1)": -0.02805906, - "vh_red(t-1)": -0.011676166, - "vh_green(t-1)": -0.014042744, - "vh_blue(t-1)": -0.018978403, - "vh_swir(t-1)": -0.00686698, - "vh_nir(t-1)": -0.008128861, - "vh_ndvi(t-1)": -116.8164094, - "vh_ndwi(t-1)": 248.6569562, - "vh_nirv(t-1)": -0.0453808, - "vh_vv(t-1)": -7.986520458, - "slope(t-2)": 0.599961042, - "elevation(t-2)": 1522.0, - "canopy_height(t-2)": 0.0, - "forest_cover(t-2)": 130.0, - "silt(t-2)": 36.0, - "sand(t-2)": 38.0, - "clay(t-2)": 26.0, - "vv(t-2)": -13.07325567, - "vh(t-2)": -20.98323835, - "red(t-2)": 1721.5, - "green(t-2)": 1432.0, - "blue(t-2)": 1056.5, - "swir(t-2)": 2950.0, - "nir(t-2)": 2476.0, - "ndvi(t-2)": 0.179768568, - "ndwi(t-2)": -0.087357002, - "nirv(t-2)": 445.0984812, - "vv_red(t-2)": -0.007594107, - "vv_green(t-2)": -0.009129368, - "vv_blue(t-2)": -0.012374118, - "vv_swir(t-2)": -0.004431612, - "vv_nir(t-2)": -0.00527999, - "vv_ndvi(t-2)": -72.72270011, - "vv_ndwi(t-2)": 149.6532084, - "vv_nirv(t-2)": -0.029371603, - "vh_red(t-2)": -0.012188927, - "vh_green(t-2)": -0.014653099, - "vh_blue(t-2)": -0.019861087, - "vh_swir(t-2)": -0.007112962, - "vh_nir(t-2)": -0.008474652, - "vh_ndvi(t-2)": -116.7236217, - "vh_ndwi(t-2)": 240.2009889, - "vh_nirv(t-2)": -0.047142912, - "vh_vv(t-2)": -7.909982677, - "slope(t-3)": 0.599961042, - "elevation(t-3)": 1522.0, - "canopy_height(t-3)": 0.0, - "forest_cover(t-3)": 130.0, - "silt(t-3)": 36.0, - "sand(t-3)": 38.0, - "clay(t-3)": 26.0, - "vv(t-3)": -12.35794964, - "vh(t-3)": -20.25746909, - "red(t-3)": 1367.333333, - "green(t-3)": 1151.0, - "blue(t-3)": 827.3333333, - "swir(t-3)": 2349.333333, - "nir(t-3)": 2051.0, - "ndvi(t-3)": 0.216978329, - "ndwi(t-3)": -0.050717071, - "nirv(t-3)": 413.3885932, - "vv_red(t-3)": -0.009037993, - "vv_green(t-3)": -0.010736707, - "vv_blue(t-3)": -0.014937087, - "vv_swir(t-3)": -0.005260194, - "vv_nir(t-3)": -0.006025329, - "vv_ndvi(t-3)": -56.95476465, - "vv_ndwi(t-3)": 243.6644995, - "vv_nirv(t-3)": -0.029894269, - "vh_red(t-3)": -0.014815311, - "vh_green(t-3)": -0.017599886, - "vh_blue(t-3)": -0.024485257, - "vh_swir(t-3)": -0.008622646, - "vh_nir(t-3)": -0.009876874, - "vh_ndvi(t-3)": -93.36171601, - "vh_ndwi(t-3)": 399.4211186, - "vh_nirv(t-3)": -0.049003454, - "vh_vv(t-3)": -7.899519455, + 'type': 'Feature', + 'properties': { + 'percent(t)': 132.6666667, + 'site': 'Blackstone', + 'date': '6/30/15', + 'slope(t)': 0.599961042, + 'elevation(t)': 1522.0, + 'canopy_height(t)': 0.0, + 'forest_cover(t)': 130.0, + 'silt(t)': 36.0, + 'sand(t)': 38.0, + 'clay(t)': 26.0, + 'vv(t)': -12.80108143, + 'vh(t)': -20.86413967, + 'red(t)': 2007.5, + 'green(t)': 1669.5, + 'blue(t)': 1234.5, + 'swir(t)': 3226.5, + 'nir(t)': 2764.5, + 'ndvi(t)': 0.158611467, + 'ndwi(t)': -0.07713057, + 'nirv(t)': 438.5596345, + 'vv_red(t)': -0.006376628, + 'vv_green(t)': -0.007667614, + 'vv_blue(t)': -0.010369446, + 'vv_swir(t)': -0.003967482, + 'vv_nir(t)': -0.004630523, + 'vv_ndvi(t)': -80.70716267, + 'vv_ndwi(t)': 165.9663796, + 'vv_nirv(t)': -0.029188919, + 'vh_red(t)': -0.010393096, + 'vh_green(t)': -0.012497238, + 'vh_blue(t)': -0.016900883, + 'vh_swir(t)': -0.006466493, + 'vh_nir(t)': -0.007547166, + 'vh_ndvi(t)': -131.5424422, + 'vh_ndwi(t)': 270.5041557, + 'vh_nirv(t)': -0.047574236, + 'vh_vv(t)': -8.063058239, + 'slope(t-1)': 0.599961042, + 'elevation(t-1)': 1522.0, + 'canopy_height(t-1)': 0.0, + 'forest_cover(t-1)': 130.0, + 'silt(t-1)': 36.0, + 'sand(t-1)': 38.0, + 'clay(t-1)': 26.0, + 'vv(t-1)': -12.93716855, + 'vh(t-1)': -20.92368901, + 'red(t-1)': 1792.0, + 'green(t-1)': 1490.0, + 'blue(t-1)': 1102.5, + 'swir(t-1)': 3047.0, + 'nir(t-1)': 2574.0, + 'ndvi(t-1)': 0.179116009, + 'ndwi(t-1)': -0.084146807, + 'nirv(t-1)': 461.0691997, + 'vv_red(t-1)': -0.007219402, + 'vv_green(t-1)': -0.008682663, + 'vv_blue(t-1)': -0.011734393, + 'vv_swir(t-1)': -0.004245871, + 'vv_nir(t-1)': -0.005026095, + 'vv_ndvi(t-1)': -72.22787422, + 'vv_ndwi(t-1)': 153.7452097, + 'vv_nirv(t-1)': -0.02805906, + 'vh_red(t-1)': -0.011676166, + 'vh_green(t-1)': -0.014042744, + 'vh_blue(t-1)': -0.018978403, + 'vh_swir(t-1)': -0.00686698, + 'vh_nir(t-1)': -0.008128861, + 'vh_ndvi(t-1)': -116.8164094, + 'vh_ndwi(t-1)': 248.6569562, + 'vh_nirv(t-1)': -0.0453808, + 'vh_vv(t-1)': -7.986520458, + 'slope(t-2)': 0.599961042, + 'elevation(t-2)': 1522.0, + 'canopy_height(t-2)': 0.0, + 'forest_cover(t-2)': 130.0, + 'silt(t-2)': 36.0, + 'sand(t-2)': 38.0, + 'clay(t-2)': 26.0, + 'vv(t-2)': -13.07325567, + 'vh(t-2)': -20.98323835, + 'red(t-2)': 1721.5, + 'green(t-2)': 1432.0, + 'blue(t-2)': 1056.5, + 'swir(t-2)': 2950.0, + 'nir(t-2)': 2476.0, + 'ndvi(t-2)': 0.179768568, + 'ndwi(t-2)': -0.087357002, + 'nirv(t-2)': 445.0984812, + 'vv_red(t-2)': -0.007594107, + 'vv_green(t-2)': -0.009129368, + 'vv_blue(t-2)': -0.012374118, + 'vv_swir(t-2)': -0.004431612, + 'vv_nir(t-2)': -0.00527999, + 'vv_ndvi(t-2)': -72.72270011, + 'vv_ndwi(t-2)': 149.6532084, + 'vv_nirv(t-2)': -0.029371603, + 'vh_red(t-2)': -0.012188927, + 'vh_green(t-2)': -0.014653099, + 'vh_blue(t-2)': -0.019861087, + 'vh_swir(t-2)': -0.007112962, + 'vh_nir(t-2)': -0.008474652, + 'vh_ndvi(t-2)': -116.7236217, + 'vh_ndwi(t-2)': 240.2009889, + 'vh_nirv(t-2)': -0.047142912, + 'vh_vv(t-2)': -7.909982677, + 'slope(t-3)': 0.599961042, + 'elevation(t-3)': 1522.0, + 'canopy_height(t-3)': 0.0, + 'forest_cover(t-3)': 130.0, + 'silt(t-3)': 36.0, + 'sand(t-3)': 38.0, + 'clay(t-3)': 26.0, + 'vv(t-3)': -12.35794964, + 'vh(t-3)': -20.25746909, + 'red(t-3)': 1367.333333, + 'green(t-3)': 1151.0, + 'blue(t-3)': 827.3333333, + 'swir(t-3)': 2349.333333, + 'nir(t-3)': 2051.0, + 'ndvi(t-3)': 0.216978329, + 'ndwi(t-3)': -0.050717071, + 'nirv(t-3)': 413.3885932, + 'vv_red(t-3)': -0.009037993, + 'vv_green(t-3)': -0.010736707, + 'vv_blue(t-3)': -0.014937087, + 'vv_swir(t-3)': -0.005260194, + 'vv_nir(t-3)': -0.006025329, + 'vv_ndvi(t-3)': -56.95476465, + 'vv_ndwi(t-3)': 243.6644995, + 'vv_nirv(t-3)': -0.029894269, + 'vh_red(t-3)': -0.014815311, + 'vh_green(t-3)': -0.017599886, + 'vh_blue(t-3)': -0.024485257, + 'vh_swir(t-3)': -0.008622646, + 'vh_nir(t-3)': -0.009876874, + 'vh_ndvi(t-3)': -93.36171601, + 'vh_ndwi(t-3)': 399.4211186, + 'vh_nirv(t-3)': -0.049003454, + 'vh_vv(t-3)': -7.899519455, }, - "geometry": {"type": "Point", "coordinates": [-115.8855556, 42.44111111]}, + 'geometry': {'type': 'Point', 'coordinates': [-115.8855556, 42.44111111]}, } STAC = { - "assets": { - "documentation": { - "href": "../_common/documentation.pdf", - "type": "application/pdf", + 'assets': { + 'documentation': { + 'href': '../_common/documentation.pdf', + 'type': 'application/pdf', }, - "labels": {"href": "labels.geojson", "type": "application/geo+json"}, - "training_features_descriptions": { - "href": "../_common/training_features_descriptions.csv", - "title": "Training Features Descriptions", - "type": "text/csv", + 'labels': {'href': 'labels.geojson', 'type': 'application/geo+json'}, + 'training_features_descriptions': { + 'href': '../_common/training_features_descriptions.csv', + 'title': 'Training Features Descriptions', + 'type': 'text/csv', }, }, - "bbox": [-115.8855556, 42.44111111, -115.8855556, 42.44111111], - "collection": "su_sar_moisture_content", - "geometry": {"coordinates": [-115.8855556, 42.44111111], "type": "Point"}, - "id": "su_sar_moisture_content_0001", - "links": [ - {"href": "../collection.json", "rel": "collection"}, - {"href": "../collection.json", "rel": "parent"}, + 'bbox': [-115.8855556, 42.44111111, -115.8855556, 42.44111111], + 'collection': 'su_sar_moisture_content', + 'geometry': {'coordinates': [-115.8855556, 42.44111111], 'type': 'Point'}, + 'id': 'su_sar_moisture_content_0001', + 'links': [ + {'href': '../collection.json', 'rel': 'collection'}, + {'href': '../collection.json', 'rel': 'parent'}, ], - "properties": { - "datetime": "2015-06-30T00:00:00Z", - "label:description": "", - "label:properties": ["percent(t)"], - "label:type": "vector", + 'properties': { + 'datetime': '2015-06-30T00:00:00Z', + 'label:description': '', + 'label:properties': ['percent(t)'], + 'label:type': 'vector', }, - "stac_extensions": ["label"], - "stac_version": "1.0.0-beta.2", - "type": "Feature", + 'stac_extensions': ['label'], + 'stac_version': '1.0.0-beta.2', + 'type': 'Feature', } def create_file(path: str) -> None: - label_path = os.path.join(path, "labels.geojson") - with open(label_path, "w") as f: + label_path = os.path.join(path, 'labels.geojson') + with open(label_path, 'w') as f: json.dump(LABELS, f) - stac_path = os.path.join(path, "stac.json") - with open(stac_path, "w") as f: + stac_path = os.path.join(path, 'stac.json') + with open(stac_path, 'w') as f: json.dump(STAC, f) -if __name__ == "__main__": +if __name__ == '__main__': # Remove old data if os.path.isdir(data_dir): shutil.rmtree(data_dir) @@ -210,14 +210,14 @@ def create_file(path: str) -> None: os.makedirs(os.path.join(os.getcwd(), data_dir)) for i in range(NUM_SAMPLES): - sample_dir = os.path.join(data_dir, data_dir + f"_{i}") + sample_dir = os.path.join(data_dir, data_dir + f'_{i}') os.makedirs(sample_dir) create_file(sample_dir) # Compress data - shutil.make_archive(data_dir, "gztar", ".", data_dir) + shutil.make_archive(data_dir, 'gztar', '.', data_dir) # Compute checksums - with open(data_dir + ".tar.gz", "rb") as f: + with open(data_dir + '.tar.gz', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f"{data_dir}.tar.gz: {md5}") + print(f'{data_dir}.tar.gz: {md5}') diff --git a/tests/data/zuericrop/data.py b/tests/data/zuericrop/data.py index afe518a48c3..efaa580a4cd 100755 --- a/tests/data/zuericrop/data.py +++ b/tests/data/zuericrop/data.py @@ -19,8 +19,8 @@ np.random.seed(0) -data_file = "ZueriCrop.hdf5" -labels_file = "labels.csv" +data_file = 'ZueriCrop.hdf5' +labels_file = 'labels.csv' # Remove old data if os.path.exists(data_file): @@ -41,7 +41,7 @@ NUM_CLASSES, size=(NUM_SAMPLES, SIZE, SIZE, 1), dtype=np.int32 ) -with h5py.File(data_file, "w") as f: - f.create_dataset("data", data=data) - f.create_dataset("gt", data=gt) - f.create_dataset("gt_instance", data=gt_instance) +with h5py.File(data_file, 'w') as f: + f.create_dataset('data', data=data) + f.create_dataset('gt', data=gt) + f.create_dataset('gt_instance', data=gt_instance) diff --git a/tests/datamodules/test_chesapeake.py b/tests/datamodules/test_chesapeake.py index ecf30775544..cd87af18244 100644 --- a/tests/datamodules/test_chesapeake.py +++ b/tests/datamodules/test_chesapeake.py @@ -10,12 +10,12 @@ class TestChesapeakeCVPRDataModule: def test_invalid_param_config(self) -> None: - with pytest.raises(ValueError, match="The pre-generated prior labels"): + with pytest.raises(ValueError, match='The pre-generated prior labels'): ChesapeakeCVPRDataModule( - root=os.path.join("tests", "data", "chesapeake", "cvpr"), - train_splits=["de-test"], - val_splits=["de-test"], - test_splits=["de-test"], + root=os.path.join('tests', 'data', 'chesapeake', 'cvpr'), + train_splits=['de-test'], + val_splits=['de-test'], + test_splits=['de-test'], batch_size=2, patch_size=32, length=4, diff --git a/tests/datamodules/test_fair1m.py b/tests/datamodules/test_fair1m.py index 1a577695d4e..4aa1e16d846 100644 --- a/tests/datamodules/test_fair1m.py +++ b/tests/datamodules/test_fair1m.py @@ -12,31 +12,31 @@ class TestFAIR1MDataModule: @pytest.fixture def datamodule(self) -> FAIR1MDataModule: - root = os.path.join("tests", "data", "fair1m") + root = os.path.join('tests', 'data', 'fair1m') batch_size = 2 num_workers = 0 dm = FAIR1MDataModule(root=root, batch_size=batch_size, num_workers=num_workers) return dm def test_train_dataloader(self, datamodule: FAIR1MDataModule) -> None: - datamodule.setup("fit") + datamodule.setup('fit') next(iter(datamodule.train_dataloader())) def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None: - datamodule.setup("validate") + datamodule.setup('validate') next(iter(datamodule.val_dataloader())) def test_predict_dataloader(self, datamodule: FAIR1MDataModule) -> None: - datamodule.setup("predict") + datamodule.setup('predict') next(iter(datamodule.predict_dataloader())) def test_plot(self, datamodule: FAIR1MDataModule) -> None: - datamodule.setup("validate") + datamodule.setup('validate') batch = next(iter(datamodule.val_dataloader())) sample = { - "image": batch["image"][0], - "boxes": batch["boxes"][0], - "label": batch["label"][0], + 'image': batch['image'][0], + 'boxes': batch['boxes'][0], + 'label': batch['label'][0], } datamodule.plot(sample) plt.close() diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index ef856eddee3..8380ce242b8 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -23,7 +23,7 @@ class CustomGeoDataset(GeoDataset): def __init__( - self, split: str = "train", length: int = 1, download: bool = False + self, split: str = 'train', length: int = 1, download: bool = False ) -> None: super().__init__() for i in range(length): @@ -32,7 +32,7 @@ def __init__( def __getitem__(self, query: BoundingBox) -> dict[str, Any]: image = torch.arange(3 * 2 * 2).view(3, 2, 2) - return {"image": image, "crs": CRS.from_epsg(4326), "bbox": query} + return {'image': image, 'crs': CRS.from_epsg(4326), 'bbox': query} def plot(self, *args: Any, **kwargs: Any) -> Figure: return plt.figure() @@ -63,12 +63,12 @@ def setup(self, stage: str) -> None: class CustomNonGeoDataset(NonGeoDataset): def __init__( - self, split: str = "train", length: int = 1, download: bool = False + self, split: str = 'train', length: int = 1, download: bool = False ) -> None: self.length = length def __getitem__(self, index: int) -> dict[str, Tensor]: - return {"image": torch.arange(3 * 2 * 2).view(3, 2, 2)} + return {'image': torch.arange(3 * 2 * 2).view(3, 2, 2)} def __len__(self) -> int: return self.length @@ -84,7 +84,7 @@ def __init__(self) -> None: def setup(self, stage: str) -> None: super().setup(stage) - if stage in ["predict"]: + if stage in ['predict']: self.predict_dataset = CustomNonGeoDataset() @@ -92,83 +92,83 @@ class TestGeoDataModule: @pytest.fixture(params=[SamplerGeoDataModule, BatchSamplerGeoDataModule]) def datamodule(self, request: SubRequest) -> CustomGeoDataModule: dm: CustomGeoDataModule = request.param() - dm.trainer = Trainer(accelerator="cpu", max_epochs=1) + dm.trainer = Trainer(accelerator='cpu', max_epochs=1) return dm - @pytest.mark.parametrize("stage", ["fit", "validate", "test"]) + @pytest.mark.parametrize('stage', ['fit', 'validate', 'test']) def test_setup(self, stage: str) -> None: dm = CustomGeoDataModule() dm.prepare_data() dm.setup(stage) def test_train(self, datamodule: CustomGeoDataModule) -> None: - datamodule.setup("fit") + datamodule.setup('fit') if datamodule.trainer: datamodule.trainer.training = True batch = next(iter(datamodule.train_dataloader())) - batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1) + batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1) batch = datamodule.on_after_batch_transfer(batch, 0) def test_val(self, datamodule: CustomGeoDataModule) -> None: - datamodule.setup("validate") + datamodule.setup('validate') if datamodule.trainer: datamodule.trainer.validating = True batch = next(iter(datamodule.val_dataloader())) - batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1) + batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1) batch = datamodule.on_after_batch_transfer(batch, 0) def test_test(self, datamodule: CustomGeoDataModule) -> None: - datamodule.setup("test") + datamodule.setup('test') if datamodule.trainer: datamodule.trainer.testing = True batch = next(iter(datamodule.test_dataloader())) - batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1) + batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1) batch = datamodule.on_after_batch_transfer(batch, 0) def test_predict(self, datamodule: CustomGeoDataModule) -> None: - datamodule.setup("predict") + datamodule.setup('predict') if datamodule.trainer: datamodule.trainer.predicting = True batch = next(iter(datamodule.predict_dataloader())) - batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1) + batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1) batch = datamodule.on_after_batch_transfer(batch, 0) def test_plot(self, datamodule: CustomGeoDataModule) -> None: - datamodule.setup("validate") + datamodule.setup('validate') datamodule.plot() plt.close() def test_no_datasets(self) -> None: dm = CustomGeoDataModule() - msg = r"CustomGeoDataModule\.setup must define one of " + msg = r'CustomGeoDataModule\.setup must define one of ' msg += r"\('{0}_dataset', 'dataset'\)\." - with pytest.raises(MisconfigurationException, match=msg.format("train")): + with pytest.raises(MisconfigurationException, match=msg.format('train')): dm.train_dataloader() - with pytest.raises(MisconfigurationException, match=msg.format("val")): + with pytest.raises(MisconfigurationException, match=msg.format('val')): dm.val_dataloader() - with pytest.raises(MisconfigurationException, match=msg.format("test")): + with pytest.raises(MisconfigurationException, match=msg.format('test')): dm.test_dataloader() - with pytest.raises(MisconfigurationException, match=msg.format("predict")): + with pytest.raises(MisconfigurationException, match=msg.format('predict')): dm.predict_dataloader() def test_no_samplers(self) -> None: dm = CustomGeoDataModule() dm.dataset = CustomGeoDataset() - msg = r"CustomGeoDataModule\.setup must define one of " + msg = r'CustomGeoDataModule\.setup must define one of ' msg += r"\('{0}_batch_sampler', '{0}_sampler', 'batch_sampler', 'sampler'\)\." - with pytest.raises(MisconfigurationException, match=msg.format("train")): + with pytest.raises(MisconfigurationException, match=msg.format('train')): dm.train_dataloader() - with pytest.raises(MisconfigurationException, match=msg.format("val")): + with pytest.raises(MisconfigurationException, match=msg.format('val')): dm.val_dataloader() - with pytest.raises(MisconfigurationException, match=msg.format("test")): + with pytest.raises(MisconfigurationException, match=msg.format('test')): dm.test_dataloader() - with pytest.raises(MisconfigurationException, match=msg.format("predict")): + with pytest.raises(MisconfigurationException, match=msg.format('predict')): dm.predict_dataloader() def test_zero_length_dataset(self) -> None: dm = CustomGeoDataModule() dm.dataset = CustomGeoDataset(length=0) - msg = r"CustomGeoDataModule\.dataset has length 0." + msg = r'CustomGeoDataModule\.dataset has length 0.' with pytest.raises(MisconfigurationException, match=msg): dm.train_dataloader() with pytest.raises(MisconfigurationException, match=msg): @@ -183,7 +183,7 @@ def test_zero_length_sampler(self) -> None: dm.dataset = CustomGeoDataset() dm.sampler = RandomGeoSampler(dm.dataset, 1, 1) dm.sampler.length = 0 - msg = r"CustomGeoDataModule\.sampler has length 0." + msg = r'CustomGeoDataModule\.sampler has length 0.' with pytest.raises(MisconfigurationException, match=msg): dm.train_dataloader() with pytest.raises(MisconfigurationException, match=msg): @@ -198,65 +198,65 @@ class TestNonGeoDataModule: @pytest.fixture def datamodule(self) -> CustomNonGeoDataModule: dm = CustomNonGeoDataModule() - dm.trainer = Trainer(accelerator="cpu", max_epochs=1) + dm.trainer = Trainer(accelerator='cpu', max_epochs=1) return dm - @pytest.mark.parametrize("stage", ["fit", "validate", "test", "predict"]) + @pytest.mark.parametrize('stage', ['fit', 'validate', 'test', 'predict']) def test_setup(self, stage: str) -> None: dm = CustomNonGeoDataModule() dm.prepare_data() dm.setup(stage) def test_train(self, datamodule: CustomNonGeoDataModule) -> None: - datamodule.setup("fit") + datamodule.setup('fit') if datamodule.trainer: datamodule.trainer.training = True batch = next(iter(datamodule.train_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) def test_val(self, datamodule: CustomNonGeoDataModule) -> None: - datamodule.setup("validate") + datamodule.setup('validate') if datamodule.trainer: datamodule.trainer.validating = True batch = next(iter(datamodule.val_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) def test_test(self, datamodule: CustomNonGeoDataModule) -> None: - datamodule.setup("test") + datamodule.setup('test') if datamodule.trainer: datamodule.trainer.testing = True batch = next(iter(datamodule.test_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) def test_predict(self, datamodule: CustomNonGeoDataModule) -> None: - datamodule.setup("predict") + datamodule.setup('predict') if datamodule.trainer: datamodule.trainer.predicting = True batch = next(iter(datamodule.predict_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) def test_plot(self, datamodule: CustomNonGeoDataModule) -> None: - datamodule.setup("validate") + datamodule.setup('validate') datamodule.plot() plt.close() def test_no_datasets(self) -> None: dm = CustomNonGeoDataModule() - msg = r"CustomNonGeoDataModule\.setup must define one of " + msg = r'CustomNonGeoDataModule\.setup must define one of ' msg += r"\('{0}_dataset', 'dataset'\)\." - with pytest.raises(MisconfigurationException, match=msg.format("train")): + with pytest.raises(MisconfigurationException, match=msg.format('train')): dm.train_dataloader() - with pytest.raises(MisconfigurationException, match=msg.format("val")): + with pytest.raises(MisconfigurationException, match=msg.format('val')): dm.val_dataloader() - with pytest.raises(MisconfigurationException, match=msg.format("test")): + with pytest.raises(MisconfigurationException, match=msg.format('test')): dm.test_dataloader() - with pytest.raises(MisconfigurationException, match=msg.format("predict")): + with pytest.raises(MisconfigurationException, match=msg.format('predict')): dm.predict_dataloader() def test_zero_length_dataset(self) -> None: dm = CustomNonGeoDataModule() dm.dataset = CustomNonGeoDataset(length=0) - msg = r"CustomNonGeoDataModule\.dataset has length 0." + msg = r'CustomNonGeoDataModule\.dataset has length 0.' with pytest.raises(MisconfigurationException, match=msg): dm.train_dataloader() with pytest.raises(MisconfigurationException, match=msg): diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py index 69deb82eb56..4a026f62caf 100644 --- a/tests/datamodules/test_levircd.py +++ b/tests/datamodules/test_levircd.py @@ -22,20 +22,20 @@ def download_url(url: str, root: str, *args: str) -> None: def transforms(sample: dict[str, Tensor]) -> dict[str, Tensor]: - sample["image1"] = F.resize( - sample["image1"], + sample['image1'] = F.resize( + sample['image1'], size=[1024, 1024], antialias=True, interpolation=InterpolationMode.BILINEAR, ) - sample["image2"] = F.resize( - sample["image2"], + sample['image2'] = F.resize( + sample['image2'], size=[1024, 1024], antialias=True, interpolation=InterpolationMode.BILINEAR, ) - sample["mask"] = F.resize( - sample["mask"].unsqueeze(dim=0), + sample['mask'] = F.resize( + sample['mask'].unsqueeze(dim=0), size=[1024, 1024], interpolation=InterpolationMode.NEAREST, ) @@ -47,11 +47,11 @@ class TestLEVIRCDPlusDataModule: def datamodule( self, monkeypatch: MonkeyPatch, tmp_path: Path ) -> LEVIRCDPlusDataModule: - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - md5 = "0ccca34310bfe7096dadfbf05b0d180f" - monkeypatch.setattr(LEVIRCDPlus, "md5", md5) - url = os.path.join("tests", "data", "levircd", "levircdplus", "LEVIR-CD+.zip") - monkeypatch.setattr(LEVIRCDPlus, "url", url) + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + md5 = '0ccca34310bfe7096dadfbf05b0d180f' + monkeypatch.setattr(LEVIRCDPlus, 'md5', md5) + url = os.path.join('tests', 'data', 'levircd', 'levircdplus', 'LEVIR-CD+.zip') + monkeypatch.setattr(LEVIRCDPlus, 'url', url) root = str(tmp_path) dm = LEVIRCDPlusDataModule( @@ -63,77 +63,77 @@ def datamodule( transforms=transforms, ) dm.prepare_data() - dm.trainer = Trainer(accelerator="cpu", max_epochs=1) + dm.trainer = Trainer(accelerator='cpu', max_epochs=1) return dm def test_train_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: - datamodule.setup("fit") + datamodule.setup('fit') if datamodule.trainer: datamodule.trainer.training = True batch = next(iter(datamodule.train_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8 - assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8 - assert batch["image1"].shape[1] == 3 - assert batch["image2"].shape[1] == 3 + assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (256, 256) + assert batch['image1'].shape[0] == batch['mask'].shape[0] == 8 + assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (256, 256) + assert batch['image2'].shape[0] == batch['mask'].shape[0] == 8 + assert batch['image1'].shape[1] == 3 + assert batch['image2'].shape[1] == 3 def test_val_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: - datamodule.setup("validate") + datamodule.setup('validate') if datamodule.trainer: datamodule.trainer.validating = True batch = next(iter(datamodule.val_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) if datamodule.val_split_pct > 0.0: assert ( - batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (1024, 1024) + batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (1024, 1024) ) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 + assert batch['image1'].shape[0] == batch['mask'].shape[0] == 1 assert ( - batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (1024, 1024) + batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (1024, 1024) ) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 - assert batch["image1"].shape[1] == 3 - assert batch["image2"].shape[1] == 3 + assert batch['image2'].shape[0] == batch['mask'].shape[0] == 1 + assert batch['image1'].shape[1] == 3 + assert batch['image2'].shape[1] == 3 def test_test_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: - datamodule.setup("test") + datamodule.setup('test') if datamodule.trainer: datamodule.trainer.testing = True batch = next(iter(datamodule.test_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (1024, 1024) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 - assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (1024, 1024) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 - assert batch["image1"].shape[1] == 3 - assert batch["image2"].shape[1] == 3 + assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (1024, 1024) + assert batch['image1'].shape[0] == batch['mask'].shape[0] == 1 + assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (1024, 1024) + assert batch['image2'].shape[0] == batch['mask'].shape[0] == 1 + assert batch['image1'].shape[1] == 3 + assert batch['image2'].shape[1] == 3 class TestLEVIRCDDataModule: @pytest.fixture def datamodule(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> LEVIRCDDataModule: - directory = os.path.join("tests", "data", "levircd", "levircd") + directory = os.path.join('tests', 'data', 'levircd', 'levircd') splits = { - "train": { - "url": os.path.join(directory, "train.zip"), - "filename": "train.zip", - "md5": "7c2e24b3072095519f1be7eb01fae4ff", + 'train': { + 'url': os.path.join(directory, 'train.zip'), + 'filename': 'train.zip', + 'md5': '7c2e24b3072095519f1be7eb01fae4ff', }, - "val": { - "url": os.path.join(directory, "val.zip"), - "filename": "val.zip", - "md5": "5c320223ba88b6fc8ff9d1feebc3b84e", + 'val': { + 'url': os.path.join(directory, 'val.zip'), + 'filename': 'val.zip', + 'md5': '5c320223ba88b6fc8ff9d1feebc3b84e', }, - "test": { - "url": os.path.join(directory, "test.zip"), - "filename": "test.zip", - "md5": "021db72d4486726d6a0702563a617b32", + 'test': { + 'url': os.path.join(directory, 'test.zip'), + 'filename': 'test.zip', + 'md5': '021db72d4486726d6a0702563a617b32', }, } - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - monkeypatch.setattr(LEVIRCD, "splits", splits) + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + monkeypatch.setattr(LEVIRCD, 'splits', splits) root = str(tmp_path) dm = LEVIRCDDataModule( @@ -144,44 +144,44 @@ def datamodule(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> LEVIRCDDataMod transforms=transforms, ) dm.prepare_data() - dm.trainer = Trainer(accelerator="cpu", max_epochs=1) + dm.trainer = Trainer(accelerator='cpu', max_epochs=1) return dm def test_train_dataloader(self, datamodule: LEVIRCDDataModule) -> None: - datamodule.setup("fit") + datamodule.setup('fit') if datamodule.trainer: datamodule.trainer.training = True batch = next(iter(datamodule.train_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8 - assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8 - assert batch["image1"].shape[1] == 3 - assert batch["image2"].shape[1] == 3 + assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (256, 256) + assert batch['image1'].shape[0] == batch['mask'].shape[0] == 8 + assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (256, 256) + assert batch['image2'].shape[0] == batch['mask'].shape[0] == 8 + assert batch['image1'].shape[1] == 3 + assert batch['image2'].shape[1] == 3 def test_val_dataloader(self, datamodule: LEVIRCDDataModule) -> None: - datamodule.setup("validate") + datamodule.setup('validate') if datamodule.trainer: datamodule.trainer.validating = True batch = next(iter(datamodule.val_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (1024, 1024) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 - assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (1024, 1024) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 - assert batch["image1"].shape[1] == 3 - assert batch["image2"].shape[1] == 3 + assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (1024, 1024) + assert batch['image1'].shape[0] == batch['mask'].shape[0] == 1 + assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (1024, 1024) + assert batch['image2'].shape[0] == batch['mask'].shape[0] == 1 + assert batch['image1'].shape[1] == 3 + assert batch['image2'].shape[1] == 3 def test_test_dataloader(self, datamodule: LEVIRCDDataModule) -> None: - datamodule.setup("test") + datamodule.setup('test') if datamodule.trainer: datamodule.trainer.testing = True batch = next(iter(datamodule.test_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (1024, 1024) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 - assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (1024, 1024) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 - assert batch["image1"].shape[1] == 3 - assert batch["image2"].shape[1] == 3 + assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (1024, 1024) + assert batch['image1'].shape[0] == batch['mask'].shape[0] == 1 + assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (1024, 1024) + assert batch['image2'].shape[0] == batch['mask'].shape[0] == 1 + assert batch['image1'].shape[1] == 3 + assert batch['image2'].shape[1] == 3 diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index 0c009a8ec2f..e67bd6d5678 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -15,7 +15,7 @@ class TestOSCDDataModule: @pytest.fixture(params=[OSCD.all_bands, OSCD.rgb_bands]) def datamodule(self, request: SubRequest) -> OSCDDataModule: bands = request.param - root = os.path.join("tests", "data", "oscd") + root = os.path.join('tests', 'data', 'oscd') dm = OSCDDataModule( root=root, download=True, @@ -26,57 +26,57 @@ def datamodule(self, request: SubRequest) -> OSCDDataModule: num_workers=0, ) dm.prepare_data() - dm.trainer = Trainer(accelerator="cpu", max_epochs=1) + dm.trainer = Trainer(accelerator='cpu', max_epochs=1) return dm def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: - datamodule.setup("fit") + datamodule.setup('fit') if datamodule.trainer: datamodule.trainer.training = True batch = next(iter(datamodule.train_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 - assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 + assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) + assert batch['image1'].shape[0] == batch['mask'].shape[0] == 1 + assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) + assert batch['image2'].shape[0] == batch['mask'].shape[0] == 1 if datamodule.bands == OSCD.all_bands: - assert batch["image1"].shape[1] == 13 - assert batch["image2"].shape[1] == 13 + assert batch['image1'].shape[1] == 13 + assert batch['image2'].shape[1] == 13 else: - assert batch["image1"].shape[1] == 3 - assert batch["image2"].shape[1] == 3 + assert batch['image1'].shape[1] == 3 + assert batch['image2'].shape[1] == 3 def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: - datamodule.setup("validate") + datamodule.setup('validate') if datamodule.trainer: datamodule.trainer.validating = True batch = next(iter(datamodule.val_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) if datamodule.val_split_pct > 0.0: - assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 - assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 + assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) + assert batch['image1'].shape[0] == batch['mask'].shape[0] == 1 + assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) + assert batch['image2'].shape[0] == batch['mask'].shape[0] == 1 if datamodule.bands == OSCD.all_bands: - assert batch["image1"].shape[1] == 13 - assert batch["image2"].shape[1] == 13 + assert batch['image1'].shape[1] == 13 + assert batch['image2'].shape[1] == 13 else: - assert batch["image1"].shape[1] == 3 - assert batch["image2"].shape[1] == 3 + assert batch['image1'].shape[1] == 3 + assert batch['image2'].shape[1] == 3 def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: - datamodule.setup("test") + datamodule.setup('test') if datamodule.trainer: datamodule.trainer.testing = True batch = next(iter(datamodule.test_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) - assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1 - assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1 + assert batch['image1'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) + assert batch['image1'].shape[0] == batch['mask'].shape[0] == 1 + assert batch['image2'].shape[-2:] == batch['mask'].shape[-2:] == (2, 2) + assert batch['image2'].shape[0] == batch['mask'].shape[0] == 1 if datamodule.bands == OSCD.all_bands: - assert batch["image1"].shape[1] == 13 - assert batch["image2"].shape[1] == 13 + assert batch['image1'].shape[1] == 13 + assert batch['image2'].shape[1] == 13 else: - assert batch["image1"].shape[1] == 3 - assert batch["image2"].shape[1] == 3 + assert batch['image1'].shape[1] == 3 + assert batch['image2'].shape[1] == 3 diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py index 41b1939efe7..004750c2840 100644 --- a/tests/datamodules/test_usavars.py +++ b/tests/datamodules/test_usavars.py @@ -14,8 +14,8 @@ class TestUSAVarsDataModule: @pytest.fixture def datamodule(self, request: SubRequest) -> USAVarsDataModule: - pytest.importorskip("pandas", minversion="1.1.3") - root = os.path.join("tests", "data", "usavars") + pytest.importorskip('pandas', minversion='1.1.3') + root = os.path.join('tests', 'data', 'usavars') batch_size = 1 num_workers = 0 @@ -26,25 +26,25 @@ def datamodule(self, request: SubRequest) -> USAVarsDataModule: return dm def test_train_dataloader(self, datamodule: USAVarsDataModule) -> None: - datamodule.setup("fit") + datamodule.setup('fit') assert len(datamodule.train_dataloader()) == 3 batch = next(iter(datamodule.train_dataloader())) - assert batch["image"].shape[0] == datamodule.batch_size + assert batch['image'].shape[0] == datamodule.batch_size def test_val_dataloader(self, datamodule: USAVarsDataModule) -> None: - datamodule.setup("validate") + datamodule.setup('validate') assert len(datamodule.val_dataloader()) == 2 batch = next(iter(datamodule.val_dataloader())) - assert batch["image"].shape[0] == datamodule.batch_size + assert batch['image'].shape[0] == datamodule.batch_size def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None: - datamodule.setup("test") + datamodule.setup('test') assert len(datamodule.test_dataloader()) == 1 batch = next(iter(datamodule.test_dataloader())) - assert batch["image"].shape[0] == datamodule.batch_size + assert batch['image'].shape[0] == datamodule.batch_size def test_plot(self, datamodule: USAVarsDataModule) -> None: - datamodule.setup("validate") + datamodule.setup('validate') batch = next(iter(datamodule.val_dataloader())) sample = unbind_samples(batch)[0] datamodule.plot(sample) diff --git a/tests/datamodules/test_utils.py b/tests/datamodules/test_utils.py index b96c6032c67..5b454c3f408 100644 --- a/tests/datamodules/test_utils.py +++ b/tests/datamodules/test_utils.py @@ -33,20 +33,20 @@ def test_group_shuffle_split() -> None: train_indices = [0, 2, 5, 6, 7, 8, 9, 10, 11, 13, 14] test_indices = [1, 3, 4, 12] np.random.seed(0) - alphabet = np.array(list("abc")) + alphabet = np.array(list('abc')) groups = np.random.randint(0, 3, size=(15)) groups = alphabet[groups] - with pytest.raises(ValueError, match="You must specify `train_size` *"): + with pytest.raises(ValueError, match='You must specify `train_size` *'): group_shuffle_split(groups, train_size=None, test_size=None) - with pytest.raises(ValueError, match="`train_size` and `test_size` must sum to 1."): + with pytest.raises(ValueError, match='`train_size` and `test_size` must sum to 1.'): group_shuffle_split(groups, train_size=0.2, test_size=1.0) with pytest.raises( ValueError, - match=re.escape("`train_size` and `test_size` must be in the range (0,1)."), + match=re.escape('`train_size` and `test_size` must be in the range (0,1).'), ): group_shuffle_split(groups, train_size=-0.2, test_size=1.2) - with pytest.raises(ValueError, match="3 groups were found, however the current *"): + with pytest.raises(ValueError, match='3 groups were found, however the current *'): group_shuffle_split(groups, train_size=None, test_size=0.999) test_cases = [(None, 0.2, 42), (0.8, None, 42)] diff --git a/tests/datamodules/test_xview2.py b/tests/datamodules/test_xview2.py index 1a0e158b366..53a230b9627 100644 --- a/tests/datamodules/test_xview2.py +++ b/tests/datamodules/test_xview2.py @@ -13,7 +13,7 @@ class TestXView2DataModule: @pytest.fixture def datamodule(self) -> XView2DataModule: - root = os.path.join("tests", "data", "xview2") + root = os.path.join('tests', 'data', 'xview2') batch_size = 1 num_workers = 0 dm = XView2DataModule( @@ -23,19 +23,19 @@ def datamodule(self) -> XView2DataModule: return dm def test_train_dataloader(self, datamodule: XView2DataModule) -> None: - datamodule.setup("fit") + datamodule.setup('fit') next(iter(datamodule.train_dataloader())) def test_val_dataloader(self, datamodule: XView2DataModule) -> None: - datamodule.setup("validate") + datamodule.setup('validate') next(iter(datamodule.val_dataloader())) def test_test_dataloader(self, datamodule: XView2DataModule) -> None: - datamodule.setup("test") + datamodule.setup('test') next(iter(datamodule.test_dataloader())) def test_plot(self, datamodule: XView2DataModule) -> None: - datamodule.setup("validate") + datamodule.setup('validate') batch = next(iter(datamodule.val_dataloader())) sample = unbind_samples(batch)[0] datamodule.plot(sample) diff --git a/tests/datasets/test_advance.py b/tests/datasets/test_advance.py index bcbaba2500f..efa3ac538f8 100644 --- a/tests/datasets/test_advance.py +++ b/tests/datasets/test_advance.py @@ -24,15 +24,15 @@ def download_url(url: str, root: str, *args: str) -> None: class TestADVANCE: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ADVANCE: - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - data_dir = os.path.join("tests", "data", "advance") + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + data_dir = os.path.join('tests', 'data', 'advance') urls = [ - os.path.join(data_dir, "ADVANCE_vision.zip"), - os.path.join(data_dir, "ADVANCE_sound.zip"), + os.path.join(data_dir, 'ADVANCE_vision.zip'), + os.path.join(data_dir, 'ADVANCE_sound.zip'), ] - md5s = ["43acacecebecd17a82bc2c1e719fd7e4", "039b7baa47879a8a4e32b9dd8287f6ad"] - monkeypatch.setattr(ADVANCE, "urls", urls) - monkeypatch.setattr(ADVANCE, "md5s", md5s) + md5s = ['43acacecebecd17a82bc2c1e719fd7e4', '039b7baa47879a8a4e32b9dd8287f6ad'] + monkeypatch.setattr(ADVANCE, 'urls', urls) + monkeypatch.setattr(ADVANCE, 'md5s', md5s) root = str(tmp_path) transforms = nn.Identity() return ADVANCE(root, transforms, download=True, checksum=True) @@ -42,24 +42,24 @@ def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: import_orig = builtins.__import__ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == "scipy.io": + if name == 'scipy.io': raise ImportError() return import_orig(name, *args, **kwargs) - monkeypatch.setattr(builtins, "__import__", mocked_import) + monkeypatch.setattr(builtins, '__import__', mocked_import) def test_getitem(self, dataset: ADVANCE) -> None: - pytest.importorskip("scipy", minversion="1.6.2") + pytest.importorskip('scipy', minversion='1.6.2') x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["audio"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) - assert x["image"].shape[0] == 3 - assert x["image"].ndim == 3 - assert x["audio"].shape[0] == 1 - assert x["audio"].ndim == 2 - assert x["label"].ndim == 0 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['audio'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) + assert x['image'].shape[0] == 3 + assert x['image'].ndim == 3 + assert x['audio'].shape[0] == 1 + assert x['audio'].ndim == 2 + assert x['label'].ndim == 0 def test_len(self, dataset: ADVANCE) -> None: assert len(dataset) == 2 @@ -68,7 +68,7 @@ def test_already_downloaded(self, dataset: ADVANCE) -> None: ADVANCE(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): ADVANCE(str(tmp_path)) def test_mock_missing_module( @@ -76,17 +76,17 @@ def test_mock_missing_module( ) -> None: with pytest.raises( ImportError, - match="scipy is not installed and is required to use this dataset", + match='scipy is not installed and is required to use this dataset', ): dataset[0] def test_plot(self, dataset: ADVANCE) -> None: - pytest.importorskip("scipy", minversion="1.6.2") + pytest.importorskip('scipy', minversion='1.6.2') x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["label"].clone() + x['prediction'] = x['label'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_agb_live_woody_density.py b/tests/datasets/test_agb_live_woody_density.py index 3e0bbbc2dc7..fee1ae4dae8 100644 --- a/tests/datasets/test_agb_live_woody_density.py +++ b/tests/datasets/test_agb_live_woody_density.py @@ -32,15 +32,15 @@ def dataset( ) -> AbovegroundLiveWoodyBiomassDensity: transforms = nn.Identity() monkeypatch.setattr( - torchgeo.datasets.agb_live_woody_density, "download_url", download_url + torchgeo.datasets.agb_live_woody_density, 'download_url', download_url ) url = os.path.join( - "tests", - "data", - "agb_live_woody_density", - "Aboveground_Live_Woody_Biomass_Density.geojson", + 'tests', + 'data', + 'agb_live_woody_density', + 'Aboveground_Live_Woody_Biomass_Density.geojson', ) - monkeypatch.setattr(AbovegroundLiveWoodyBiomassDensity, "url", url) + monkeypatch.setattr(AbovegroundLiveWoodyBiomassDensity, 'url', url) root = str(tmp_path) return AbovegroundLiveWoodyBiomassDensity( @@ -50,11 +50,11 @@ def dataset( def test_getitem(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_no_dataset(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): AbovegroundLiveWoodyBiomassDensity(str(tmp_path)) def test_already_downloaded( @@ -73,12 +73,12 @@ def test_or(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: def test_plot(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: query = dataset.bounds x = dataset[query] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() diff --git a/tests/datasets/test_agrifieldnet.py b/tests/datasets/test_agrifieldnet.py index cdb539671f9..c3eb794f8e4 100644 --- a/tests/datasets/test_agrifieldnet.py +++ b/tests/datasets/test_agrifieldnet.py @@ -23,16 +23,16 @@ class TestAgriFieldNet: @pytest.fixture def dataset(self) -> AgriFieldNet: - path = os.path.join("tests", "data", "agrifieldnet") + path = os.path.join('tests', 'data', 'agrifieldnet') transforms = nn.Identity() return AgriFieldNet(paths=path, transforms=transforms) def test_getitem(self, dataset: AgriFieldNet) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) def test_and(self, dataset: AgriFieldNet) -> None: ds = dataset & dataset @@ -46,32 +46,32 @@ def test_already_downloaded(self, dataset: AgriFieldNet) -> None: AgriFieldNet(paths=dataset.paths) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): AgriFieldNet(str(tmp_path)) def test_plot(self, dataset: AgriFieldNet) -> None: x = dataset[dataset.bounds] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: AgriFieldNet) -> None: x = dataset[dataset.bounds] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() def test_invalid_query(self, dataset: AgriFieldNet) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] def test_rgb_bands_absent_plot(self, dataset: AgriFieldNet) -> None: with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): - ds = AgriFieldNet(dataset.paths, bands=["B01", "B02", "B05"]) + ds = AgriFieldNet(dataset.paths, bands=['B01', 'B02', 'B05']) x = ds[ds.bounds] - ds.plot(x, suptitle="Test") + ds.plot(x, suptitle='Test') plt.close() diff --git a/tests/datasets/test_airphen.py b/tests/datasets/test_airphen.py index ff5ce9602d7..3c60fb090f7 100644 --- a/tests/datasets/test_airphen.py +++ b/tests/datasets/test_airphen.py @@ -23,8 +23,8 @@ class TestAirphen: @pytest.fixture def dataset(self) -> Airphen: - paths = os.path.join("tests", "data", "airphen") - bands = ["B1", "B3", "B4"] + paths = os.path.join('tests', 'data', 'airphen') + bands = ['B1', 'B3', 'B4'] transforms = nn.Identity() return Airphen(paths, bands=bands, transforms=transforms) @@ -34,8 +34,8 @@ def test_len(self, dataset: Airphen) -> None: def test_getitem(self, dataset: Airphen) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["image"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['image'], torch.Tensor) def test_and(self, dataset: Airphen) -> None: ds = dataset & dataset @@ -47,25 +47,25 @@ def test_or(self, dataset: Airphen) -> None: def test_plot(self, dataset: Airphen) -> None: x = dataset[dataset.bounds] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): Airphen(str(tmp_path)) def test_invalid_query(self, dataset: Airphen) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] def test_plot_wrong_bands(self, dataset: Airphen) -> None: - bands = ("B1", "B2", "B3") + bands = ('B1', 'B2', 'B3') ds = Airphen(dataset.paths, bands=bands) x = dataset[dataset.bounds] with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): ds.plot(x) diff --git a/tests/datasets/test_astergdem.py b/tests/datasets/test_astergdem.py index dfd41e40409..a838e0e7daf 100644 --- a/tests/datasets/test_astergdem.py +++ b/tests/datasets/test_astergdem.py @@ -23,8 +23,8 @@ class TestAsterGDEM: @pytest.fixture def dataset(self, tmp_path: Path) -> AsterGDEM: - zipfile = os.path.join("tests", "data", "astergdem", "astergdem.zip") - shutil.unpack_archive(zipfile, tmp_path, "zip") + zipfile = os.path.join('tests', 'data', 'astergdem', 'astergdem.zip') + shutil.unpack_archive(zipfile, tmp_path, 'zip') root = str(tmp_path) transforms = nn.Identity() return AsterGDEM(root, transforms=transforms) @@ -32,14 +32,14 @@ def dataset(self, tmp_path: Path) -> AsterGDEM: def test_datasetmissing(self, tmp_path: Path) -> None: shutil.rmtree(tmp_path) os.makedirs(tmp_path) - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): AsterGDEM(str(tmp_path)) def test_getitem(self, dataset: AsterGDEM) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_and(self, dataset: AsterGDEM) -> None: ds = dataset & dataset @@ -52,19 +52,19 @@ def test_or(self, dataset: AsterGDEM) -> None: def test_plot(self, dataset: AsterGDEM) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: AsterGDEM) -> None: query = dataset.bounds x = dataset[query] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() def test_invalid_query(self, dataset: AsterGDEM) -> None: query = BoundingBox(100, 100, 100, 100, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_benin_cashews.py b/tests/datasets/test_benin_cashews.py index e5e65aa0d01..1e960527a84 100644 --- a/tests/datasets/test_benin_cashews.py +++ b/tests/datasets/test_benin_cashews.py @@ -22,7 +22,7 @@ class Collection: def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join("tests", "data", "ts_cashew_benin", "*.tar.gz") + glob_path = os.path.join('tests', 'data', 'ts_cashew_benin', '*.tar.gz') for tarball in glob.iglob(glob_path): shutil.copy(tarball, output_dir) @@ -36,13 +36,13 @@ class TestBeninSmallHolderCashews: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path ) -> BeninSmallHolderCashews: - radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) - source_md5 = "255efff0f03bc6322470949a09bc76db" - labels_md5 = "ed2195d93ca6822d48eb02bc3e81c127" - monkeypatch.setitem(BeninSmallHolderCashews.image_meta, "md5", source_md5) - monkeypatch.setitem(BeninSmallHolderCashews.target_meta, "md5", labels_md5) - monkeypatch.setattr(BeninSmallHolderCashews, "dates", ("2019_11_05",)) + radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) + source_md5 = '255efff0f03bc6322470949a09bc76db' + labels_md5 = 'ed2195d93ca6822d48eb02bc3e81c127' + monkeypatch.setitem(BeninSmallHolderCashews.image_meta, 'md5', source_md5) + monkeypatch.setitem(BeninSmallHolderCashews.target_meta, 'md5', labels_md5) + monkeypatch.setattr(BeninSmallHolderCashews, 'dates', ('2019_11_05',)) root = str(tmp_path) transforms = nn.Identity() bands = BeninSmallHolderCashews.all_bands @@ -52,7 +52,7 @@ def dataset( transforms=transforms, bands=bands, download=True, - api_key="", + api_key='', checksum=True, verbose=True, ) @@ -60,10 +60,10 @@ def dataset( def test_getitem(self, dataset: BeninSmallHolderCashews) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - assert isinstance(x["x"], torch.Tensor) - assert isinstance(x["y"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + assert isinstance(x['x'], torch.Tensor) + assert isinstance(x['y'], torch.Tensor) def test_len(self, dataset: BeninSmallHolderCashews) -> None: assert len(dataset) == 72 @@ -74,33 +74,33 @@ def test_add(self, dataset: BeninSmallHolderCashews) -> None: assert len(ds) == 144 def test_already_downloaded(self, dataset: BeninSmallHolderCashews) -> None: - BeninSmallHolderCashews(root=dataset.root, download=True, api_key="") + BeninSmallHolderCashews(root=dataset.root, download=True, api_key='') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): BeninSmallHolderCashews(str(tmp_path)) def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): - BeninSmallHolderCashews(bands=["B01", "B02"]) # type: ignore[arg-type] + BeninSmallHolderCashews(bands=['B01', 'B02']) # type: ignore[arg-type] - with pytest.raises(ValueError, match="is an invalid band name."): - BeninSmallHolderCashews(bands=("foo", "bar")) + with pytest.raises(ValueError, match='is an invalid band name.'): + BeninSmallHolderCashews(bands=('foo', 'bar')) def test_plot(self, dataset: BeninSmallHolderCashews) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["mask"].clone() + x['prediction'] = x['mask'].clone() dataset.plot(x) plt.close() def test_failed_plot(self, dataset: BeninSmallHolderCashews) -> None: - single_band_dataset = BeninSmallHolderCashews(root=dataset.root, bands=("B01",)) + single_band_dataset = BeninSmallHolderCashews(root=dataset.root, bands=('B01',)) with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): x = single_band_dataset[0].copy() - single_band_dataset.plot(x, suptitle="Test") + single_band_dataset.plot(x, suptitle='Test') diff --git a/tests/datasets/test_bigearthnet.py b/tests/datasets/test_bigearthnet.py index a0e93952244..82a3655626f 100644 --- a/tests/datasets/test_bigearthnet.py +++ b/tests/datasets/test_bigearthnet.py @@ -22,46 +22,46 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestBigEarthNet: @pytest.fixture( - params=zip(["all", "s1", "s2"], [43, 19, 19], ["train", "val", "test"]) + params=zip(['all', 's1', 's2'], [43, 19, 19], ['train', 'val', 'test']) ) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> BigEarthNet: - monkeypatch.setattr(torchgeo.datasets.bigearthnet, "download_url", download_url) - data_dir = os.path.join("tests", "data", "bigearthnet") + monkeypatch.setattr(torchgeo.datasets.bigearthnet, 'download_url', download_url) + data_dir = os.path.join('tests', 'data', 'bigearthnet') metadata = { - "s1": { - "url": os.path.join(data_dir, "BigEarthNet-S1-v1.0.tar.gz"), - "md5": "5a64e9ce38deb036a435a7b59494924c", - "filename": "BigEarthNet-S1-v1.0.tar.gz", - "directory": "BigEarthNet-S1-v1.0", + 's1': { + 'url': os.path.join(data_dir, 'BigEarthNet-S1-v1.0.tar.gz'), + 'md5': '5a64e9ce38deb036a435a7b59494924c', + 'filename': 'BigEarthNet-S1-v1.0.tar.gz', + 'directory': 'BigEarthNet-S1-v1.0', }, - "s2": { - "url": os.path.join(data_dir, "BigEarthNet-S2-v1.0.tar.gz"), - "md5": "ef5f41129b8308ca178b04d7538dbacf", - "filename": "BigEarthNet-S2-v1.0.tar.gz", - "directory": "BigEarthNet-v1.0", + 's2': { + 'url': os.path.join(data_dir, 'BigEarthNet-S2-v1.0.tar.gz'), + 'md5': 'ef5f41129b8308ca178b04d7538dbacf', + 'filename': 'BigEarthNet-S2-v1.0.tar.gz', + 'directory': 'BigEarthNet-v1.0', }, } splits_metadata = { - "train": { - "url": os.path.join(data_dir, "bigearthnet-train.csv"), - "filename": "bigearthnet-train.csv", - "md5": "167ac4d5de8dde7b5aeaa812f42031e7", + 'train': { + 'url': os.path.join(data_dir, 'bigearthnet-train.csv'), + 'filename': 'bigearthnet-train.csv', + 'md5': '167ac4d5de8dde7b5aeaa812f42031e7', }, - "val": { - "url": os.path.join(data_dir, "bigearthnet-val.csv"), - "filename": "bigearthnet-val.csv", - "md5": "aff594ba256a52e839a3b5fefeb9ef42", + 'val': { + 'url': os.path.join(data_dir, 'bigearthnet-val.csv'), + 'filename': 'bigearthnet-val.csv', + 'md5': 'aff594ba256a52e839a3b5fefeb9ef42', }, - "test": { - "url": os.path.join(data_dir, "bigearthnet-test.csv"), - "filename": "bigearthnet-test.csv", - "md5": "851a6bdda484d47f60e121352dcb1bf5", + 'test': { + 'url': os.path.join(data_dir, 'bigearthnet-test.csv'), + 'filename': 'bigearthnet-test.csv', + 'md5': '851a6bdda484d47f60e121352dcb1bf5', }, } - monkeypatch.setattr(BigEarthNet, "metadata", metadata) - monkeypatch.setattr(BigEarthNet, "splits_metadata", splits_metadata) + monkeypatch.setattr(BigEarthNet, 'metadata', metadata) + monkeypatch.setattr(BigEarthNet, 'splits_metadata', splits_metadata) bands, num_classes, split = request.param root = str(tmp_path) transforms = nn.Identity() @@ -72,23 +72,23 @@ def dataset( def test_getitem(self, dataset: BigEarthNet) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) - assert x["label"].shape == (dataset.num_classes,) - assert x["image"].dtype == torch.float32 - assert x["label"].dtype == torch.int64 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) + assert x['label'].shape == (dataset.num_classes,) + assert x['image'].dtype == torch.float32 + assert x['label'].dtype == torch.int64 - if dataset.bands == "all": - assert x["image"].shape == (14, 120, 120) - elif dataset.bands == "s1": - assert x["image"].shape == (2, 120, 120) + if dataset.bands == 'all': + assert x['image'].shape == (14, 120, 120) + elif dataset.bands == 's1': + assert x['image'].shape == (2, 120, 120) else: - assert x["image"].shape == (12, 120, 120) + assert x['image'].shape == (12, 120, 120) def test_len(self, dataset: BigEarthNet) -> None: - if dataset.split == "train": + if dataset.split == 'train': assert len(dataset) == 2 - elif dataset.split == "val": + elif dataset.split == 'val': assert len(dataset) == 1 else: assert len(dataset) == 1 @@ -105,25 +105,25 @@ def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None: def test_already_downloaded_not_extracted( self, dataset: BigEarthNet, tmp_path: Path ) -> None: - if dataset.bands == "all": + if dataset.bands == 'all': shutil.rmtree( - os.path.join(dataset.root, dataset.metadata["s1"]["directory"]) + os.path.join(dataset.root, dataset.metadata['s1']['directory']) ) shutil.rmtree( - os.path.join(dataset.root, dataset.metadata["s2"]["directory"]) + os.path.join(dataset.root, dataset.metadata['s2']['directory']) ) - download_url(dataset.metadata["s1"]["url"], root=str(tmp_path)) - download_url(dataset.metadata["s2"]["url"], root=str(tmp_path)) - elif dataset.bands == "s1": + download_url(dataset.metadata['s1']['url'], root=str(tmp_path)) + download_url(dataset.metadata['s2']['url'], root=str(tmp_path)) + elif dataset.bands == 's1': shutil.rmtree( - os.path.join(dataset.root, dataset.metadata["s1"]["directory"]) + os.path.join(dataset.root, dataset.metadata['s1']['directory']) ) - download_url(dataset.metadata["s1"]["url"], root=str(tmp_path)) + download_url(dataset.metadata['s1']['url'], root=str(tmp_path)) else: shutil.rmtree( - os.path.join(dataset.root, dataset.metadata["s2"]["directory"]) + os.path.join(dataset.root, dataset.metadata['s2']['directory']) ) - download_url(dataset.metadata["s2"]["url"], root=str(tmp_path)) + download_url(dataset.metadata['s2']['url'], root=str(tmp_path)) BigEarthNet( root=str(tmp_path), @@ -134,15 +134,15 @@ def test_already_downloaded_not_extracted( ) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): BigEarthNet(str(tmp_path)) def test_plot(self, dataset: BigEarthNet) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["label"].clone() + x['prediction'] = x['label'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_biomassters.py b/tests/datasets/test_biomassters.py index 17dab2df03c..8a853145da8 100644 --- a/tests/datasets/test_biomassters.py +++ b/tests/datasets/test_biomassters.py @@ -15,10 +15,10 @@ class TestBioMassters: @pytest.fixture( - params=product(["train", "test"], [["S1"], ["S2"], ["S1", "S2"]], [True, False]) + params=product(['train', 'test'], [['S1'], ['S2'], ['S1', 'S2']], [True, False]) ) def dataset(self, request: SubRequest) -> BioMassters: - root = os.path.join("tests", "data", "biomassters") + root = os.path.join('tests', 'data', 'biomassters') split, sensors, as_time_series = request.param return BioMassters( root, split=split, sensors=sensors, as_time_series=as_time_series @@ -29,22 +29,22 @@ def test_len_of_ds(self, dataset: BioMassters) -> None: def test_invalid_split(self, dataset: BioMassters) -> None: with pytest.raises(AssertionError): - BioMassters(dataset.root, split="foo") + BioMassters(dataset.root, split='foo') def test_invalid_bands(self, dataset: BioMassters) -> None: with pytest.raises(AssertionError): - BioMassters(dataset.root, sensors=["S3"]) + BioMassters(dataset.root, sensors=['S3']) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): BioMassters(str(tmp_path)) def test_plot(self, dataset: BioMassters) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() sample = dataset[0] - if dataset.split == "train": - sample["prediction"] = sample["label"] + if dataset.split == 'train': + sample['prediction'] = sample['label'] dataset.plot(sample) plt.close() diff --git a/tests/datasets/test_cbf.py b/tests/datasets/test_cbf.py index f53023925b5..de3dea399dd 100644 --- a/tests/datasets/test_cbf.py +++ b/tests/datasets/test_cbf.py @@ -31,16 +31,16 @@ class TestCanadianBuildingFootprints: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path ) -> CanadianBuildingFootprints: - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) monkeypatch.setattr( - CanadianBuildingFootprints, "provinces_territories", ["Alberta"] + CanadianBuildingFootprints, 'provinces_territories', ['Alberta'] ) monkeypatch.setattr( - CanadianBuildingFootprints, "md5s", ["25091d1f051baa30d8f2026545cfb696"] + CanadianBuildingFootprints, 'md5s', ['25091d1f051baa30d8f2026545cfb696'] ) - url = os.path.join("tests", "data", "cbf") + os.sep - monkeypatch.setattr(CanadianBuildingFootprints, "url", url) - monkeypatch.setattr(plt, "show", lambda *args: None) + url = os.path.join('tests', 'data', 'cbf') + os.sep + monkeypatch.setattr(CanadianBuildingFootprints, 'url', url) + monkeypatch.setattr(plt, 'show', lambda *args: None) root = str(tmp_path) transforms = nn.Identity() return CanadianBuildingFootprints( @@ -50,8 +50,8 @@ def dataset( def test_getitem(self, dataset: CanadianBuildingFootprints) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_and(self, dataset: CanadianBuildingFootprints) -> None: ds = dataset & dataset @@ -67,21 +67,21 @@ def test_already_downloaded(self, dataset: CanadianBuildingFootprints) -> None: def test_plot(self, dataset: CanadianBuildingFootprints) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') def test_plot_prediction(self, dataset: CanadianBuildingFootprints) -> None: query = dataset.bounds x = dataset[query] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): CanadianBuildingFootprints(str(tmp_path)) def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None: query = BoundingBox(2, 2, 2, 2, 2, 2) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_cdl.py b/tests/datasets/test_cdl.py index 2cf989d7975..6f9758c2569 100644 --- a/tests/datasets/test_cdl.py +++ b/tests/datasets/test_cdl.py @@ -31,16 +31,16 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestCDL: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CDL: - monkeypatch.setattr(torchgeo.datasets.cdl, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.cdl, 'download_url', download_url) md5s = { - 2023: "3fbd3eecf92b8ce1ae35060ada463c6d", - 2022: "826c6fd639d9cdd94a44302fbc5b76c3", + 2023: '3fbd3eecf92b8ce1ae35060ada463c6d', + 2022: '826c6fd639d9cdd94a44302fbc5b76c3', } - monkeypatch.setattr(CDL, "md5s", md5s) - url = os.path.join("tests", "data", "cdl", "{}_30m_cdls.zip") - monkeypatch.setattr(CDL, "url", url) - monkeypatch.setattr(plt, "show", lambda *args: None) + monkeypatch.setattr(CDL, 'md5s', md5s) + url = os.path.join('tests', 'data', 'cdl', '{}_30m_cdls.zip') + monkeypatch.setattr(CDL, 'url', url) + monkeypatch.setattr(plt, 'show', lambda *args: None) root = str(tmp_path) transforms = nn.Identity() return CDL( @@ -54,15 +54,15 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CDL: def test_getitem(self, dataset: CDL) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_classes(self) -> None: - root = os.path.join("tests", "data", "cdl") + root = os.path.join('tests', 'data', 'cdl') classes = list(CDL.cmap.keys())[:5] ds = CDL(root, years=[2023], classes=classes) sample = ds[ds.bounds] - mask = sample["mask"] + mask = sample['mask'] assert mask.max() < len(classes) def test_and(self, dataset: CDL) -> None: @@ -83,7 +83,7 @@ def test_already_extracted(self, dataset: CDL) -> None: CDL(dataset.paths, years=[2023, 2022]) def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "cdl", "*_30m_cdls.zip") + pathname = os.path.join('tests', 'data', 'cdl', '*_30m_cdls.zip') root = str(tmp_path) for zipfile in glob.iglob(pathname): shutil.copy(zipfile, root) @@ -92,7 +92,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: def test_invalid_year(self, tmp_path: Path) -> None: with pytest.raises( AssertionError, - match="CDL data product only exists for the following years:", + match='CDL data product only exists for the following years:', ): CDL(str(tmp_path), years=[1996]) @@ -106,23 +106,23 @@ def test_invalid_classes(self) -> None: def test_plot(self, dataset: CDL) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: CDL) -> None: query = dataset.bounds x = dataset[query] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): CDL(str(tmp_path)) def test_invalid_query(self, dataset: CDL) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_chabud.py b/tests/datasets/test_chabud.py index d7e31b05d9e..955104a53fb 100644 --- a/tests/datasets/test_chabud.py +++ b/tests/datasets/test_chabud.py @@ -17,7 +17,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import ChaBuD, DatasetNotFoundError -pytest.importorskip("h5py", minversion="3") +pytest.importorskip('h5py', minversion='3') def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None: @@ -25,16 +25,16 @@ def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) class TestChaBuD: - @pytest.fixture(params=zip([ChaBuD.all_bands, ChaBuD.rgb_bands], ["train", "val"])) + @pytest.fixture(params=zip([ChaBuD.all_bands, ChaBuD.rgb_bands], ['train', 'val'])) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> ChaBuD: - monkeypatch.setattr(torchgeo.datasets.chabud, "download_url", download_url) - data_dir = os.path.join("tests", "data", "chabud") - url = os.path.join(data_dir, "train_eval.hdf5") - md5 = "1bec048beeb87a865c53f40ab418aa75" - monkeypatch.setattr(ChaBuD, "url", url) - monkeypatch.setattr(ChaBuD, "md5", md5) + monkeypatch.setattr(torchgeo.datasets.chabud, 'download_url', download_url) + data_dir = os.path.join('tests', 'data', 'chabud') + url = os.path.join(data_dir, 'train_eval.hdf5') + md5 = '1bec048beeb87a865c53f40ab418aa75' + monkeypatch.setattr(ChaBuD, 'url', url) + monkeypatch.setattr(ChaBuD, 'md5', md5) bands, split = request.param root = str(tmp_path) transforms = nn.Identity() @@ -52,28 +52,28 @@ def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: import_orig = builtins.__import__ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == "h5py": + if name == 'h5py': raise ImportError() return import_orig(name, *args, **kwargs) - monkeypatch.setattr(builtins, "__import__", mocked_import) + monkeypatch.setattr(builtins, '__import__', mocked_import) def test_getitem(self, dataset: ChaBuD) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) # Image tests - assert x["image"].ndim == 3 + assert x['image'].ndim == 3 if dataset.bands == ChaBuD.rgb_bands: - assert x["image"].shape[0] == 2 * 3 + assert x['image'].shape[0] == 2 * 3 elif dataset.bands == ChaBuD.all_bands: - assert x["image"].shape[0] == 2 * 12 + assert x['image'].shape[0] == 2 * 12 # Mask tests: - assert x["mask"].ndim == 2 + assert x['mask'].ndim == 2 def test_len(self, dataset: ChaBuD) -> None: assert len(dataset) == 4 @@ -82,7 +82,7 @@ def test_already_downloaded(self, dataset: ChaBuD) -> None: ChaBuD(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): ChaBuD(str(tmp_path)) def test_mock_missing_module( @@ -90,24 +90,24 @@ def test_mock_missing_module( ) -> None: with pytest.raises( ImportError, - match="h5py is not installed and is required to use this dataset", + match='h5py is not installed and is required to use this dataset', ): ChaBuD(dataset.root, download=True, checksum=True) def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): - ChaBuD(bands=["OK", "BK"]) + ChaBuD(bands=['OK', 'BK']) def test_plot(self, dataset: ChaBuD) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() sample = dataset[0] - sample["prediction"] = sample["mask"].clone() - dataset.plot(sample, suptitle="prediction") + sample['prediction'] = sample['mask'].clone() + dataset.plot(sample, suptitle='prediction') plt.close() def test_plot_rgb(self, dataset: ChaBuD) -> None: - dataset = ChaBuD(root=dataset.root, bands=["B02"]) + dataset = ChaBuD(root=dataset.root, bands=['B02']) with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"): - dataset.plot(dataset[0], suptitle="Single Band") + dataset.plot(dataset[0], suptitle='Single Band') diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index 0692e0be00f..459c6bcdbdf 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -29,18 +29,18 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestChesapeake13: - pytest.importorskip("zipfile_deflate64") + pytest.importorskip('zipfile_deflate64') @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Chesapeake13: - monkeypatch.setattr(torchgeo.datasets.chesapeake, "download_url", download_url) - md5 = "fe35a615b8e749b21270472aa98bb42c" - monkeypatch.setattr(Chesapeake13, "md5", md5) + monkeypatch.setattr(torchgeo.datasets.chesapeake, 'download_url', download_url) + md5 = 'fe35a615b8e749b21270472aa98bb42c' + monkeypatch.setattr(Chesapeake13, 'md5', md5) url = os.path.join( - "tests", "data", "chesapeake", "BAYWIDE", "Baywide_13Class_20132014.zip" + 'tests', 'data', 'chesapeake', 'BAYWIDE', 'Baywide_13Class_20132014.zip' ) - monkeypatch.setattr(Chesapeake13, "url", url) - monkeypatch.setattr(plt, "show", lambda *args: None) + monkeypatch.setattr(Chesapeake13, 'url', url) + monkeypatch.setattr(plt, 'show', lambda *args: None) root = str(tmp_path) transforms = nn.Identity() return Chesapeake13(root, transforms=transforms, download=True, checksum=True) @@ -48,8 +48,8 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Chesapeake13: def test_getitem(self, dataset: Chesapeake13) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_and(self, dataset: Chesapeake13) -> None: ds = dataset & dataset @@ -64,37 +64,37 @@ def test_already_extracted(self, dataset: Chesapeake13) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: url = os.path.join( - "tests", "data", "chesapeake", "BAYWIDE", "Baywide_13Class_20132014.zip" + 'tests', 'data', 'chesapeake', 'BAYWIDE', 'Baywide_13Class_20132014.zip' ) root = str(tmp_path) shutil.copy(url, root) Chesapeake13(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): Chesapeake13(str(tmp_path), checksum=True) def test_plot(self, dataset: Chesapeake13) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: Chesapeake13) -> None: query = dataset.bounds x = dataset[query] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() def test_url(self) -> None: - ds = Chesapeake13(os.path.join("tests", "data", "chesapeake", "BAYWIDE")) - assert "cicwebresources.blob.core.windows.net" in ds.url + ds = Chesapeake13(os.path.join('tests', 'data', 'chesapeake', 'BAYWIDE')) + assert 'cicwebresources.blob.core.windows.net' in ds.url def test_invalid_query(self, dataset: Chesapeake13) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] @@ -102,54 +102,54 @@ def test_invalid_query(self, dataset: Chesapeake13) -> None: class TestChesapeakeCVPR: @pytest.fixture( params=[ - ("naip-new", "naip-old", "nlcd"), - ("landsat-leaf-on", "landsat-leaf-off", "lc"), - ("naip-new", "landsat-leaf-on", "lc", "nlcd", "buildings"), - ("naip-new", "prior_from_cooccurrences_101_31_no_osm_no_buildings"), + ('naip-new', 'naip-old', 'nlcd'), + ('landsat-leaf-on', 'landsat-leaf-off', 'lc'), + ('naip-new', 'landsat-leaf-on', 'lc', 'nlcd', 'buildings'), + ('naip-new', 'prior_from_cooccurrences_101_31_no_osm_no_buildings'), ] ) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> ChesapeakeCVPR: - monkeypatch.setattr(torchgeo.datasets.chesapeake, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.chesapeake, 'download_url', download_url) monkeypatch.setattr( ChesapeakeCVPR, - "md5s", + 'md5s', { - "base": "882d18b1f15ea4498bf54e674aecd5d4", - "prior_extension": "677446c486f3145787938b14ee3da13f", + 'base': '882d18b1f15ea4498bf54e674aecd5d4', + 'prior_extension': '677446c486f3145787938b14ee3da13f', }, ) monkeypatch.setattr( ChesapeakeCVPR, - "urls", + 'urls', { - "base": os.path.join( - "tests", - "data", - "chesapeake", - "cvpr", - "cvpr_chesapeake_landcover.zip", + 'base': os.path.join( + 'tests', + 'data', + 'chesapeake', + 'cvpr', + 'cvpr_chesapeake_landcover.zip', ), - "prior_extension": os.path.join( - "tests", - "data", - "chesapeake", - "cvpr", - "cvpr_chesapeake_landcover_prior_extension.zip", + 'prior_extension': os.path.join( + 'tests', + 'data', + 'chesapeake', + 'cvpr', + 'cvpr_chesapeake_landcover_prior_extension.zip', ), }, ) monkeypatch.setattr( ChesapeakeCVPR, - "_files", - ["de_1m_2013_extended-debuffered-test_tiles", "spatial_index.geojson"], + '_files', + ['de_1m_2013_extended-debuffered-test_tiles', 'spatial_index.geojson'], ) root = str(tmp_path) transforms = nn.Identity() return ChesapeakeCVPR( root, - splits=["de-test"], + splits=['de-test'], layers=request.param, transforms=transforms, download=True, @@ -159,8 +159,8 @@ def dataset( def test_getitem(self, dataset: ChesapeakeCVPR) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_and(self, dataset: ChesapeakeCVPR) -> None: ds = dataset & dataset @@ -177,48 +177,48 @@ def test_already_downloaded(self, tmp_path: Path) -> None: root = str(tmp_path) shutil.copy( os.path.join( - "tests", "data", "chesapeake", "cvpr", "cvpr_chesapeake_landcover.zip" + 'tests', 'data', 'chesapeake', 'cvpr', 'cvpr_chesapeake_landcover.zip' ), root, ) shutil.copy( os.path.join( - "tests", - "data", - "chesapeake", - "cvpr", - "cvpr_chesapeake_landcover_prior_extension.zip", + 'tests', + 'data', + 'chesapeake', + 'cvpr', + 'cvpr_chesapeake_landcover_prior_extension.zip', ), root, ) ChesapeakeCVPR(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): ChesapeakeCVPR(str(tmp_path), checksum=True) def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] def test_multiple_hits_query(self, dataset: ChesapeakeCVPR) -> None: ds = ChesapeakeCVPR( - root=dataset.root, splits=["de-train", "de-test"], layers=dataset.layers + root=dataset.root, splits=['de-train', 'de-test'], layers=dataset.layers ) with pytest.raises( - IndexError, match="query: .* spans multiple tiles which is not valid" + IndexError, match='query: .* spans multiple tiles which is not valid' ): ds[dataset.bounds] def test_plot(self, dataset: ChesapeakeCVPR) -> None: x = dataset[dataset.bounds].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["mask"][:, :, 0].clone().unsqueeze(2) + x['prediction'] = x['mask'][:, :, 0].clone().unsqueeze(2) dataset.plot(x) plt.close() diff --git a/tests/datasets/test_cloud_cover.py b/tests/datasets/test_cloud_cover.py index 98f7952d48e..7bc0e12afc7 100644 --- a/tests/datasets/test_cloud_cover.py +++ b/tests/datasets/test_cloud_cover.py @@ -22,7 +22,7 @@ class Collection: def download(self, output_dir: str, **kwargs: str) -> None: glob_path = os.path.join( - "tests", "data", "ref_cloud_cover_detection_challenge_v1", "*.tar.gz" + 'tests', 'data', 'ref_cloud_cover_detection_challenge_v1', '*.tar.gz' ) for tarball in glob.iglob(glob_path): shutil.copy(tarball, output_dir) @@ -35,23 +35,23 @@ def fetch(dataset_id: str, **kwargs: str) -> Collection: class TestCloudCoverDetection: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CloudCoverDetection: - radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) + radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) test_image_meta = { - "filename": "ref_cloud_cover_detection_challenge_v1_test_source.tar.gz", - "md5": "542e64a6e39b53c84c6462ec1b989e43", + 'filename': 'ref_cloud_cover_detection_challenge_v1_test_source.tar.gz', + 'md5': '542e64a6e39b53c84c6462ec1b989e43', } - monkeypatch.setitem(CloudCoverDetection.image_meta, "test", test_image_meta) + monkeypatch.setitem(CloudCoverDetection.image_meta, 'test', test_image_meta) test_target_meta = { - "filename": "ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz", - "md5": "e8d41de08744a9845e74fca1eee3d1d3", + 'filename': 'ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz', + 'md5': 'e8d41de08744a9845e74fca1eee3d1d3', } - monkeypatch.setitem(CloudCoverDetection.target_meta, "test", test_target_meta) + monkeypatch.setitem(CloudCoverDetection.target_meta, 'test', test_target_meta) root = str(tmp_path) - split = "test" + split = 'test' transforms = nn.Identity() return CloudCoverDetection( @@ -59,55 +59,55 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CloudCoverDetecti transforms=transforms, split=split, download=True, - api_key="", + api_key='', checksum=True, ) def test_invalid_band(self, dataset: CloudCoverDetection) -> None: - invalid_bands = ["B09"] + invalid_bands = ['B09'] with pytest.raises(ValueError): CloudCoverDetection( root=dataset.root, - split="test", + split='test', download=False, - api_key="", + api_key='', bands=invalid_bands, ) def test_get_item(self, dataset: CloudCoverDetection) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) def test_add(self, dataset: CloudCoverDetection) -> None: assert len(dataset) == 1 def test_already_downloaded(self, dataset: CloudCoverDetection) -> None: - CloudCoverDetection(root=dataset.root, split="test", download=True, api_key="") + CloudCoverDetection(root=dataset.root, split='test', download=True, api_key='') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): CloudCoverDetection(str(tmp_path)) def test_plot(self, dataset: CloudCoverDetection) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() sample = dataset[0] - sample["prediction"] = sample["mask"].clone() - dataset.plot(sample, suptitle="Pred") + sample['prediction'] = sample['mask'].clone() + dataset.plot(sample, suptitle='Pred') plt.close() def test_plot_rgb(self, dataset: CloudCoverDetection) -> None: dataset = CloudCoverDetection( root=dataset.root, - split="test", - bands=list(["B08"]), + split='test', + bands=list(['B08']), download=True, - api_key="", + api_key='', ) with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): - dataset.plot(dataset[0], suptitle="Single Band") + dataset.plot(dataset[0], suptitle='Single Band') diff --git a/tests/datasets/test_cms_mangrove_canopy.py b/tests/datasets/test_cms_mangrove_canopy.py index 1ebd33a1095..3c3bf9d4b2f 100644 --- a/tests/datasets/test_cms_mangrove_canopy.py +++ b/tests/datasets/test_cms_mangrove_canopy.py @@ -29,15 +29,15 @@ class TestCMSGlobalMangroveCanopy: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path ) -> CMSGlobalMangroveCanopy: - zipfile = "CMS_Global_Map_Mangrove_Canopy_1665.zip" - monkeypatch.setattr(CMSGlobalMangroveCanopy, "zipfile", zipfile) + zipfile = 'CMS_Global_Map_Mangrove_Canopy_1665.zip' + monkeypatch.setattr(CMSGlobalMangroveCanopy, 'zipfile', zipfile) - md5 = "d6894fa6293cc9c0f3f95a810e842de5" - monkeypatch.setattr(CMSGlobalMangroveCanopy, "md5", md5) + md5 = 'd6894fa6293cc9c0f3f95a810e842de5' + monkeypatch.setattr(CMSGlobalMangroveCanopy, 'md5', md5) - root = os.path.join("tests", "data", "cms_mangrove_canopy") + root = os.path.join('tests', 'data', 'cms_mangrove_canopy') transforms = nn.Identity() - country = "Angola" + country = 'Angola' return CMSGlobalMangroveCanopy( root, country=country, transforms=transforms, checksum=True @@ -46,39 +46,39 @@ def dataset( def test_getitem(self, dataset: CMSGlobalMangroveCanopy) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_no_dataset(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): CMSGlobalMangroveCanopy(str(tmp_path)) def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( - "tests", - "data", - "cms_mangrove_canopy", - "CMS_Global_Map_Mangrove_Canopy_1665.zip", + 'tests', + 'data', + 'cms_mangrove_canopy', + 'CMS_Global_Map_Mangrove_Canopy_1665.zip', ) root = str(tmp_path) shutil.copy(pathname, root) - CMSGlobalMangroveCanopy(root, country="Angola") + CMSGlobalMangroveCanopy(root, country='Angola') def test_corrupted(self, tmp_path: Path) -> None: with open( - os.path.join(tmp_path, "CMS_Global_Map_Mangrove_Canopy_1665.zip"), "w" + os.path.join(tmp_path, 'CMS_Global_Map_Mangrove_Canopy_1665.zip'), 'w' ) as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): - CMSGlobalMangroveCanopy(str(tmp_path), country="Angola", checksum=True) + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): + CMSGlobalMangroveCanopy(str(tmp_path), country='Angola', checksum=True) def test_invalid_country(self) -> None: with pytest.raises(AssertionError): - CMSGlobalMangroveCanopy(country="fakeCountry") + CMSGlobalMangroveCanopy(country='fakeCountry') def test_invalid_measurement(self) -> None: with pytest.raises(AssertionError): - CMSGlobalMangroveCanopy(measurement="wrongMeasurement") + CMSGlobalMangroveCanopy(measurement='wrongMeasurement') def test_and(self, dataset: CMSGlobalMangroveCanopy) -> None: ds = dataset & dataset @@ -91,12 +91,12 @@ def test_or(self, dataset: CMSGlobalMangroveCanopy) -> None: def test_plot(self, dataset: CMSGlobalMangroveCanopy) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: CMSGlobalMangroveCanopy) -> None: query = dataset.bounds x = dataset[query] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() diff --git a/tests/datasets/test_cowc.py b/tests/datasets/test_cowc.py index 19f448f5a27..f454569d5b7 100644 --- a/tests/datasets/test_cowc.py +++ b/tests/datasets/test_cowc.py @@ -28,24 +28,24 @@ def test_not_implemented(self) -> None: class TestCOWCCounting: - @pytest.fixture(params=["train", "test"]) + @pytest.fixture(params=['train', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> COWC: - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - base_url = os.path.join("tests", "data", "cowc_counting") + os.sep - monkeypatch.setattr(COWCCounting, "base_url", base_url) + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + base_url = os.path.join('tests', 'data', 'cowc_counting') + os.sep + monkeypatch.setattr(COWCCounting, 'base_url', base_url) md5s = [ - "7d0c6d1fb548d3ea3a182a56ce231f97", - "2e9a806b19b21f9d796c7393ad8f51ee", - "39453c0627effd908e773c5c1f8aecc9", - "67190b3e0ca8aa1fc93250aa5383a8f3", - "575aead6a0c92aba37d613895194da7c", - "e7c2279040d3ce31b9c925c45d0c61e2", - "f159e23d52bd0b5656fe296f427b98e1", - "0a4daed8c5f6c4e20faa6e38636e4346", + '7d0c6d1fb548d3ea3a182a56ce231f97', + '2e9a806b19b21f9d796c7393ad8f51ee', + '39453c0627effd908e773c5c1f8aecc9', + '67190b3e0ca8aa1fc93250aa5383a8f3', + '575aead6a0c92aba37d613895194da7c', + 'e7c2279040d3ce31b9c925c45d0c61e2', + 'f159e23d52bd0b5656fe296f427b98e1', + '0a4daed8c5f6c4e20faa6e38636e4346', ] - monkeypatch.setattr(COWCCounting, "md5s", md5s) + monkeypatch.setattr(COWCCounting, 'md5s', md5s) root = str(tmp_path) split = request.param transforms = nn.Identity() @@ -54,8 +54,8 @@ def dataset( def test_getitem(self, dataset: COWC) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) def test_len(self, dataset: COWC) -> None: assert len(dataset) in [6, 12] @@ -74,42 +74,42 @@ def test_out_of_bounds(self, dataset: COWC) -> None: def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - COWCCounting(split="foo") + COWCCounting(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): COWCCounting(str(tmp_path)) def test_plot(self, dataset: COWCCounting) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["label"].clone() + x['prediction'] = x['label'].clone() dataset.plot(x) plt.close() class TestCOWCDetection: - @pytest.fixture(params=["train", "test"]) + @pytest.fixture(params=['train', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> COWC: - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - base_url = os.path.join("tests", "data", "cowc_detection") + os.sep - monkeypatch.setattr(COWCDetection, "base_url", base_url) + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + base_url = os.path.join('tests', 'data', 'cowc_detection') + os.sep + monkeypatch.setattr(COWCDetection, 'base_url', base_url) md5s = [ - "6bbbdb36ee4922e879f66ed9234cb8ab", - "09e4af08c6e6553afe5098b328ce9749", - "12a2708ab7644766e43f5aae34aa7f2a", - "a896433398a0c58263c0d266cfc93bc4", - "911ed42c104db60f7a7d03a5b36bc1ab", - "4cdb4fefab6a2951591e7840c11a229d", - "dd315cfb48dfa7ddb8230c942682bc37", - "dccc2257e9c4a9dde2b4f84769804046", + '6bbbdb36ee4922e879f66ed9234cb8ab', + '09e4af08c6e6553afe5098b328ce9749', + '12a2708ab7644766e43f5aae34aa7f2a', + 'a896433398a0c58263c0d266cfc93bc4', + '911ed42c104db60f7a7d03a5b36bc1ab', + '4cdb4fefab6a2951591e7840c11a229d', + 'dd315cfb48dfa7ddb8230c942682bc37', + 'dccc2257e9c4a9dde2b4f84769804046', ] - monkeypatch.setattr(COWCDetection, "md5s", md5s) + monkeypatch.setattr(COWCDetection, 'md5s', md5s) root = str(tmp_path) split = request.param transforms = nn.Identity() @@ -118,8 +118,8 @@ def dataset( def test_getitem(self, dataset: COWC) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) def test_len(self, dataset: COWC) -> None: assert len(dataset) in [6, 12] @@ -138,18 +138,18 @@ def test_out_of_bounds(self, dataset: COWC) -> None: def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - COWCDetection(split="foo") + COWCDetection(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): COWCDetection(str(tmp_path)) def test_plot(self, dataset: COWCDetection) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["label"].clone() + x['prediction'] = x['label'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_cropharvest.py b/tests/datasets/test_cropharvest.py index 274b6b726a1..f478cdf53ad 100644 --- a/tests/datasets/test_cropharvest.py +++ b/tests/datasets/test_cropharvest.py @@ -16,7 +16,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import CropHarvest, DatasetNotFoundError -pytest.importorskip("h5py", minversion="3") +pytest.importorskip('h5py', minversion='3') def download_url(url: str, root: str, filename: str, md5: str) -> None: @@ -29,30 +29,30 @@ def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: import_orig = builtins.__import__ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == "h5py": + if name == 'h5py': raise ImportError() return import_orig(name, *args, **kwargs) - monkeypatch.setattr(builtins, "__import__", mocked_import) + monkeypatch.setattr(builtins, '__import__', mocked_import) @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest: - monkeypatch.setattr(torchgeo.datasets.cropharvest, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.cropharvest, 'download_url', download_url) monkeypatch.setitem( - CropHarvest.file_dict["features"], "md5", "ef6f4f00c0b3b50ed8380b0044928572" + CropHarvest.file_dict['features'], 'md5', 'ef6f4f00c0b3b50ed8380b0044928572' ) monkeypatch.setitem( - CropHarvest.file_dict["labels"], "md5", "1d93b6bfcec7b6797b75acbd9d284b92" + CropHarvest.file_dict['labels'], 'md5', '1d93b6bfcec7b6797b75acbd9d284b92' ) monkeypatch.setitem( - CropHarvest.file_dict["features"], - "url", - os.path.join("tests", "data", "cropharvest", "features.tar.gz"), + CropHarvest.file_dict['features'], + 'url', + os.path.join('tests', 'data', 'cropharvest', 'features.tar.gz'), ) monkeypatch.setitem( - CropHarvest.file_dict["labels"], - "url", - os.path.join("tests", "data", "cropharvest", "labels.geojson"), + CropHarvest.file_dict['labels'], + 'url', + os.path.join('tests', 'data', 'cropharvest', 'labels.geojson'), ) root = str(tmp_path) @@ -64,11 +64,11 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest: def test_getitem(self, dataset: CropHarvest) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["array"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) - assert x["array"].shape == (12, 18) + assert isinstance(x['array'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) + assert x['array'].shape == (12, 18) y = dataset[2] - assert y["label"] == 1 + assert y['label'] == 1 def test_len(self, dataset: CropHarvest) -> None: assert len(dataset) == 5 @@ -77,17 +77,17 @@ def test_already_downloaded(self, dataset: CropHarvest, tmp_path: Path) -> None: CropHarvest(root=str(tmp_path), download=False) def test_downloaded_zipped(self, dataset: CropHarvest, tmp_path: Path) -> None: - feature_path = os.path.join(tmp_path, "features") + feature_path = os.path.join(tmp_path, 'features') shutil.rmtree(feature_path) CropHarvest(root=str(tmp_path), download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): CropHarvest(str(tmp_path)) def test_plot(self, dataset: CropHarvest) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_mock_missing_module( @@ -95,6 +95,6 @@ def test_mock_missing_module( ) -> None: with pytest.raises( ImportError, - match="h5py is not installed and is required to use this dataset", + match='h5py is not installed and is required to use this dataset', ): CropHarvest(root=str(tmp_path), download=True)[0] diff --git a/tests/datasets/test_cv4a_kenya_crop_type.py b/tests/datasets/test_cv4a_kenya_crop_type.py index c5886060814..ad0e26ed03d 100644 --- a/tests/datasets/test_cv4a_kenya_crop_type.py +++ b/tests/datasets/test_cv4a_kenya_crop_type.py @@ -23,7 +23,7 @@ class Collection: def download(self, output_dir: str, **kwargs: str) -> None: glob_path = os.path.join( - "tests", "data", "ref_african_crops_kenya_02", "*.tar.gz" + 'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz' ) for tarball in glob.iglob(glob_path): shutil.copy(tarball, output_dir) @@ -36,23 +36,23 @@ def fetch(dataset_id: str, **kwargs: str) -> Collection: class TestCV4AKenyaCropType: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CV4AKenyaCropType: - radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) - source_md5 = "7f4dcb3f33743dddd73f453176308bfb" - labels_md5 = "95fc59f1d94a85ec00931d4d1280bec9" - monkeypatch.setitem(CV4AKenyaCropType.image_meta, "md5", source_md5) - monkeypatch.setitem(CV4AKenyaCropType.target_meta, "md5", labels_md5) + radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) + source_md5 = '7f4dcb3f33743dddd73f453176308bfb' + labels_md5 = '95fc59f1d94a85ec00931d4d1280bec9' + monkeypatch.setitem(CV4AKenyaCropType.image_meta, 'md5', source_md5) + monkeypatch.setitem(CV4AKenyaCropType.target_meta, 'md5', labels_md5) monkeypatch.setattr( - CV4AKenyaCropType, "tile_names", ["ref_african_crops_kenya_02_tile_00"] + CV4AKenyaCropType, 'tile_names', ['ref_african_crops_kenya_02_tile_00'] ) - monkeypatch.setattr(CV4AKenyaCropType, "dates", ["20190606"]) + monkeypatch.setattr(CV4AKenyaCropType, 'dates', ['20190606']) root = str(tmp_path) transforms = nn.Identity() return CV4AKenyaCropType( root, transforms=transforms, download=True, - api_key="", + api_key='', checksum=True, verbose=True, ) @@ -60,10 +60,10 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CV4AKenyaCropType def test_getitem(self, dataset: CV4AKenyaCropType) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - assert isinstance(x["x"], torch.Tensor) - assert isinstance(x["y"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + assert isinstance(x['x'], torch.Tensor) + assert isinstance(x['y'], torch.Tensor) def test_len(self, dataset: CV4AKenyaCropType) -> None: assert len(dataset) == 345 @@ -85,41 +85,41 @@ def test_get_splits(self, dataset: CV4AKenyaCropType) -> None: assert 4793 not in train_field_ids def test_already_downloaded(self, dataset: CV4AKenyaCropType) -> None: - CV4AKenyaCropType(root=dataset.root, download=True, api_key="") + CV4AKenyaCropType(root=dataset.root, download=True, api_key='') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): CV4AKenyaCropType(str(tmp_path)) def test_invalid_tile(self, dataset: CV4AKenyaCropType) -> None: with pytest.raises(AssertionError): - dataset._load_label_tile("foo") + dataset._load_label_tile('foo') with pytest.raises(AssertionError): - dataset._load_all_image_tiles("foo", ("B01", "B02")) + dataset._load_all_image_tiles('foo', ('B01', 'B02')) with pytest.raises(AssertionError): - dataset._load_single_image_tile("foo", "20190606", ("B01", "B02")) + dataset._load_single_image_tile('foo', '20190606', ('B01', 'B02')) def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): - CV4AKenyaCropType(bands=["B01", "B02"]) # type: ignore[arg-type] + CV4AKenyaCropType(bands=['B01', 'B02']) # type: ignore[arg-type] - with pytest.raises(ValueError, match="is an invalid band name."): - CV4AKenyaCropType(bands=("foo", "bar")) + with pytest.raises(ValueError, match='is an invalid band name.'): + CV4AKenyaCropType(bands=('foo', 'bar')) def test_plot(self, dataset: CV4AKenyaCropType) -> None: - dataset.plot(dataset[0], time_step=0, suptitle="Test") + dataset.plot(dataset[0], time_step=0, suptitle='Test') plt.close() sample = dataset[0] - sample["prediction"] = sample["mask"].clone() - dataset.plot(sample, time_step=0, suptitle="Pred") + sample['prediction'] = sample['mask'].clone() + dataset.plot(sample, time_step=0, suptitle='Pred') plt.close() def test_plot_rgb(self, dataset: CV4AKenyaCropType) -> None: - dataset = CV4AKenyaCropType(root=dataset.root, bands=tuple(["B01"])) + dataset = CV4AKenyaCropType(root=dataset.root, bands=tuple(['B01'])) with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): - dataset.plot(dataset[0], time_step=0, suptitle="Single Band") + dataset.plot(dataset[0], time_step=0, suptitle='Single Band') diff --git a/tests/datasets/test_cyclone.py b/tests/datasets/test_cyclone.py index 6ab894c1fb7..d165b064a90 100644 --- a/tests/datasets/test_cyclone.py +++ b/tests/datasets/test_cyclone.py @@ -19,7 +19,7 @@ class Collection: def download(self, output_dir: str, **kwargs: str) -> None: - for tarball in glob.iglob(os.path.join("tests", "data", "cyclone", "*.tar.gz")): + for tarball in glob.iglob(os.path.join('tests', 'data', 'cyclone', '*.tar.gz')): shutil.copy(tarball, output_dir) @@ -28,41 +28,41 @@ def fetch(collection_id: str, **kwargs: str) -> Collection: class TestTropicalCyclone: - @pytest.fixture(params=["train", "test"]) + @pytest.fixture(params=['train', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> TropicalCyclone: - radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) + radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) md5s = { - "train": { - "source": "2b818e0a0873728dabf52c7054a0ce4c", - "labels": "c3c2b6d02c469c5519f4add4f9132712", + 'train': { + 'source': '2b818e0a0873728dabf52c7054a0ce4c', + 'labels': 'c3c2b6d02c469c5519f4add4f9132712', }, - "test": { - "source": "bc07c519ddf3ce88857435ddddf98a16", - "labels": "3ca4243eff39b87c73e05ec8db1824bf", + 'test': { + 'source': 'bc07c519ddf3ce88857435ddddf98a16', + 'labels': '3ca4243eff39b87c73e05ec8db1824bf', }, } - monkeypatch.setattr(TropicalCyclone, "md5s", md5s) - monkeypatch.setattr(TropicalCyclone, "size", 1) + monkeypatch.setattr(TropicalCyclone, 'md5s', md5s) + monkeypatch.setattr(TropicalCyclone, 'size', 1) root = str(tmp_path) split = request.param transforms = nn.Identity() return TropicalCyclone( - root, split, transforms, download=True, api_key="", checksum=True + root, split, transforms, download=True, api_key='', checksum=True ) - @pytest.mark.parametrize("index", [0, 1]) + @pytest.mark.parametrize('index', [0, 1]) def test_getitem(self, dataset: TropicalCyclone, index: int) -> None: x = dataset[index] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["storm_id"], str) - assert isinstance(x["relative_time"], int) - assert isinstance(x["ocean"], int) - assert isinstance(x["label"], torch.Tensor) - assert x["image"].shape == (3, dataset.size, dataset.size) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['storm_id'], str) + assert isinstance(x['relative_time'], int) + assert isinstance(x['ocean'], int) + assert isinstance(x['label'], torch.Tensor) + assert x['image'].shape == (3, dataset.size, dataset.size) def test_len(self, dataset: TropicalCyclone) -> None: assert len(dataset) == 5 @@ -73,21 +73,21 @@ def test_add(self, dataset: TropicalCyclone) -> None: assert len(ds) == 10 def test_already_downloaded(self, dataset: TropicalCyclone) -> None: - TropicalCyclone(root=dataset.root, download=True, api_key="") + TropicalCyclone(root=dataset.root, download=True, api_key='') def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - TropicalCyclone(split="foo") + TropicalCyclone(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): TropicalCyclone(str(tmp_path)) def test_plot(self, dataset: TropicalCyclone) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() sample = dataset[0] - sample["prediction"] = sample["label"] + sample['prediction'] = sample['label'] dataset.plot(sample) plt.close() diff --git a/tests/datasets/test_deepglobelandcover.py b/tests/datasets/test_deepglobelandcover.py index 1ab9b70b2d1..5e845958668 100644 --- a/tests/datasets/test_deepglobelandcover.py +++ b/tests/datasets/test_deepglobelandcover.py @@ -16,13 +16,13 @@ class TestDeepGlobeLandCover: - @pytest.fixture(params=["train", "test"]) + @pytest.fixture(params=['train', 'test']) def dataset( self, monkeypatch: MonkeyPatch, request: SubRequest ) -> DeepGlobeLandCover: - md5 = "2cbd68d36b1485f09f32d874dde7c5c5" - monkeypatch.setattr(DeepGlobeLandCover, "md5", md5) - root = os.path.join("tests", "data", "deepglobelandcover") + md5 = '2cbd68d36b1485f09f32d874dde7c5c5' + monkeypatch.setattr(DeepGlobeLandCover, 'md5', md5) + root = os.path.join('tests', 'data', 'deepglobelandcover') split = request.param transforms = nn.Identity() return DeepGlobeLandCover(root, split, transforms, checksum=True) @@ -30,40 +30,40 @@ def dataset( def test_getitem(self, dataset: DeepGlobeLandCover) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) def test_len(self, dataset: DeepGlobeLandCover) -> None: assert len(dataset) == 3 def test_extract(self, tmp_path: Path) -> None: - root = os.path.join("tests", "data", "deepglobelandcover") - filename = "data.zip" + root = os.path.join('tests', 'data', 'deepglobelandcover') + filename = 'data.zip' shutil.copyfile( os.path.join(root, filename), os.path.join(str(tmp_path), filename) ) DeepGlobeLandCover(root=str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: - with open(os.path.join(tmp_path, "data.zip"), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with open(os.path.join(tmp_path, 'data.zip'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): DeepGlobeLandCover(root=str(tmp_path), checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - DeepGlobeLandCover(split="foo") + DeepGlobeLandCover(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): DeepGlobeLandCover(str(tmp_path)) def test_plot(self, dataset: DeepGlobeLandCover) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["mask"].clone() + x['prediction'] = x['mask'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_dfc2022.py b/tests/datasets/test_dfc2022.py index 22caebcfb7b..d353da5e274 100644 --- a/tests/datasets/test_dfc2022.py +++ b/tests/datasets/test_dfc2022.py @@ -16,20 +16,20 @@ class TestDFC2022: - @pytest.fixture(params=["train", "train-unlabeled", "val"]) + @pytest.fixture(params=['train', 'train-unlabeled', 'val']) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> DFC2022: monkeypatch.setitem( - DFC2022.metadata["train"], "md5", "6e380c4fa659d05ca93be71b50cacd90" + DFC2022.metadata['train'], 'md5', '6e380c4fa659d05ca93be71b50cacd90' ) monkeypatch.setitem( - DFC2022.metadata["train-unlabeled"], - "md5", - "b2bf3839323d4eae636f198921442945", + DFC2022.metadata['train-unlabeled'], + 'md5', + 'b2bf3839323d4eae636f198921442945', ) monkeypatch.setitem( - DFC2022.metadata["val"], "md5", "e018dc6865bd3086738038fff27b818a" + DFC2022.metadata['val'], 'md5', 'e018dc6865bd3086738038fff27b818a' ) - root = os.path.join("tests", "data", "dfc2022") + root = os.path.join('tests', 'data', 'dfc2022') split = request.param transforms = nn.Identity() return DFC2022(root, split, transforms, checksum=True) @@ -37,57 +37,57 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> DFC2022: def test_getitem(self, dataset: DFC2022) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert x["image"].ndim == 3 - assert x["image"].shape[0] == 4 + assert isinstance(x['image'], torch.Tensor) + assert x['image'].ndim == 3 + assert x['image'].shape[0] == 4 - if dataset.split == "train": - assert isinstance(x["mask"], torch.Tensor) - assert x["mask"].ndim == 2 + if dataset.split == 'train': + assert isinstance(x['mask'], torch.Tensor) + assert x['mask'].ndim == 2 def test_len(self, dataset: DFC2022) -> None: assert len(dataset) == 2 def test_extract(self, tmp_path: Path) -> None: shutil.copyfile( - os.path.join("tests", "data", "dfc2022", "labeled_train.zip"), - os.path.join(tmp_path, "labeled_train.zip"), + os.path.join('tests', 'data', 'dfc2022', 'labeled_train.zip'), + os.path.join(tmp_path, 'labeled_train.zip'), ) shutil.copyfile( - os.path.join("tests", "data", "dfc2022", "unlabeled_train.zip"), - os.path.join(tmp_path, "unlabeled_train.zip"), + os.path.join('tests', 'data', 'dfc2022', 'unlabeled_train.zip'), + os.path.join(tmp_path, 'unlabeled_train.zip'), ) shutil.copyfile( - os.path.join("tests", "data", "dfc2022", "val.zip"), - os.path.join(tmp_path, "val.zip"), + os.path.join('tests', 'data', 'dfc2022', 'val.zip'), + os.path.join(tmp_path, 'val.zip'), ) DFC2022(root=str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: - with open(os.path.join(tmp_path, "labeled_train.zip"), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with open(os.path.join(tmp_path, 'labeled_train.zip'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): DFC2022(root=str(tmp_path), checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - DFC2022(split="foo") + DFC2022(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): DFC2022(str(tmp_path)) def test_plot(self, dataset: DFC2022) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - if dataset.split == "train": - x["prediction"] = x["mask"].clone() + if dataset.split == 'train': + x['prediction'] = x['mask'].clone() dataset.plot(x) plt.close() - del x["mask"] + del x['mask'] dataset.plot(x) plt.close() diff --git a/tests/datasets/test_eddmaps.py b/tests/datasets/test_eddmaps.py index a15adbeecaf..364e988aba3 100644 --- a/tests/datasets/test_eddmaps.py +++ b/tests/datasets/test_eddmaps.py @@ -16,9 +16,9 @@ class TestEDDMapS: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> EDDMapS: - root = os.path.join("tests", "data", "eddmaps") + root = os.path.join('tests', 'data', 'eddmaps') return EDDMapS(root) def test_getitem(self, dataset: EDDMapS) -> None: @@ -37,12 +37,12 @@ def test_or(self, dataset: EDDMapS) -> None: assert isinstance(ds, UnionDataset) def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): EDDMapS(str(tmp_path)) def test_invalid_query(self, dataset: EDDMapS) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_enviroatlas.py b/tests/datasets/test_enviroatlas.py index da7641b47a8..b6e249f702b 100644 --- a/tests/datasets/test_enviroatlas.py +++ b/tests/datasets/test_enviroatlas.py @@ -31,25 +31,25 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestEnviroAtlas: @pytest.fixture( params=[ - (("naip", "prior", "lc"), False), - (("naip", "prior", "buildings", "lc"), True), - (("naip", "prior"), False), + (('naip', 'prior', 'lc'), False), + (('naip', 'prior', 'buildings', 'lc'), True), + (('naip', 'prior'), False), ] ) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> EnviroAtlas: - monkeypatch.setattr(torchgeo.datasets.enviroatlas, "download_url", download_url) - monkeypatch.setattr(EnviroAtlas, "md5", "071ec65c611e1d4915a5247bffb5ad87") + monkeypatch.setattr(torchgeo.datasets.enviroatlas, 'download_url', download_url) + monkeypatch.setattr(EnviroAtlas, 'md5', '071ec65c611e1d4915a5247bffb5ad87') monkeypatch.setattr( EnviroAtlas, - "url", - os.path.join("tests", "data", "enviroatlas", "enviroatlas_lotp.zip"), + 'url', + os.path.join('tests', 'data', 'enviroatlas', 'enviroatlas_lotp.zip'), ) monkeypatch.setattr( EnviroAtlas, - "_files", - ["pittsburgh_pa-2010_1m-train_tiles-debuffered", "spatial_index.geojson"], + '_files', + ['pittsburgh_pa-2010_1m-train_tiles-debuffered', 'spatial_index.geojson'], ) root = str(tmp_path) transforms = nn.Identity() @@ -67,8 +67,8 @@ def test_getitem(self, dataset: EnviroAtlas) -> None: bb = next(iter(sampler)) x = dataset[bb] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_and(self, dataset: EnviroAtlas) -> None: ds = dataset & dataset @@ -84,29 +84,29 @@ def test_already_extracted(self, dataset: EnviroAtlas) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: root = str(tmp_path) shutil.copy( - os.path.join("tests", "data", "enviroatlas", "enviroatlas_lotp.zip"), root + os.path.join('tests', 'data', 'enviroatlas', 'enviroatlas_lotp.zip'), root ) EnviroAtlas(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): EnviroAtlas(str(tmp_path), checksum=True) def test_out_of_bounds_query(self, dataset: EnviroAtlas) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] def test_multiple_hits_query(self, dataset: EnviroAtlas) -> None: ds = EnviroAtlas( root=dataset.root, - splits=["pittsburgh_pa-2010_1m-train", "austin_tx-2012_1m-test"], + splits=['pittsburgh_pa-2010_1m-train', 'austin_tx-2012_1m-test'], layers=dataset.layers, ) with pytest.raises( - IndexError, match="query: .* spans multiple tiles which is not valid" + IndexError, match='query: .* spans multiple tiles which is not valid' ): ds[dataset.bounds] @@ -114,14 +114,14 @@ def test_plot(self, dataset: EnviroAtlas) -> None: sampler = RandomGeoSampler(dataset, size=16, length=1) bb = next(iter(sampler)) x = dataset[bb] - if "naip" not in dataset.layers or "lc" not in dataset.layers: + if 'naip' not in dataset.layers or 'lc' not in dataset.layers: with pytest.raises(ValueError, match="The 'naip' and"): dataset.plot(x) else: - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["mask"][0].clone() + x['prediction'] = x['mask'][0].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_esri2020.py b/tests/datasets/test_esri2020.py index 1e01e0ac11d..f0d28614247 100644 --- a/tests/datasets/test_esri2020.py +++ b/tests/datasets/test_esri2020.py @@ -29,19 +29,19 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestEsri2020: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Esri2020: - monkeypatch.setattr(torchgeo.datasets.esri2020, "download_url", download_url) - zipfile = "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip" - monkeypatch.setattr(Esri2020, "zipfile", zipfile) + monkeypatch.setattr(torchgeo.datasets.esri2020, 'download_url', download_url) + zipfile = 'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip' + monkeypatch.setattr(Esri2020, 'zipfile', zipfile) - md5 = "34aec55538694171c7b605b0cc0d0138" - monkeypatch.setattr(Esri2020, "md5", md5) + md5 = '34aec55538694171c7b605b0cc0d0138' + monkeypatch.setattr(Esri2020, 'md5', md5) url = os.path.join( - "tests", - "data", - "esri2020", - "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip", + 'tests', + 'data', + 'esri2020', + 'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip', ) - monkeypatch.setattr(Esri2020, "url", url) + monkeypatch.setattr(Esri2020, 'url', url) root = str(tmp_path) transforms = nn.Identity() return Esri2020(root, transforms=transforms, download=True, checksum=True) @@ -49,24 +49,24 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Esri2020: def test_getitem(self, dataset: Esri2020) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_already_extracted(self, dataset: Esri2020) -> None: Esri2020(dataset.paths, download=True) def test_not_extracted(self, tmp_path: Path) -> None: url = os.path.join( - "tests", - "data", - "esri2020", - "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip", + 'tests', + 'data', + 'esri2020', + 'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip', ) shutil.copy(url, tmp_path) Esri2020(str(tmp_path)) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): Esri2020(str(tmp_path), checksum=True) def test_and(self, dataset: Esri2020) -> None: @@ -80,23 +80,23 @@ def test_or(self, dataset: Esri2020) -> None: def test_plot(self, dataset: Esri2020) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: Esri2020) -> None: query = dataset.bounds x = dataset[query] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() def test_url(self) -> None: - ds = Esri2020(os.path.join("tests", "data", "esri2020")) - assert "ai4edataeuwest.blob.core.windows.net" in ds.url + ds = Esri2020(os.path.join('tests', 'data', 'esri2020')) + assert 'ai4edataeuwest.blob.core.windows.net' in ds.url def test_invalid_query(self, dataset: Esri2020) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_etci2021.py b/tests/datasets/test_etci2021.py index 8ee695bbcab..0cf4029921d 100644 --- a/tests/datasets/test_etci2021.py +++ b/tests/datasets/test_etci2021.py @@ -21,33 +21,33 @@ def download_url(url: str, root: str, *args: str) -> None: class TestETCI2021: - @pytest.fixture(params=["train", "val", "test"]) + @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> ETCI2021: - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - data_dir = os.path.join("tests", "data", "etci2021") + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + data_dir = os.path.join('tests', 'data', 'etci2021') metadata = { - "train": { - "filename": "train.zip", - "md5": "bd55f2116e43a35d5b94a765938be2aa", - "directory": "train", - "url": os.path.join(data_dir, "train.zip"), + 'train': { + 'filename': 'train.zip', + 'md5': 'bd55f2116e43a35d5b94a765938be2aa', + 'directory': 'train', + 'url': os.path.join(data_dir, 'train.zip'), }, - "val": { - "filename": "val_with_ref_labels.zip", - "md5": "96ed69904043e514c13c14ffd3ec45cd", - "directory": "test", - "url": os.path.join(data_dir, "val_with_ref_labels.zip"), + 'val': { + 'filename': 'val_with_ref_labels.zip', + 'md5': '96ed69904043e514c13c14ffd3ec45cd', + 'directory': 'test', + 'url': os.path.join(data_dir, 'val_with_ref_labels.zip'), }, - "test": { - "filename": "test_without_ref_labels.zip", - "md5": "1b66d85e22c8f5b0794b3542c5ea09ef", - "directory": "test_internal", - "url": os.path.join(data_dir, "test_without_ref_labels.zip"), + 'test': { + 'filename': 'test_without_ref_labels.zip', + 'md5': '1b66d85e22c8f5b0794b3542c5ea09ef', + 'directory': 'test_internal', + 'url': os.path.join(data_dir, 'test_without_ref_labels.zip'), }, } - monkeypatch.setattr(ETCI2021, "metadata", metadata) + monkeypatch.setattr(ETCI2021, 'metadata', metadata) root = str(tmp_path) split = request.param transforms = nn.Identity() @@ -56,15 +56,15 @@ def dataset( def test_getitem(self, dataset: ETCI2021) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - assert x["image"].shape[0] == 6 - assert x["image"].shape[-2:] == x["mask"].shape[-2:] + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + assert x['image'].shape[0] == 6 + assert x['image'].shape[-2:] == x['mask'].shape[-2:] - if dataset.split != "test": - assert x["mask"].shape[0] == 2 + if dataset.split != 'test': + assert x['mask'].shape[0] == 2 else: - assert x["mask"].shape[0] == 1 + assert x['mask'].shape[0] == 1 def test_len(self, dataset: ETCI2021) -> None: assert len(dataset) == 3 @@ -74,18 +74,18 @@ def test_already_downloaded(self, dataset: ETCI2021) -> None: def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - ETCI2021(split="foo") + ETCI2021(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): ETCI2021(str(tmp_path)) def test_plot(self, dataset: ETCI2021) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["mask"][0].clone() + x['prediction'] = x['mask'][0].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_eudem.py b/tests/datasets/test_eudem.py index 9816ea3c84d..5c5f7eb9119 100644 --- a/tests/datasets/test_eudem.py +++ b/tests/datasets/test_eudem.py @@ -24,9 +24,9 @@ class TestEUDEM: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> EUDEM: - md5s = {"eu_dem_v11_E30N10.zip": "ef148466c02197a08be169eaad186591"} - monkeypatch.setattr(EUDEM, "md5s", md5s) - zipfile = os.path.join("tests", "data", "eudem", "eu_dem_v11_E30N10.zip") + md5s = {'eu_dem_v11_E30N10.zip': 'ef148466c02197a08be169eaad186591'} + monkeypatch.setattr(EUDEM, 'md5s', md5s) + zipfile = os.path.join('tests', 'data', 'eudem', 'eu_dem_v11_E30N10.zip') shutil.copy(zipfile, tmp_path) root = str(tmp_path) transforms = nn.Identity() @@ -35,25 +35,25 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> EUDEM: def test_getitem(self, dataset: EUDEM) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_extracted_already(self, dataset: EUDEM) -> None: assert isinstance(dataset.paths, str) - zipfile = os.path.join(dataset.paths, "eu_dem_v11_E30N10.zip") - shutil.unpack_archive(zipfile, dataset.paths, "zip") + zipfile = os.path.join(dataset.paths, 'eu_dem_v11_E30N10.zip') + shutil.unpack_archive(zipfile, dataset.paths, 'zip') EUDEM(dataset.paths) def test_no_dataset(self, tmp_path: Path) -> None: shutil.rmtree(tmp_path) os.makedirs(tmp_path) - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): EUDEM(str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: - with open(os.path.join(tmp_path, "eu_dem_v11_E30N10.zip"), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with open(os.path.join(tmp_path, 'eu_dem_v11_E30N10.zip'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): EUDEM(str(tmp_path), checksum=True) def test_and(self, dataset: EUDEM) -> None: @@ -67,19 +67,19 @@ def test_or(self, dataset: EUDEM) -> None: def test_plot(self, dataset: EUDEM) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: EUDEM) -> None: query = dataset.bounds x = dataset[query] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() def test_invalid_query(self, dataset: EUDEM) -> None: query = BoundingBox(100, 100, 100, 100, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_eurocrops.py b/tests/datasets/test_eurocrops.py index 0477b8617ac..5846c8df52f 100644 --- a/tests/datasets/test_eurocrops.py +++ b/tests/datasets/test_eurocrops.py @@ -28,20 +28,20 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestEuroCrops: - @pytest.fixture(params=[None, ["1000000010"], ["1000000000"], ["2000000000"]]) + @pytest.fixture(params=[None, ['1000000010'], ['1000000000'], ['2000000000']]) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> EuroCrops: classes = request.param - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - monkeypatch.setattr(torchgeo.datasets.eurocrops, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + monkeypatch.setattr(torchgeo.datasets.eurocrops, 'download_url', download_url) monkeypatch.setattr( - EuroCrops, "zenodo_files", [("AA.zip", "b2ef5cac231294731c1dfea47cba544d")] + EuroCrops, 'zenodo_files', [('AA.zip', 'b2ef5cac231294731c1dfea47cba544d')] ) - monkeypatch.setattr(EuroCrops, "hcat_md5", "22d61cf3b316c8babfd209ae81419d8f") - base_url = os.path.join("tests", "data", "eurocrops") + os.sep - monkeypatch.setattr(EuroCrops, "base_url", base_url) - monkeypatch.setattr(plt, "show", lambda *args: None) + monkeypatch.setattr(EuroCrops, 'hcat_md5', '22d61cf3b316c8babfd209ae81419d8f') + base_url = os.path.join('tests', 'data', 'eurocrops') + os.sep + monkeypatch.setattr(EuroCrops, 'base_url', base_url) + monkeypatch.setattr(plt, 'show', lambda *args: None) root = str(tmp_path) transforms = nn.Identity() return EuroCrops( @@ -51,8 +51,8 @@ def dataset( def test_getitem(self, dataset: EuroCrops) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_and(self, dataset: EuroCrops) -> None: ds = dataset & dataset @@ -68,25 +68,25 @@ def test_already_downloaded(self, dataset: EuroCrops) -> None: def test_plot(self, dataset: EuroCrops) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') def test_plot_prediction(self, dataset: EuroCrops) -> None: query = dataset.bounds x = dataset[query] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): EuroCrops(str(tmp_path)) def test_invalid_query(self, dataset: EuroCrops) -> None: query = BoundingBox(200, 200, 200, 200, 2, 2) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] def test_integrity_error(self, dataset: EuroCrops) -> None: - dataset.zenodo_files = [("AA.zip", "invalid")] + dataset.zenodo_files = [('AA.zip', 'invalid')] assert not dataset._check_integrity() diff --git a/tests/datasets/test_eurosat.py b/tests/datasets/test_eurosat.py index 3f8173b6656..b6ec2a283d8 100644 --- a/tests/datasets/test_eurosat.py +++ b/tests/datasets/test_eurosat.py @@ -28,34 +28,34 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestEuroSAT: - @pytest.fixture(params=product([EuroSAT, EuroSAT100], ["train", "val", "test"])) + @pytest.fixture(params=product([EuroSAT, EuroSAT100], ['train', 'val', 'test'])) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> EuroSAT: base_class: type[EuroSAT] = request.param[0] split: str = request.param[1] - monkeypatch.setattr(torchgeo.datasets.eurosat, "download_url", download_url) - md5 = "aa051207b0547daba0ac6af57808d68e" - monkeypatch.setattr(base_class, "md5", md5) - url = os.path.join("tests", "data", "eurosat", "EuroSATallBands.zip") - monkeypatch.setattr(base_class, "url", url) - monkeypatch.setattr(base_class, "filename", "EuroSATallBands.zip") + monkeypatch.setattr(torchgeo.datasets.eurosat, 'download_url', download_url) + md5 = 'aa051207b0547daba0ac6af57808d68e' + monkeypatch.setattr(base_class, 'md5', md5) + url = os.path.join('tests', 'data', 'eurosat', 'EuroSATallBands.zip') + monkeypatch.setattr(base_class, 'url', url) + monkeypatch.setattr(base_class, 'filename', 'EuroSATallBands.zip') monkeypatch.setattr( base_class, - "split_urls", + 'split_urls', { - "train": os.path.join("tests", "data", "eurosat", "eurosat-train.txt"), - "val": os.path.join("tests", "data", "eurosat", "eurosat-val.txt"), - "test": os.path.join("tests", "data", "eurosat", "eurosat-test.txt"), + 'train': os.path.join('tests', 'data', 'eurosat', 'eurosat-train.txt'), + 'val': os.path.join('tests', 'data', 'eurosat', 'eurosat-val.txt'), + 'test': os.path.join('tests', 'data', 'eurosat', 'eurosat-test.txt'), }, ) monkeypatch.setattr( base_class, - "split_md5s", + 'split_md5s', { - "train": "4af60a00fdfdf8500572ae5360694b71", - "val": "4af60a00fdfdf8500572ae5360694b71", - "test": "4af60a00fdfdf8500572ae5360694b71", + 'train': '4af60a00fdfdf8500572ae5360694b71', + 'val': '4af60a00fdfdf8500572ae5360694b71', + 'test': '4af60a00fdfdf8500572ae5360694b71', }, ) root = str(tmp_path) @@ -67,16 +67,16 @@ def dataset( def test_getitem(self, dataset: EuroSAT) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - EuroSAT(split="foo") + EuroSAT(split='foo') def test_invalid_bands(self) -> None: with pytest.raises(ValueError): - EuroSAT(bands=("OK", "BK")) + EuroSAT(bands=('OK', 'BK')) def test_len(self, dataset: EuroSAT) -> None: assert len(dataset) == 2 @@ -97,22 +97,22 @@ def test_already_downloaded_not_extracted( EuroSAT(root=str(tmp_path), download=False) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): EuroSAT(str(tmp_path)) def test_plot(self, dataset: EuroSAT) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["label"].clone() + x['prediction'] = x['label'].clone() dataset.plot(x) plt.close() def test_plot_rgb(self, dataset: EuroSAT, tmp_path: Path) -> None: - dataset = EuroSAT(root=str(tmp_path), bands=("B03",)) + dataset = EuroSAT(root=str(tmp_path), bands=('B03',)) with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): - dataset.plot(dataset[0], suptitle="Single Band") + dataset.plot(dataset[0], suptitle='Single Band') diff --git a/tests/datasets/test_fair1m.py b/tests/datasets/test_fair1m.py index fcb7d4f7711..38db23974d3 100644 --- a/tests/datasets/test_fair1m.py +++ b/tests/datasets/test_fair1m.py @@ -22,49 +22,49 @@ def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) class TestFAIR1M: - test_root = os.path.join("tests", "data", "fair1m") + test_root = os.path.join('tests', 'data', 'fair1m') - @pytest.fixture(params=["train", "val", "test"]) + @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> FAIR1M: - monkeypatch.setattr(torchgeo.datasets.fair1m, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.fair1m, 'download_url', download_url) urls = { - "train": ( - os.path.join(self.test_root, "train", "part1", "images.zip"), - os.path.join(self.test_root, "train", "part1", "labelXml.zip"), - os.path.join(self.test_root, "train", "part2", "images.zip"), - os.path.join(self.test_root, "train", "part2", "labelXmls.zip"), + 'train': ( + os.path.join(self.test_root, 'train', 'part1', 'images.zip'), + os.path.join(self.test_root, 'train', 'part1', 'labelXml.zip'), + os.path.join(self.test_root, 'train', 'part2', 'images.zip'), + os.path.join(self.test_root, 'train', 'part2', 'labelXmls.zip'), ), - "val": ( - os.path.join(self.test_root, "validation", "images.zip"), - os.path.join(self.test_root, "validation", "labelXmls.zip"), + 'val': ( + os.path.join(self.test_root, 'validation', 'images.zip'), + os.path.join(self.test_root, 'validation', 'labelXmls.zip'), ), - "test": ( - os.path.join(self.test_root, "test", "images0.zip"), - os.path.join(self.test_root, "test", "images1.zip"), - os.path.join(self.test_root, "test", "images2.zip"), + 'test': ( + os.path.join(self.test_root, 'test', 'images0.zip'), + os.path.join(self.test_root, 'test', 'images1.zip'), + os.path.join(self.test_root, 'test', 'images2.zip'), ), } md5s = { - "train": ( - "ffbe9329e51ae83161ce24b5b46dc934", - "2db6fbe64be6ebb0a03656da6c6effe7", - "401b0f1d75d9d23f2e088bfeaf274cfa", - "d62b18eae8c3201f6112c2e9db84d605", + 'train': ( + 'ffbe9329e51ae83161ce24b5b46dc934', + '2db6fbe64be6ebb0a03656da6c6effe7', + '401b0f1d75d9d23f2e088bfeaf274cfa', + 'd62b18eae8c3201f6112c2e9db84d605', ), - "val": ( - "83d2f06574fc7158ded0eb1fb256c8fe", - "316490b200503c54cf43835a341b6dbe", + 'val': ( + '83d2f06574fc7158ded0eb1fb256c8fe', + '316490b200503c54cf43835a341b6dbe', ), - "test": ( - "3c02845752667b96a5749c90c7fdc994", - "9359107f1d0abac6a5b98725f4064bc0", - "d7bc2985c625ffd47d86cdabb2a9d2bc", + 'test': ( + '3c02845752667b96a5749c90c7fdc994', + '9359107f1d0abac6a5b98725f4064bc0', + 'd7bc2985c625ffd47d86cdabb2a9d2bc', ), } - monkeypatch.setattr(FAIR1M, "urls", urls) - monkeypatch.setattr(FAIR1M, "md5s", md5s) + monkeypatch.setattr(FAIR1M, 'urls', urls) + monkeypatch.setattr(FAIR1M, 'md5s', md5s) root = str(tmp_path) split = request.param transforms = nn.Identity() @@ -73,17 +73,17 @@ def dataset( def test_getitem(self, dataset: FAIR1M) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert x["image"].shape[0] == 3 + assert isinstance(x['image'], torch.Tensor) + assert x['image'].shape[0] == 3 - if dataset.split != "test": - assert isinstance(x["boxes"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) - assert x["boxes"].shape[-2:] == (5, 2) - assert x["label"].ndim == 1 + if dataset.split != 'test': + assert isinstance(x['boxes'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) + assert x['boxes'].shape[-2:] == (5, 2) + assert x['label'].ndim == 1 def test_len(self, dataset: FAIR1M) -> None: - if dataset.split == "train": + if dataset.split == 'train': assert len(dataset) == 8 else: assert len(dataset) == 4 @@ -105,7 +105,7 @@ def test_already_downloaded_not_extracted( FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True) def test_corrupted(self, tmp_path: Path, dataset: FAIR1M) -> None: - md5s = tuple(["randomhash"] * len(FAIR1M.md5s[dataset.split])) + md5s = tuple(['randomhash'] * len(FAIR1M.md5s[dataset.split])) FAIR1M.md5s[dataset.split] = md5s shutil.rmtree(dataset.root) for filepath, url in zip( @@ -115,22 +115,22 @@ def test_corrupted(self, tmp_path: Path, dataset: FAIR1M) -> None: os.makedirs(os.path.dirname(output), exist_ok=True) download_url(url, root=os.path.dirname(output), filename=output) - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True) def test_not_downloaded(self, tmp_path: Path, dataset: FAIR1M) -> None: shutil.rmtree(str(tmp_path)) - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): FAIR1M(root=str(tmp_path), split=dataset.split) def test_plot(self, dataset: FAIR1M) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - if dataset.split != "test": - x["prediction_boxes"] = x["boxes"].clone() + if dataset.split != 'test': + x['prediction_boxes'] = x['boxes'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_fire_risk.py b/tests/datasets/test_fire_risk.py index 76689bf9e82..e3f235c464d 100644 --- a/tests/datasets/test_fire_risk.py +++ b/tests/datasets/test_fire_risk.py @@ -21,15 +21,15 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestFireRisk: - @pytest.fixture(params=["train", "val"]) + @pytest.fixture(params=['train', 'val']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> FireRisk: - monkeypatch.setattr(torchgeo.datasets.fire_risk, "download_url", download_url) - url = os.path.join("tests", "data", "fire_risk", "FireRisk.zip") - md5 = "db22106d61b10d855234b4a74db921ac" - monkeypatch.setattr(FireRisk, "md5", md5) - monkeypatch.setattr(FireRisk, "url", url) + monkeypatch.setattr(torchgeo.datasets.fire_risk, 'download_url', download_url) + url = os.path.join('tests', 'data', 'fire_risk', 'FireRisk.zip') + md5 = 'db22106d61b10d855234b4a74db921ac' + monkeypatch.setattr(FireRisk, 'md5', md5) + monkeypatch.setattr(FireRisk, 'url', url) root = str(tmp_path) split = request.param transforms = nn.Identity() @@ -38,9 +38,9 @@ def dataset( def test_getitem(self, dataset: FireRisk) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) - assert x["image"].shape[0] == 3 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) + assert x['image'].shape[0] == 3 def test_len(self, dataset: FireRisk) -> None: assert len(dataset) == 5 @@ -56,15 +56,15 @@ def test_already_downloaded_not_extracted( FireRisk(root=str(tmp_path), download=False) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): FireRisk(str(tmp_path)) def test_plot(self, dataset: FireRisk) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["label"].clone() + x['prediction'] = x['label'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_forestdamage.py b/tests/datasets/test_forestdamage.py index 47caaebe5e3..39aae73026a 100644 --- a/tests/datasets/test_forestdamage.py +++ b/tests/datasets/test_forestdamage.py @@ -22,15 +22,15 @@ def download_url(url: str, root: str, *args: str) -> None: class TestForestDamage: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ForestDamage: - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - data_dir = os.path.join("tests", "data", "forestdamage") + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + data_dir = os.path.join('tests', 'data', 'forestdamage') - url = os.path.join(data_dir, "Data_Set_Larch_Casebearer.zip") + url = os.path.join(data_dir, 'Data_Set_Larch_Casebearer.zip') - md5 = "52d82ac38899e6e6bb40aacda643ee15" + md5 = '52d82ac38899e6e6bb40aacda643ee15' - monkeypatch.setattr(ForestDamage, "url", url) - monkeypatch.setattr(ForestDamage, "md5", md5) + monkeypatch.setattr(ForestDamage, 'url', url) + monkeypatch.setattr(ForestDamage, 'md5', md5) root = str(tmp_path) transforms = nn.Identity() return ForestDamage( @@ -43,39 +43,39 @@ def test_already_downloaded(self, dataset: ForestDamage) -> None: def test_getitem(self, dataset: ForestDamage) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) - assert isinstance(x["boxes"], torch.Tensor) - assert x["image"].shape[0] == 3 - assert x["image"].ndim == 3 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) + assert isinstance(x['boxes'], torch.Tensor) + assert x['image'].shape[0] == 3 + assert x['image'].ndim == 3 def test_len(self, dataset: ForestDamage) -> None: assert len(dataset) == 2 def test_not_extracted(self, tmp_path: Path) -> None: url = os.path.join( - "tests", "data", "forestdamage", "Data_Set_Larch_Casebearer.zip" + 'tests', 'data', 'forestdamage', 'Data_Set_Larch_Casebearer.zip' ) shutil.copy(url, tmp_path) ForestDamage(root=str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: - with open(os.path.join(tmp_path, "Data_Set_Larch_Casebearer.zip"), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with open(os.path.join(tmp_path, 'Data_Set_Larch_Casebearer.zip'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): ForestDamage(root=str(tmp_path), checksum=True) def test_not_found(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): ForestDamage(str(tmp_path)) def test_plot(self, dataset: ForestDamage) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: ForestDamage) -> None: x = dataset[0].copy() - x["prediction_boxes"] = x["boxes"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction_boxes'] = x['boxes'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() diff --git a/tests/datasets/test_gbif.py b/tests/datasets/test_gbif.py index bf6923a6bc2..35426d18b03 100644 --- a/tests/datasets/test_gbif.py +++ b/tests/datasets/test_gbif.py @@ -16,9 +16,9 @@ class TestGBIF: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> GBIF: - root = os.path.join("tests", "data", "gbif") + root = os.path.join('tests', 'data', 'gbif') return GBIF(root) def test_getitem(self, dataset: GBIF) -> None: @@ -37,12 +37,12 @@ def test_or(self, dataset: GBIF) -> None: assert isinstance(ds, UnionDataset) def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): GBIF(str(tmp_path)) def test_invalid_query(self, dataset: GBIF) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index c98f010c599..3424d9fda3b 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -46,12 +46,12 @@ def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: hits = self.index.intersection(tuple(query), objects=True) hit = next(iter(hits)) bounds = BoundingBox(*hit.bounds) - return {"index": bounds} + return {'index': bounds} class CustomVectorDataset(VectorDataset): - filename_glob = "*.geojson" - date_format = "%Y" + filename_glob = '*.geojson' + date_format = '%Y' filename_regex = r""" ^vector_(?P\d{4})\.geojson """ @@ -64,25 +64,25 @@ class CustomSentinelDataset(Sentinel2): class CustomNonGeoDataset(NonGeoDataset): def __getitem__(self, index: int) -> dict[str, int]: - return {"index": index} + return {'index': index} def __len__(self) -> int: return 2 class TestGeoDataset: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> GeoDataset: return CustomGeoDataset() def test_getitem(self, dataset: GeoDataset) -> None: query = BoundingBox(0, 1, 2, 3, 4, 5) - assert dataset[query] == {"index": query} + assert dataset[query] == {'index': query} def test_len(self, dataset: GeoDataset) -> None: assert len(dataset) == 1 - @pytest.mark.parametrize("crs", [CRS.from_epsg(4087), CRS.from_epsg(32631)]) + @pytest.mark.parametrize('crs', [CRS.from_epsg(4087), CRS.from_epsg(32631)]) def test_crs(self, dataset: GeoDataset, crs: CRS) -> None: dataset.crs = crs @@ -136,9 +136,9 @@ def test_or_four(self) -> None: def test_str(self, dataset: GeoDataset) -> None: out = str(dataset) - assert "type: GeoDataset" in out - assert "bbox: BoundingBox" in out - assert "size: 1" in out + assert 'type: GeoDataset' in out + assert 'bbox: BoundingBox' in out + assert 'size: 1' in out def test_picklable(self, dataset: GeoDataset) -> None: x = pickle.dumps(dataset) @@ -155,47 +155,47 @@ def test_abstract(self) -> None: def test_and_nongeo(self, dataset: GeoDataset) -> None: ds2 = CustomNonGeoDataset() with pytest.raises( - ValueError, match="IntersectionDataset only supports GeoDatasets" + ValueError, match='IntersectionDataset only supports GeoDatasets' ): dataset & ds2 # type: ignore[operator] def test_files_property_for_non_existing_file_or_dir(self, tmp_path: Path) -> None: - paths = [str(tmp_path), str(tmp_path / "non_existing_file.tif")] - with pytest.warns(UserWarning, match="Path was ignored."): + paths = [str(tmp_path), str(tmp_path / 'non_existing_file.tif')] + with pytest.warns(UserWarning, match='Path was ignored.'): assert len(CustomGeoDataset(paths=paths).files) == 0 def test_files_property_for_virtual_files(self) -> None: # Tests only a subset of schemes and combinations. paths = [ - "file://directory/file.tif", - "zip://archive.zip!folder/file.tif", - "az://azure_bucket/prefix/file.tif", - "/vsiaz/azure_bucket/prefix/file.tif", - "zip+az://azure_bucket/prefix/archive.zip!folder_in_archive/file.tif", - "/vsizip//vsiaz/azure_bucket/prefix/archive.zip/folder_in_archive/file.tif", + 'file://directory/file.tif', + 'zip://archive.zip!folder/file.tif', + 'az://azure_bucket/prefix/file.tif', + '/vsiaz/azure_bucket/prefix/file.tif', + 'zip+az://azure_bucket/prefix/archive.zip!folder_in_archive/file.tif', + '/vsizip//vsiaz/azure_bucket/prefix/archive.zip/folder_in_archive/file.tif', ] assert len(CustomGeoDataset(paths=paths).files) == len(paths) def test_files_property_ordered(self) -> None: """Ensure that the list of files is ordered.""" - paths = ["file://file3.tif", "file://file1.tif", "file://file2.tif"] + paths = ['file://file3.tif', 'file://file1.tif', 'file://file2.tif'] assert CustomGeoDataset(paths=paths).files == sorted(paths) def test_files_property_deterministic(self) -> None: """Ensure that the list of files is consistent regardless of their original order. """ - paths1 = ["file://file3.tif", "file://file1.tif", "file://file2.tif"] - paths2 = ["file://file2.tif", "file://file3.tif", "file://file1.tif"] + paths1 = ['file://file3.tif', 'file://file1.tif', 'file://file2.tif'] + paths2 = ['file://file2.tif', 'file://file3.tif', 'file://file1.tif'] assert ( CustomGeoDataset(paths=paths1).files == CustomGeoDataset(paths=paths2).files ) class TestRasterDataset: - @pytest.fixture(params=zip([["R", "G", "B"], None], [True, False])) + @pytest.fixture(params=zip([['R', 'G', 'B'], None], [True, False])) def naip(self, request: SubRequest) -> NAIP: - root = os.path.join("tests", "data", "naip") + root = os.path.join('tests', 'data', 'naip') bands = request.param[0] crs = CRS.from_epsg(4087) transforms = nn.Identity() @@ -205,45 +205,45 @@ def naip(self, request: SubRequest) -> NAIP: @pytest.fixture( params=zip( [ - ["B04", "B03", "B02"], - ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B11"], + ['B04', 'B03', 'B02'], + ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B11'], ], [True, False], ) ) def sentinel(self, request: SubRequest) -> Sentinel2: - root = os.path.join("tests", "data", "sentinel2") + root = os.path.join('tests', 'data', 'sentinel2') bands = request.param[0] transforms = nn.Identity() cache = request.param[1] return Sentinel2(root, bands=bands, transforms=transforms, cache=cache) @pytest.mark.parametrize( - "paths", + 'paths', [ # Single directory - os.path.join("tests", "data", "naip"), + os.path.join('tests', 'data', 'naip'), # Multiple directories [ - os.path.join("tests", "data", "naip"), - os.path.join("tests", "data", "naip"), + os.path.join('tests', 'data', 'naip'), + os.path.join('tests', 'data', 'naip'), ], # Single file - os.path.join("tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif"), + os.path.join('tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif'), # Multiple files ( os.path.join( - "tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif" + 'tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif' ), os.path.join( - "tests", "data", "naip", "m_3807511_ne_18_060_20190605.tif" + 'tests', 'data', 'naip', 'm_3807511_ne_18_060_20190605.tif' ), ), # Combination { - os.path.join("tests", "data", "naip"), + os.path.join('tests', 'data', 'naip'), os.path.join( - "tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif" + 'tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif' ), }, ], @@ -254,45 +254,45 @@ def test_files(self, paths: str | Iterable[str]) -> None: def test_getitem_single_file(self, naip: NAIP) -> None: x = naip[naip.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["image"], torch.Tensor) - assert len(naip.bands) == x["image"].shape[0] + assert isinstance(x['crs'], CRS) + assert isinstance(x['image'], torch.Tensor) + assert len(naip.bands) == x['image'].shape[0] def test_getitem_separate_files(self, sentinel: Sentinel2) -> None: x = sentinel[sentinel.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["image"], torch.Tensor) - assert len(sentinel.bands) == x["image"].shape[0] + assert isinstance(x['crs'], CRS) + assert isinstance(x['image'], torch.Tensor) + assert len(sentinel.bands) == x['image'].shape[0] - @pytest.mark.parametrize("dtype", ["uint16", "uint32"]) + @pytest.mark.parametrize('dtype', ['uint16', 'uint32']) def test_getitem_uint_dtype(self, dtype: str) -> None: - root = os.path.join("tests", "data", "raster", dtype) + root = os.path.join('tests', 'data', 'raster', dtype) ds = RasterDataset(root) x = ds[ds.bounds] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert x["image"].dtype == torch.float32 + assert isinstance(x['image'], torch.Tensor) + assert x['image'].dtype == torch.float32 def test_invalid_query(self, sentinel: Sentinel2) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds: .*" + IndexError, match='query: .* not found in index with bounds: .*' ): sentinel[query] def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): RasterDataset(str(tmp_path)) def test_no_all_bands(self) -> None: - root = os.path.join("tests", "data", "sentinel2") - bands = ["B04", "B03", "B02"] + root = os.path.join('tests', 'data', 'sentinel2') + bands = ['B04', 'B03', 'B02'] transforms = nn.Identity() cache = True msg = ( - "CustomSentinelDataset is missing an `all_bands` attribute," - " so `bands` cannot be specified." + 'CustomSentinelDataset is missing an `all_bands` attribute,' + ' so `bands` cannot be specified.' ) with pytest.raises(AssertionError, match=msg): @@ -300,27 +300,27 @@ def test_no_all_bands(self) -> None: class TestVectorDataset: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> CustomVectorDataset: - root = os.path.join("tests", "data", "vector") + root = os.path.join('tests', 'data', 'vector') transforms = nn.Identity() return CustomVectorDataset(root, res=0.1, transforms=transforms) - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def multilabel(self) -> CustomVectorDataset: - root = os.path.join("tests", "data", "vector") + root = os.path.join('tests', 'data', 'vector') transforms = nn.Identity() return CustomVectorDataset( - root, res=0.1, transforms=transforms, label_name="label_id" + root, res=0.1, transforms=transforms, label_name='label_id' ) def test_getitem(self, dataset: CustomVectorDataset) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) assert torch.equal( - x["mask"].unique(), # type: ignore[no-untyped-call] + x['mask'].unique(), # type: ignore[no-untyped-call] torch.tensor([0, 1], dtype=torch.uint8), ) @@ -331,37 +331,37 @@ def test_time_index(self, dataset: CustomVectorDataset) -> None: def test_getitem_multilabel(self, multilabel: CustomVectorDataset) -> None: x = multilabel[multilabel.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) assert torch.equal( - x["mask"].unique(), # type: ignore[no-untyped-call] + x['mask'].unique(), # type: ignore[no-untyped-call] torch.tensor([0, 1, 2, 3], dtype=torch.uint8), ) def test_empty_shapes(self, dataset: CustomVectorDataset) -> None: query = BoundingBox(1.1, 1.9, 1.1, 1.9, 0, sys.maxsize) x = dataset[query] - assert torch.equal(x["mask"], torch.zeros(8, 8, dtype=torch.uint8)) + assert torch.equal(x['mask'], torch.zeros(8, 8, dtype=torch.uint8)) def test_invalid_query(self, dataset: CustomVectorDataset) -> None: query = BoundingBox(3, 3, 3, 3, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): VectorDataset(str(tmp_path)) class TestNonGeoDataset: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> NonGeoDataset: return CustomNonGeoDataset() def test_getitem(self, dataset: NonGeoDataset) -> None: - assert dataset[0] == {"index": 0} + assert dataset[0] == {'index': 0} def test_len(self, dataset: NonGeoDataset) -> None: assert len(dataset) == 2 @@ -391,8 +391,8 @@ def test_add_four(self) -> None: assert len(dataset) == 8 def test_str(self, dataset: NonGeoDataset) -> None: - assert "type: NonGeoDataset" in str(dataset) - assert "size: 2" in str(dataset) + assert 'type: NonGeoDataset' in str(dataset) + assert 'size: 2' in str(dataset) def test_abstract(self) -> None: with pytest.raises(TypeError, match="Can't instantiate abstract class"): @@ -400,22 +400,22 @@ def test_abstract(self) -> None: class TestNonGeoClassificationDataset: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self, root: str) -> NonGeoClassificationDataset: transforms = nn.Identity() return NonGeoClassificationDataset(root, transforms=transforms) - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def root(self) -> str: - root = os.path.join("tests", "data", "nongeoclassification") + root = os.path.join('tests', 'data', 'nongeoclassification') return root def test_getitem(self, dataset: NonGeoClassificationDataset) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) - assert x["image"].shape[0] == 3 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) + assert x['image'].shape[0] == 3 def test_len(self, dataset: NonGeoClassificationDataset) -> None: assert len(dataset) == 2 @@ -445,216 +445,216 @@ def test_add_four(self, root: str) -> None: assert len(dataset) == 8 def test_str(self, dataset: NonGeoClassificationDataset) -> None: - assert "type: NonGeoDataset" in str(dataset) - assert "size: 2" in str(dataset) + assert 'type: NonGeoDataset' in str(dataset) + assert 'size: 2' in str(dataset) class TestIntersectionDataset: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> IntersectionDataset: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4326")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_4_epsg_4326')) transforms = nn.Identity() return IntersectionDataset(ds1, ds2, transforms=transforms) def test_getitem(self, dataset: IntersectionDataset) -> None: query = dataset.bounds sample = dataset[query] - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_len(self, dataset: IntersectionDataset) -> None: assert len(dataset) == 1 def test_str(self, dataset: IntersectionDataset) -> None: out = str(dataset) - assert "type: IntersectionDataset" in out - assert "bbox: BoundingBox" in out - assert "size: 1" in out + assert 'type: IntersectionDataset' in out + assert 'bbox: BoundingBox' in out + assert 'size: 1' in out def test_nongeo_dataset(self) -> None: ds1 = CustomNonGeoDataset() ds2 = CustomNonGeoDataset() with pytest.raises( - ValueError, match="IntersectionDataset only supports GeoDatasets" + ValueError, match='IntersectionDataset only supports GeoDatasets' ): IntersectionDataset(ds1, ds2) # type: ignore[arg-type] def test_different_crs_12(self) -> None: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4326')) ds = IntersectionDataset(ds1, ds2) sample = ds[ds.bounds] assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087) assert ds1.res == ds2.res == ds.res == 2 assert len(ds1) == len(ds2) == len(ds) == 1 - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_different_crs_12_3(self) -> None: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) - ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4326')) + ds3 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_32631')) ds = (ds1 & ds2) & ds3 sample = ds[ds.bounds] assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) assert ds1.res == ds2.res == ds3.res == ds.res == 2 assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1 - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_different_crs_1_23(self) -> None: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) - ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4326')) + ds3 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_32631')) ds = ds1 & (ds2 & ds3) sample = ds[ds.bounds] assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) assert ds1.res == ds2.res == ds3.res == ds.res == 2 assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1 - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_different_res_12(self) -> None: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_4_epsg_4087')) ds = IntersectionDataset(ds1, ds2) sample = ds[ds.bounds] assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087) assert ds1.res == ds2.res == ds.res == 2 assert len(ds1) == len(ds2) == len(ds) == 1 - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_different_res_12_3(self) -> None: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) - ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_4_epsg_4087')) + ds3 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_8_epsg_4087')) ds = (ds1 & ds2) & ds3 sample = ds[ds.bounds] assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) assert ds1.res == ds2.res == ds3.res == ds.res == 2 assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1 - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_different_res_1_23(self) -> None: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) - ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_4_epsg_4087')) + ds3 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_8_epsg_4087')) ds = ds1 & (ds2 & ds3) sample = ds[ds.bounds] assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) assert ds1.res == ds2.res == ds3.res == ds.res == 2 assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1 - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_no_overlap(self) -> None: ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5)) ds2 = CustomGeoDataset(BoundingBox(6, 7, 8, 9, 10, 11)) - msg = "Datasets have no spatiotemporal intersection" + msg = 'Datasets have no spatiotemporal intersection' with pytest.raises(RuntimeError, match=msg): IntersectionDataset(ds1, ds2) def test_invalid_query(self, dataset: IntersectionDataset) -> None: query = BoundingBox(-1, -1, -1, -1, -1, -1) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] class TestUnionDataset: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> UnionDataset: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4326")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_4_epsg_4326')) transforms = nn.Identity() return UnionDataset(ds1, ds2, transforms=transforms) def test_getitem(self, dataset: UnionDataset) -> None: query = dataset.bounds sample = dataset[query] - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_len(self, dataset: UnionDataset) -> None: assert len(dataset) == 2 def test_str(self, dataset: UnionDataset) -> None: out = str(dataset) - assert "type: UnionDataset" in out - assert "bbox: BoundingBox" in out - assert "size: 2" in out + assert 'type: UnionDataset' in out + assert 'bbox: BoundingBox' in out + assert 'size: 2' in out def test_different_crs_12(self) -> None: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4326')) ds = UnionDataset(ds1, ds2) sample = ds[ds.bounds] assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087) assert ds1.res == ds2.res == ds.res == 2 assert len(ds1) == len(ds2) == 1 assert len(ds) == 2 - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_different_crs_12_3(self) -> None: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) - ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4326')) + ds3 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_32631')) ds = (ds1 | ds2) | ds3 sample = ds[ds.bounds] assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) assert ds1.res == ds2.res == ds3.res == ds.res == 2 assert len(ds1) == len(ds2) == len(ds3) == 1 assert len(ds) == 3 - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_different_crs_1_23(self) -> None: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) - ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4326')) + ds3 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_32631')) ds = ds1 | (ds2 | ds3) sample = ds[ds.bounds] assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) assert ds1.res == ds2.res == ds3.res == ds.res == 2 assert len(ds1) == len(ds2) == len(ds3) == 1 assert len(ds) == 3 - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_different_res_12(self) -> None: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_4_epsg_4087')) ds = UnionDataset(ds1, ds2) sample = ds[ds.bounds] assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087) assert ds1.res == ds2.res == ds.res == 2 assert len(ds1) == len(ds2) == 1 assert len(ds) == 2 - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_different_res_12_3(self) -> None: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) - ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_4_epsg_4087')) + ds3 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_8_epsg_4087')) ds = (ds1 | ds2) | ds3 sample = ds[ds.bounds] assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) assert ds1.res == ds2.res == ds3.res == ds.res == 2 assert len(ds1) == len(ds2) == len(ds3) == 1 assert len(ds) == 3 - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_different_res_1_23(self) -> None: - ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) - ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) - ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087")) + ds1 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_2_epsg_4087')) + ds2 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_4_epsg_4087')) + ds3 = RasterDataset(os.path.join('tests', 'data', 'raster', 'res_8_epsg_4087')) ds = ds1 | (ds2 | ds3) sample = ds[ds.bounds] assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) assert ds1.res == ds2.res == ds3.res == ds.res == 2 assert len(ds1) == len(ds2) == len(ds3) == 1 assert len(ds) == 3 - assert isinstance(sample["image"], torch.Tensor) + assert isinstance(sample['image'], torch.Tensor) def test_nongeo_dataset(self) -> None: ds1 = CustomNonGeoDataset() ds2 = CustomNonGeoDataset() ds3 = CustomGeoDataset() - msg = "UnionDataset only supports GeoDatasets" + msg = 'UnionDataset only supports GeoDatasets' with pytest.raises(ValueError, match=msg): UnionDataset(ds1, ds2) # type: ignore[arg-type] with pytest.raises(ValueError, match=msg): @@ -665,6 +665,6 @@ def test_nongeo_dataset(self) -> None: def test_invalid_query(self, dataset: UnionDataset) -> None: query = BoundingBox(-1, -1, -1, -1, -1, -1) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_gid15.py b/tests/datasets/test_gid15.py index e39619d8313..9c0358fb08b 100644 --- a/tests/datasets/test_gid15.py +++ b/tests/datasets/test_gid15.py @@ -21,15 +21,15 @@ def download_url(url: str, root: str, *args: str) -> None: class TestGID15: - @pytest.fixture(params=["train", "val", "test"]) + @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> GID15: - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - md5 = "3d5b1373ef9a3084ec493b9b2056fe07" - monkeypatch.setattr(GID15, "md5", md5) - url = os.path.join("tests", "data", "gid15", "gid-15.zip") - monkeypatch.setattr(GID15, "url", url) + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + md5 = '3d5b1373ef9a3084ec493b9b2056fe07' + monkeypatch.setattr(GID15, 'md5', md5) + url = os.path.join('tests', 'data', 'gid15', 'gid-15.zip') + monkeypatch.setattr(GID15, 'url', url) root = str(tmp_path) split = request.param transforms = nn.Identity() @@ -38,14 +38,14 @@ def dataset( def test_getitem(self, dataset: GID15) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert x["image"].shape[0] == 3 + assert isinstance(x['image'], torch.Tensor) + assert x['image'].shape[0] == 3 - if dataset.split != "test": - assert isinstance(x["mask"], torch.Tensor) - assert x["image"].shape[-2:] == x["mask"].shape[-2:] + if dataset.split != 'test': + assert isinstance(x['mask'], torch.Tensor) + assert x['image'].shape[-2:] == x['mask'].shape[-2:] else: - assert "mask" not in x + assert 'mask' not in x def test_len(self, dataset: GID15) -> None: assert len(dataset) == 2 @@ -55,22 +55,22 @@ def test_already_downloaded(self, dataset: GID15) -> None: def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - GID15(split="foo") + GID15(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): GID15(str(tmp_path)) def test_plot(self, dataset: GID15) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() - if dataset.split != "test": + if dataset.split != 'test': sample = dataset[0] - sample["prediction"] = torch.clone(sample["mask"]) - dataset.plot(sample, suptitle="Prediction") + sample['prediction'] = torch.clone(sample['mask']) + dataset.plot(sample, suptitle='Prediction') else: sample = dataset[0] - sample["prediction"] = torch.ones((1, 1)) + sample['prediction'] = torch.ones((1, 1)) dataset.plot(sample) plt.close() diff --git a/tests/datasets/test_globbiomass.py b/tests/datasets/test_globbiomass.py index 5bffc3ff0a4..94309f677a6 100644 --- a/tests/datasets/test_globbiomass.py +++ b/tests/datasets/test_globbiomass.py @@ -25,18 +25,18 @@ class TestGlobBiomass: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> GlobBiomass: shutil.copy( - os.path.join("tests", "data", "globbiomass", "N00E020_agb.zip"), tmp_path + os.path.join('tests', 'data', 'globbiomass', 'N00E020_agb.zip'), tmp_path ) shutil.copy( - os.path.join("tests", "data", "globbiomass", "N00E020_gsv.zip"), tmp_path + os.path.join('tests', 'data', 'globbiomass', 'N00E020_gsv.zip'), tmp_path ) md5s = { - "N00E020_agb.zip": "22e11817ede672a2a76b8a5588bc4bf4", - "N00E020_gsv.zip": "e79bf051ac5d659cb21c566c53ce7b98", + 'N00E020_agb.zip': '22e11817ede672a2a76b8a5588bc4bf4', + 'N00E020_gsv.zip': 'e79bf051ac5d659cb21c566c53ce7b98', } - monkeypatch.setattr(GlobBiomass, "md5s", md5s) + monkeypatch.setattr(GlobBiomass, 'md5s', md5s) root = str(tmp_path) transforms = nn.Identity() return GlobBiomass(root, transforms=transforms, checksum=True) @@ -44,20 +44,20 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> GlobBiomass: def test_getitem(self, dataset: GlobBiomass) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_already_extracted(self, dataset: GlobBiomass) -> None: GlobBiomass(dataset.paths) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): GlobBiomass(str(tmp_path), checksum=True) def test_corrupted(self, tmp_path: Path) -> None: - with open(os.path.join(tmp_path, "N00E020_agb.zip"), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with open(os.path.join(tmp_path, 'N00E020_agb.zip'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): GlobBiomass(str(tmp_path), checksum=True) def test_and(self, dataset: GlobBiomass) -> None: @@ -71,19 +71,19 @@ def test_or(self, dataset: GlobBiomass) -> None: def test_plot(self, dataset: GlobBiomass) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: GlobBiomass) -> None: query = dataset.bounds x = dataset[query] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() def test_invalid_query(self, dataset: GlobBiomass) -> None: query = BoundingBox(100, 100, 100, 100, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py index a6192ab012e..e907213c29c 100644 --- a/tests/datasets/test_idtrees.py +++ b/tests/datasets/test_idtrees.py @@ -18,7 +18,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, IDTReeS -pytest.importorskip("laspy", minversion="2") +pytest.importorskip('laspy', minversion='2') def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -26,31 +26,31 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestIDTReeS: - @pytest.fixture(params=zip(["train", "test", "test"], ["task1", "task1", "task2"])) + @pytest.fixture(params=zip(['train', 'test', 'test'], ['task1', 'task1', 'task2'])) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> IDTReeS: - monkeypatch.setattr(torchgeo.datasets.idtrees, "download_url", download_url) - data_dir = os.path.join("tests", "data", "idtrees") + monkeypatch.setattr(torchgeo.datasets.idtrees, 'download_url', download_url) + data_dir = os.path.join('tests', 'data', 'idtrees') metadata = { - "train": { - "url": os.path.join(data_dir, "IDTREES_competition_train_v2.zip"), - "md5": "5ddfa76240b4bb6b4a7861d1d31c299c", - "filename": "IDTREES_competition_train_v2.zip", + 'train': { + 'url': os.path.join(data_dir, 'IDTREES_competition_train_v2.zip'), + 'md5': '5ddfa76240b4bb6b4a7861d1d31c299c', + 'filename': 'IDTREES_competition_train_v2.zip', }, - "test": { - "url": os.path.join(data_dir, "IDTREES_competition_test_v2.zip"), - "md5": "b108931c84a70f2a38a8234290131c9b", - "filename": "IDTREES_competition_test_v2.zip", + 'test': { + 'url': os.path.join(data_dir, 'IDTREES_competition_test_v2.zip'), + 'md5': 'b108931c84a70f2a38a8234290131c9b', + 'filename': 'IDTREES_competition_test_v2.zip', }, } split, task = request.param - monkeypatch.setattr(IDTReeS, "metadata", metadata) + monkeypatch.setattr(IDTReeS, 'metadata', metadata) root = str(tmp_path) transforms = nn.Identity() return IDTReeS(root, split, task, transforms, download=True, checksum=True) - @pytest.fixture(params=["laspy", "pyvista"]) + @pytest.fixture(params=['laspy', 'pyvista']) def mock_missing_module(self, monkeypatch: MonkeyPatch, request: SubRequest) -> str: import_orig = builtins.__import__ package = str(request.param) @@ -60,29 +60,29 @@ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: raise ImportError() return import_orig(name, *args, **kwargs) - monkeypatch.setattr(builtins, "__import__", mocked_import) + monkeypatch.setattr(builtins, '__import__', mocked_import) return package def test_getitem(self, dataset: IDTReeS) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["chm"], torch.Tensor) - assert isinstance(x["hsi"], torch.Tensor) - assert isinstance(x["las"], torch.Tensor) - assert x["image"].shape == (3, 200, 200) - assert x["chm"].shape == (1, 200, 200) - assert x["hsi"].shape == (369, 200, 200) - assert x["las"].ndim == 2 - assert x["las"].shape[0] == 3 - - if "label" in x: - assert isinstance(x["label"], torch.Tensor) - if "boxes" in x: - assert isinstance(x["boxes"], torch.Tensor) - if x["boxes"].ndim != 1: - assert x["boxes"].ndim == 2 - assert x["boxes"].shape[-1] == 4 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['chm'], torch.Tensor) + assert isinstance(x['hsi'], torch.Tensor) + assert isinstance(x['las'], torch.Tensor) + assert x['image'].shape == (3, 200, 200) + assert x['chm'].shape == (1, 200, 200) + assert x['hsi'].shape == (369, 200, 200) + assert x['las'].ndim == 2 + assert x['las'].shape[0] == 3 + + if 'label' in x: + assert isinstance(x['label'], torch.Tensor) + if 'boxes' in x: + assert isinstance(x['boxes'], torch.Tensor) + if x['boxes'].ndim != 1: + assert x['boxes'].ndim == 2 + assert x['boxes'].shape[-1] == 4 def test_len(self, dataset: IDTReeS) -> None: assert len(dataset) == 3 @@ -91,11 +91,11 @@ def test_already_downloaded(self, dataset: IDTReeS) -> None: IDTReeS(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): IDTReeS(str(tmp_path)) def test_not_extracted(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "idtrees", "*.zip") + pathname = os.path.join('tests', 'data', 'idtrees', '*.zip') root = str(tmp_path) for zipfile in glob.iglob(pathname): shutil.copy(zipfile, root) @@ -106,39 +106,39 @@ def test_mock_missing_module( ) -> None: package = mock_missing_module - if package == "laspy": + if package == 'laspy': with pytest.raises( ImportError, - match=f"{package} is not installed and is required to use this dataset", + match=f'{package} is not installed and is required to use this dataset', ): IDTReeS(dataset.root, dataset.split, dataset.task) - elif package == "pyvista": + elif package == 'pyvista': with pytest.raises( ImportError, - match=f"{package} is not installed and is required to plot point cloud", + match=f'{package} is not installed and is required to plot point cloud', ): dataset.plot_las(0) def test_plot(self, dataset: IDTReeS) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - if "boxes" in x: - x["prediction_boxes"] = x["boxes"] + if 'boxes' in x: + x['prediction_boxes'] = x['boxes'] dataset.plot(x, show_titles=True) plt.close() - if "label" in x: - x["prediction_label"] = x["label"] + if 'label' in x: + x['prediction_label'] = x['label'] dataset.plot(x, show_titles=False) plt.close() def test_plot_las(self, dataset: IDTReeS) -> None: - pyvista = pytest.importorskip("pyvista", minversion="0.34.2") + pyvista = pytest.importorskip('pyvista', minversion='0.34.2') pyvista.OFF_SCREEN = True # Test point cloud without colors point_cloud = dataset.plot_las(index=0) - pyvista.plot(point_cloud, scalars=point_cloud.points, cpos="yz", cmap="viridis") + pyvista.plot(point_cloud, scalars=point_cloud.points, cpos='yz', cmap='viridis') diff --git a/tests/datasets/test_inaturalist.py b/tests/datasets/test_inaturalist.py index 49c87d83f77..0f9a5424875 100644 --- a/tests/datasets/test_inaturalist.py +++ b/tests/datasets/test_inaturalist.py @@ -16,9 +16,9 @@ class TestINaturalist: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> INaturalist: - root = os.path.join("tests", "data", "inaturalist") + root = os.path.join('tests', 'data', 'inaturalist') return INaturalist(root) def test_getitem(self, dataset: INaturalist) -> None: @@ -37,12 +37,12 @@ def test_or(self, dataset: INaturalist) -> None: assert isinstance(ds, UnionDataset) def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): INaturalist(str(tmp_path)) def test_invalid_query(self, dataset: INaturalist) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_inria.py b/tests/datasets/test_inria.py index 71739a0ec8b..21bcb1a900d 100644 --- a/tests/datasets/test_inria.py +++ b/tests/datasets/test_inria.py @@ -15,13 +15,13 @@ class TestInriaAerialImageLabeling: - @pytest.fixture(params=["train", "val", "test"]) + @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch ) -> InriaAerialImageLabeling: - root = os.path.join("tests", "data", "inria") - test_md5 = "3ecbe95eb84aea064e455c4321546be1" - monkeypatch.setattr(InriaAerialImageLabeling, "md5", test_md5) + root = os.path.join('tests', 'data', 'inria') + test_md5 = '3ecbe95eb84aea064e455c4321546be1' + monkeypatch.setattr(InriaAerialImageLabeling, 'md5', test_md5) transforms = nn.Identity() return InriaAerialImageLabeling( root, split=request.param, transforms=transforms, checksum=True @@ -30,39 +30,39 @@ def dataset( def test_getitem(self, dataset: InriaAerialImageLabeling) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - if dataset.split == "train": - assert isinstance(x["mask"], torch.Tensor) - assert x["mask"].ndim == 2 - assert x["image"].shape[0] == 3 - assert x["image"].ndim == 3 + assert isinstance(x['image'], torch.Tensor) + if dataset.split == 'train': + assert isinstance(x['mask'], torch.Tensor) + assert x['mask'].ndim == 2 + assert x['image'].shape[0] == 3 + assert x['image'].ndim == 3 def test_len(self, dataset: InriaAerialImageLabeling) -> None: - if dataset.split == "train": + if dataset.split == 'train': assert len(dataset) == 2 - elif dataset.split == "val": + elif dataset.split == 'val': assert len(dataset) == 5 - elif dataset.split == "test": + elif dataset.split == 'test': assert len(dataset) == 7 def test_already_downloaded(self, dataset: InriaAerialImageLabeling) -> None: InriaAerialImageLabeling(root=dataset.root) def test_not_downloaded(self, tmp_path: str) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): InriaAerialImageLabeling(str(tmp_path)) def test_dataset_checksum(self, dataset: InriaAerialImageLabeling) -> None: - InriaAerialImageLabeling.md5 = "randommd5hash123" + InriaAerialImageLabeling.md5 = 'randommd5hash123' shutil.rmtree(os.path.join(dataset.root, dataset.directory)) - with pytest.raises(RuntimeError, match="Dataset corrupted"): + with pytest.raises(RuntimeError, match='Dataset corrupted'): InriaAerialImageLabeling(root=dataset.root, checksum=True) def test_plot(self, dataset: InriaAerialImageLabeling) -> None: x = dataset[0].copy() - if dataset.split == "train": - x["prediction"] = x["mask"] - dataset.plot(x, suptitle="Test") + if dataset.split == 'train': + x['prediction'] = x['mask'] + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() diff --git a/tests/datasets/test_l7irish.py b/tests/datasets/test_l7irish.py index 1f2801e5afb..5ae714b7645 100644 --- a/tests/datasets/test_l7irish.py +++ b/tests/datasets/test_l7irish.py @@ -31,15 +31,15 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestL7Irish: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> L7Irish: - monkeypatch.setattr(torchgeo.datasets.l7irish, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.l7irish, 'download_url', download_url) md5s = { - "austral": "0485d6045f6b508068ef8daf9e5a5326", - "boreal": "5798f32545d7166564c4c4429357b840", + 'austral': '0485d6045f6b508068ef8daf9e5a5326', + 'boreal': '5798f32545d7166564c4c4429357b840', } - url = os.path.join("tests", "data", "l7irish", "{}.tar.gz") - monkeypatch.setattr(L7Irish, "url", url) - monkeypatch.setattr(L7Irish, "md5s", md5s) + url = os.path.join('tests', 'data', 'l7irish', '{}.tar.gz') + monkeypatch.setattr(L7Irish, 'url', url) + monkeypatch.setattr(L7Irish, 'md5s', md5s) root = str(tmp_path) transforms = nn.Identity() return L7Irish(root, transforms=transforms, download=True, checksum=True) @@ -47,9 +47,9 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> L7Irish: def test_getitem(self, dataset: L7Irish) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) def test_and(self, dataset: L7Irish) -> None: ds = dataset & dataset @@ -61,41 +61,41 @@ def test_or(self, dataset: L7Irish) -> None: def test_plot(self, dataset: L7Irish) -> None: x = dataset[dataset.bounds] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_already_extracted(self, dataset: L7Irish) -> None: L7Irish(dataset.paths, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "l7irish", "*.tar.gz") + pathname = os.path.join('tests', 'data', 'l7irish', '*.tar.gz') root = str(tmp_path) for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) L7Irish(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): L7Irish(str(tmp_path)) def test_plot_prediction(self, dataset: L7Irish) -> None: x = dataset[dataset.bounds] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() def test_invalid_query(self, dataset: L7Irish) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] def test_rgb_bands_absent_plot(self, dataset: L7Irish) -> None: with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): - ds = L7Irish(dataset.paths, bands=["B10", "B20", "B50"]) + ds = L7Irish(dataset.paths, bands=['B10', 'B20', 'B50']) x = ds[ds.bounds] - ds.plot(x, suptitle="Test") + ds.plot(x, suptitle='Test') plt.close() diff --git a/tests/datasets/test_l8biome.py b/tests/datasets/test_l8biome.py index 7ae0cf14352..96c3209ca0a 100644 --- a/tests/datasets/test_l8biome.py +++ b/tests/datasets/test_l8biome.py @@ -31,15 +31,15 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestL8Biome: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> L8Biome: - monkeypatch.setattr(torchgeo.datasets.l8biome, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.l8biome, 'download_url', download_url) md5s = { - "barren": "29c9910adbc89677389f210226fb163d", - "forest": "b7dbb82fb2c22cbb03389d8828d73713", + 'barren': '29c9910adbc89677389f210226fb163d', + 'forest': 'b7dbb82fb2c22cbb03389d8828d73713', } - url = os.path.join("tests", "data", "l8biome", "{}.tar.gz") - monkeypatch.setattr(L8Biome, "url", url) - monkeypatch.setattr(L8Biome, "md5s", md5s) + url = os.path.join('tests', 'data', 'l8biome', '{}.tar.gz') + monkeypatch.setattr(L8Biome, 'url', url) + monkeypatch.setattr(L8Biome, 'md5s', md5s) root = str(tmp_path) transforms = nn.Identity() return L8Biome(root, transforms=transforms, download=True, checksum=True) @@ -47,9 +47,9 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> L8Biome: def test_getitem(self, dataset: L8Biome) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) def test_and(self, dataset: L8Biome) -> None: ds = dataset & dataset @@ -61,41 +61,41 @@ def test_or(self, dataset: L8Biome) -> None: def test_plot(self, dataset: L8Biome) -> None: x = dataset[dataset.bounds] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_already_extracted(self, dataset: L8Biome) -> None: L8Biome(dataset.paths, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "l8biome", "*.tar.gz") + pathname = os.path.join('tests', 'data', 'l8biome', '*.tar.gz') root = str(tmp_path) for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) L8Biome(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): L8Biome(str(tmp_path)) def test_plot_prediction(self, dataset: L8Biome) -> None: x = dataset[dataset.bounds] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() def test_invalid_query(self, dataset: L8Biome) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] def test_rgb_bands_absent_plot(self, dataset: L8Biome) -> None: with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): - ds = L8Biome(dataset.paths, bands=["B1", "B2", "B5"]) + ds = L8Biome(dataset.paths, bands=['B1', 'B2', 'B5']) x = ds[ds.bounds] - ds.plot(x, suptitle="Test") + ds.plot(x, suptitle='Test') plt.close() diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index e5bb366e5a2..7222ff78bbc 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -29,11 +29,11 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestLandCoverAIGeo: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> LandCoverAIGeo: - monkeypatch.setattr(torchgeo.datasets.landcoverai, "download_url", download_url) - md5 = "ff8998857cc8511f644d3f7d0f3688d0" - monkeypatch.setattr(LandCoverAIGeo, "md5", md5) - url = os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip") - monkeypatch.setattr(LandCoverAIGeo, "url", url) + monkeypatch.setattr(torchgeo.datasets.landcoverai, 'download_url', download_url) + md5 = 'ff8998857cc8511f644d3f7d0f3688d0' + monkeypatch.setattr(LandCoverAIGeo, 'md5', md5) + url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') + monkeypatch.setattr(LandCoverAIGeo, 'url', url) root = str(tmp_path) transforms = nn.Identity() return LandCoverAIGeo(root, transforms=transforms, download=True, checksum=True) @@ -41,53 +41,53 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> LandCoverAIGeo: def test_getitem(self, dataset: LandCoverAIGeo) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) def test_already_extracted(self, dataset: LandCoverAIGeo) -> None: LandCoverAIGeo(dataset.root, download=True) def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: - url = os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip") + url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') root = str(tmp_path) shutil.copy(url, root) LandCoverAIGeo(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): LandCoverAIGeo(str(tmp_path)) def test_out_of_bounds_query(self, dataset: LandCoverAIGeo) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] def test_plot(self, dataset: LandCoverAIGeo) -> None: x = dataset[dataset.bounds].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["mask"][:, :, 0].clone().unsqueeze(2) + x['prediction'] = x['mask'][:, :, 0].clone().unsqueeze(2) dataset.plot(x) plt.close() class TestLandCoverAI: - @pytest.fixture(params=["train", "val", "test"]) + @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> LandCoverAI: - pytest.importorskip("cv2", minversion="4.4.0") - monkeypatch.setattr(torchgeo.datasets.landcoverai, "download_url", download_url) - md5 = "ff8998857cc8511f644d3f7d0f3688d0" - monkeypatch.setattr(LandCoverAI, "md5", md5) - url = os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip") - monkeypatch.setattr(LandCoverAI, "url", url) - sha256 = "ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b" - monkeypatch.setattr(LandCoverAI, "sha256", sha256) + pytest.importorskip('cv2', minversion='4.4.0') + monkeypatch.setattr(torchgeo.datasets.landcoverai, 'download_url', download_url) + md5 = 'ff8998857cc8511f644d3f7d0f3688d0' + monkeypatch.setattr(LandCoverAI, 'md5', md5) + url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') + monkeypatch.setattr(LandCoverAI, 'url', url) + sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' + monkeypatch.setattr(LandCoverAI, 'sha256', sha256) root = str(tmp_path) split = request.param transforms = nn.Identity() @@ -96,8 +96,8 @@ def dataset( def test_getitem(self, dataset: LandCoverAI) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) def test_len(self, dataset: LandCoverAI) -> None: assert len(dataset) == 2 @@ -111,28 +111,28 @@ def test_already_extracted(self, dataset: LandCoverAI) -> None: LandCoverAI(root=dataset.root, download=True) def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: - pytest.importorskip("cv2", minversion="4.4.0") - sha256 = "ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b" - monkeypatch.setattr(LandCoverAI, "sha256", sha256) - url = os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip") + pytest.importorskip('cv2', minversion='4.4.0') + sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' + monkeypatch.setattr(LandCoverAI, 'sha256', sha256) + url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') root = str(tmp_path) shutil.copy(url, root) LandCoverAI(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): LandCoverAI(str(tmp_path)) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - LandCoverAI(split="foo") + LandCoverAI(split='foo') def test_plot(self, dataset: LandCoverAI) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["mask"].clone() + x['prediction'] = x['mask'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_landsat.py b/tests/datasets/test_landsat.py index f058b9a4db0..1d106dd8cd8 100644 --- a/tests/datasets/test_landsat.py +++ b/tests/datasets/test_landsat.py @@ -25,12 +25,12 @@ class TestLandsat8: @pytest.fixture( params=[ - ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"], - ["SR_B4", "SR_B3", "SR_B2", "SR_QA_AEROSOL"], + ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'], + ['SR_B4', 'SR_B3', 'SR_B2', 'SR_QA_AEROSOL'], ] ) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> Landsat8: - root = os.path.join("tests", "data", "landsat8") + root = os.path.join('tests', 'data', 'landsat8') bands = request.param transforms = nn.Identity() return Landsat8(root, bands=bands, transforms=transforms) @@ -41,8 +41,8 @@ def test_separate_files(self, dataset: Landsat8) -> None: def test_getitem(self, dataset: Landsat8) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["image"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['image'], torch.Tensor) def test_and(self, dataset: Landsat8) -> None: ds = dataset & dataset @@ -54,25 +54,25 @@ def test_or(self, dataset: Landsat8) -> None: def test_plot(self, dataset: Landsat8) -> None: x = dataset[dataset.bounds] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_wrong_bands(self, dataset: Landsat8) -> None: - bands = ("SR_B1",) + bands = ('SR_B1',) ds = Landsat8(dataset.paths, bands=bands) x = dataset[dataset.bounds] with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): ds.plot(x) def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): Landsat8(str(tmp_path)) def test_invalid_query(self, dataset: Landsat8) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_levircd.py b/tests/datasets/test_levircd.py index 2cb37c56a31..cbec555746c 100644 --- a/tests/datasets/test_levircd.py +++ b/tests/datasets/test_levircd.py @@ -21,30 +21,30 @@ def download_url(url: str, root: str, *args: str) -> None: class TestLEVIRCD: - @pytest.fixture(params=["train", "val", "test"]) + @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> LEVIRCD: - directory = os.path.join("tests", "data", "levircd", "levircd") + directory = os.path.join('tests', 'data', 'levircd', 'levircd') splits = { - "train": { - "url": os.path.join(directory, "train.zip"), - "filename": "train.zip", - "md5": "7c2e24b3072095519f1be7eb01fae4ff", + 'train': { + 'url': os.path.join(directory, 'train.zip'), + 'filename': 'train.zip', + 'md5': '7c2e24b3072095519f1be7eb01fae4ff', }, - "val": { - "url": os.path.join(directory, "val.zip"), - "filename": "val.zip", - "md5": "5c320223ba88b6fc8ff9d1feebc3b84e", + 'val': { + 'url': os.path.join(directory, 'val.zip'), + 'filename': 'val.zip', + 'md5': '5c320223ba88b6fc8ff9d1feebc3b84e', }, - "test": { - "url": os.path.join(directory, "test.zip"), - "filename": "test.zip", - "md5": "021db72d4486726d6a0702563a617b32", + 'test': { + 'url': os.path.join(directory, 'test.zip'), + 'filename': 'test.zip', + 'md5': '021db72d4486726d6a0702563a617b32', }, } - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - monkeypatch.setattr(LEVIRCD, "splits", splits) + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + monkeypatch.setattr(LEVIRCD, 'splits', splits) root = str(tmp_path) split = request.param transforms = nn.Identity() @@ -53,11 +53,11 @@ def dataset( def test_getitem(self, dataset: LEVIRCD) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image1"], torch.Tensor) - assert isinstance(x["image2"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - assert x["image1"].shape[0] == 3 - assert x["image2"].shape[0] == 3 + assert isinstance(x['image1'], torch.Tensor) + assert isinstance(x['image2'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + assert x['image1'].shape[0] == 3 + assert x['image2'].shape[0] == 3 def test_len(self, dataset: LEVIRCD) -> None: assert len(dataset) == 2 @@ -67,32 +67,32 @@ def test_already_downloaded(self, dataset: LEVIRCD) -> None: def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - LEVIRCD(split="foo") + LEVIRCD(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): LEVIRCD(str(tmp_path)) def test_plot(self, dataset: LEVIRCD) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() sample = dataset[0] - sample["prediction"] = sample["mask"].clone() - dataset.plot(sample, suptitle="Prediction") + sample['prediction'] = sample['mask'].clone() + dataset.plot(sample, suptitle='Prediction') plt.close() class TestLEVIRCDPlus: - @pytest.fixture(params=["train", "test"]) + @pytest.fixture(params=['train', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> LEVIRCDPlus: - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - md5 = "0ccca34310bfe7096dadfbf05b0d180f" - monkeypatch.setattr(LEVIRCDPlus, "md5", md5) - url = os.path.join("tests", "data", "levircd", "levircdplus", "LEVIR-CD+.zip") - monkeypatch.setattr(LEVIRCDPlus, "url", url) + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + md5 = '0ccca34310bfe7096dadfbf05b0d180f' + monkeypatch.setattr(LEVIRCDPlus, 'md5', md5) + url = os.path.join('tests', 'data', 'levircd', 'levircdplus', 'LEVIR-CD+.zip') + monkeypatch.setattr(LEVIRCDPlus, 'url', url) root = str(tmp_path) split = request.param transforms = nn.Identity() @@ -101,11 +101,11 @@ def dataset( def test_getitem(self, dataset: LEVIRCDPlus) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image1"], torch.Tensor) - assert isinstance(x["image2"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - assert x["image1"].shape[0] == 3 - assert x["image2"].shape[0] == 3 + assert isinstance(x['image1'], torch.Tensor) + assert isinstance(x['image2'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + assert x['image1'].shape[0] == 3 + assert x['image2'].shape[0] == 3 def test_len(self, dataset: LEVIRCDPlus) -> None: assert len(dataset) == 2 @@ -115,17 +115,17 @@ def test_already_downloaded(self, dataset: LEVIRCDPlus) -> None: def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - LEVIRCDPlus(split="foo") + LEVIRCDPlus(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): LEVIRCDPlus(str(tmp_path)) def test_plot(self, dataset: LEVIRCDPlus) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() sample = dataset[0] - sample["prediction"] = sample["mask"].clone() - dataset.plot(sample, suptitle="Prediction") + sample['prediction'] = sample['mask'].clone() + dataset.plot(sample, suptitle='Prediction') plt.close() diff --git a/tests/datasets/test_loveda.py b/tests/datasets/test_loveda.py index a368d711034..be36bec2f1e 100644 --- a/tests/datasets/test_loveda.py +++ b/tests/datasets/test_loveda.py @@ -21,32 +21,32 @@ def download_url(url: str, root: str, *args: str) -> None: class TestLoveDA: - @pytest.fixture(params=["train", "val", "test"]) + @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> LoveDA: - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - md5 = "3d5b1373ef9a3084ec493b9b2056fe07" + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + md5 = '3d5b1373ef9a3084ec493b9b2056fe07' info_dict = { - "train": { - "url": os.path.join("tests", "data", "loveda", "Train.zip"), - "filename": "Train.zip", - "md5": md5, + 'train': { + 'url': os.path.join('tests', 'data', 'loveda', 'Train.zip'), + 'filename': 'Train.zip', + 'md5': md5, }, - "val": { - "url": os.path.join("tests", "data", "loveda", "Val.zip"), - "filename": "Val.zip", - "md5": md5, + 'val': { + 'url': os.path.join('tests', 'data', 'loveda', 'Val.zip'), + 'filename': 'Val.zip', + 'md5': md5, }, - "test": { - "url": os.path.join("tests", "data", "loveda", "Test.zip"), - "filename": "Test.zip", - "md5": md5, + 'test': { + 'url': os.path.join('tests', 'data', 'loveda', 'Test.zip'), + 'filename': 'Test.zip', + 'md5': md5, }, } - monkeypatch.setattr(LoveDA, "info_dict", info_dict) + monkeypatch.setattr(LoveDA, 'info_dict', info_dict) root = str(tmp_path) split = request.param @@ -58,14 +58,14 @@ def dataset( def test_getitem(self, dataset: LoveDA) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert x["image"].shape[0] == 3 + assert isinstance(x['image'], torch.Tensor) + assert x['image'].shape[0] == 3 - if dataset.split != "test": - assert isinstance(x["mask"], torch.Tensor) - assert x["image"].shape[-2:] == x["mask"].shape[-2:] + if dataset.split != 'test': + assert isinstance(x['mask'], torch.Tensor) + assert x['image'].shape[-2:] == x['mask'].shape[-2:] else: - assert "mask" not in x + assert 'mask' not in x def test_len(self, dataset: LoveDA) -> None: assert len(dataset) == 2 @@ -76,16 +76,16 @@ def test_already_downloaded(self, dataset: LoveDA) -> None: def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - LoveDA(split="foo") + LoveDA(split='foo') def test_invalid_scene(self) -> None: with pytest.raises(AssertionError): - LoveDA(scene=["garden"]) + LoveDA(scene=['garden']) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): LoveDA(str(tmp_path)) def test_plot(self, dataset: LoveDA) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() diff --git a/tests/datasets/test_mapinwild.py b/tests/datasets/test_mapinwild.py index 90aa35b6aa2..aff7d200099 100644 --- a/tests/datasets/test_mapinwild.py +++ b/tests/datasets/test_mapinwild.py @@ -23,50 +23,50 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestMapInWild: - @pytest.fixture(params=["train", "validation", "test"]) + @pytest.fixture(params=['train', 'validation', 'test']) def dataset( self, tmp_path: Path, monkeypatch: MonkeyPatch, request: SubRequest ) -> MapInWild: - monkeypatch.setattr(torchgeo.datasets.mapinwild, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.mapinwild, 'download_url', download_url) md5s = { - "ESA_WC.zip": "3a1e696353d238c50996958855da02fc", - "VIIRS.zip": "e8b0e230edb1183c02092357af83bd52", - "mask.zip": "15245bb6368d27dbb4bd16310f4604fa", - "s1_part1.zip": "e660da4175518af993b63644e44a9d03", - "s1_part2.zip": "620cf0a7d598a2893bc7642ad7ee6087", - "s2_autumn_part1.zip": "624b6cf0191c5e0bc0d51f92b568e676", - "s2_autumn_part2.zip": "f848c62b8de36f06f12fb6b1b065c7b6", - "s2_spring_part1.zip": "3296f3a7da7e485708dd16be91deb111", - "s2_spring_part2.zip": "d27e94387a59f0558fe142a791682861", - "s2_summer_part1.zip": "41d783706c3c1e4238556a772d3232fb", - "s2_summer_part2.zip": "3495c87b67a771cfac5153d1958daa0c", - "s2_temporal_subset_part1.zip": "06fa463888cb033011a06cf69f82273e", - "s2_temporal_subset_part2.zip": "93e5383adeeea27f00051ecf110fcef8", - "s2_winter_part1.zip": "617abe1c6ad8d38725aa27c9dcc38ceb", - "s2_winter_part2.zip": "4e40d7bb0eec4ddea0b7b00314239a49", - "split_IDs.csv": "ca22c3d30d0b62e001ed0c327c147127", + 'ESA_WC.zip': '3a1e696353d238c50996958855da02fc', + 'VIIRS.zip': 'e8b0e230edb1183c02092357af83bd52', + 'mask.zip': '15245bb6368d27dbb4bd16310f4604fa', + 's1_part1.zip': 'e660da4175518af993b63644e44a9d03', + 's1_part2.zip': '620cf0a7d598a2893bc7642ad7ee6087', + 's2_autumn_part1.zip': '624b6cf0191c5e0bc0d51f92b568e676', + 's2_autumn_part2.zip': 'f848c62b8de36f06f12fb6b1b065c7b6', + 's2_spring_part1.zip': '3296f3a7da7e485708dd16be91deb111', + 's2_spring_part2.zip': 'd27e94387a59f0558fe142a791682861', + 's2_summer_part1.zip': '41d783706c3c1e4238556a772d3232fb', + 's2_summer_part2.zip': '3495c87b67a771cfac5153d1958daa0c', + 's2_temporal_subset_part1.zip': '06fa463888cb033011a06cf69f82273e', + 's2_temporal_subset_part2.zip': '93e5383adeeea27f00051ecf110fcef8', + 's2_winter_part1.zip': '617abe1c6ad8d38725aa27c9dcc38ceb', + 's2_winter_part2.zip': '4e40d7bb0eec4ddea0b7b00314239a49', + 'split_IDs.csv': 'ca22c3d30d0b62e001ed0c327c147127', } - monkeypatch.setattr(MapInWild, "md5s", md5s) + monkeypatch.setattr(MapInWild, 'md5s', md5s) - urls = os.path.join("tests", "data", "mapinwild") - monkeypatch.setattr(MapInWild, "url", urls) + urls = os.path.join('tests', 'data', 'mapinwild') + monkeypatch.setattr(MapInWild, 'url', urls) root = str(tmp_path) split = request.param transforms = nn.Identity() modality = [ - "mask", - "viirs", - "esa_wc", - "s2_winter", - "s1", - "s2_summer", - "s2_spring", - "s2_autumn", - "s2_temporal_subset", + 'mask', + 'viirs', + 'esa_wc', + 's2_winter', + 's1', + 's2_summer', + 's2_spring', + 's2_autumn', + 's2_temporal_subset', ] return MapInWild( root, @@ -80,9 +80,9 @@ def dataset( def test_getitem(self, dataset: MapInWild) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - assert x["image"].ndim == 3 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + assert x['image'].ndim == 3 def test_len(self, dataset: MapInWild) -> None: assert len(dataset) == 1 @@ -94,14 +94,14 @@ def test_add(self, dataset: MapInWild) -> None: def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - MapInWild(split="foo") + MapInWild(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): MapInWild(root=str(tmp_path)) def test_downloaded_not_extracted(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "mapinwild", "*", "*") + pathname = os.path.join('tests', 'data', 'mapinwild', '*', '*') pathname_glob = glob.glob(pathname) root = str(tmp_path) for zipfile in pathname_glob: @@ -109,18 +109,18 @@ def test_downloaded_not_extracted(self, tmp_path: Path) -> None: MapInWild(root, download=False) def test_corrupted(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "mapinwild", "**", "*.zip") + pathname = os.path.join('tests', 'data', 'mapinwild', '**', '*.zip') pathname_glob = glob.glob(pathname, recursive=True) root = str(tmp_path) for zipfile in pathname_glob: shutil.copy(zipfile, root) splitfile = os.path.join( - "tests", "data", "mapinwild", "split_IDs", "split_IDs.csv" + 'tests', 'data', 'mapinwild', 'split_IDs', 'split_IDs.csv' ) shutil.copy(splitfile, root) - with open(os.path.join(tmp_path, "mask.zip"), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with open(os.path.join(tmp_path, 'mask.zip'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): MapInWild(root=str(tmp_path), download=True, checksum=True) def test_already_downloaded(self, dataset: MapInWild, tmp_path: Path) -> None: @@ -128,10 +128,10 @@ def test_already_downloaded(self, dataset: MapInWild, tmp_path: Path) -> None: def test_plot(self, dataset: MapInWild) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["mask"].clone() + x['prediction'] = x['mask'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_millionaid.py b/tests/datasets/test_millionaid.py index 1e94fd003d0..349006ce248 100644 --- a/tests/datasets/test_millionaid.py +++ b/tests/datasets/test_millionaid.py @@ -16,10 +16,10 @@ class TestMillionAID: @pytest.fixture( - scope="class", params=zip(["train", "test"], ["multi-class", "multi-label"]) + scope='class', params=zip(['train', 'test'], ['multi-class', 'multi-label']) ) def dataset(self, request: SubRequest) -> MillionAID: - root = os.path.join("tests", "data", "millionaid") + root = os.path.join('tests', 'data', 'millionaid') split, task = request.param transforms = nn.Identity() return MillionAID( @@ -29,36 +29,36 @@ def dataset(self, request: SubRequest) -> MillionAID: def test_getitem(self, dataset: MillionAID) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) - assert x["image"].shape[0] == 3 - assert x["image"].ndim == 3 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) + assert x['image'].shape[0] == 3 + assert x['image'].ndim == 3 def test_len(self, dataset: MillionAID) -> None: assert len(dataset) == 2 def test_not_found(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): MillionAID(str(tmp_path)) def test_not_extracted(self, tmp_path: Path) -> None: - url = os.path.join("tests", "data", "millionaid", "train.zip") + url = os.path.join('tests', 'data', 'millionaid', 'train.zip') shutil.copy(url, tmp_path) MillionAID(str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: - with open(os.path.join(tmp_path, "train.zip"), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with open(os.path.join(tmp_path, 'train.zip'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): MillionAID(str(tmp_path), checksum=True) def test_plot(self, dataset: MillionAID) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: MillionAID) -> None: x = dataset[0].copy() - x["prediction"] = x["label"].clone() + x['prediction'] = x['label'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_naip.py b/tests/datasets/test_naip.py index fe257ae2b78..1526f3df5d2 100644 --- a/tests/datasets/test_naip.py +++ b/tests/datasets/test_naip.py @@ -22,15 +22,15 @@ class TestNAIP: @pytest.fixture def dataset(self) -> NAIP: - root = os.path.join("tests", "data", "naip") + root = os.path.join('tests', 'data', 'naip') transforms = nn.Identity() return NAIP(root, transforms=transforms) def test_getitem(self, dataset: NAIP) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["image"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['image'], torch.Tensor) def test_and(self, dataset: NAIP) -> None: ds = dataset & dataset @@ -43,16 +43,16 @@ def test_or(self, dataset: NAIP) -> None: def test_plot(self, dataset: NAIP) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): NAIP(str(tmp_path)) def test_invalid_query(self, dataset: NAIP) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index f475234ffe6..588cd89174a 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -17,7 +17,7 @@ class Collection: def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join("tests", "data", "nasa_marine_debris", "*.tar.gz") + glob_path = os.path.join('tests', 'data', 'nasa_marine_debris', '*.tar.gz') for tarball in glob.iglob(glob_path): shutil.copy(tarball, output_dir) @@ -30,8 +30,8 @@ class Collection_corrupted: def download(self, output_dir: str, **kwargs: str) -> None: filenames = NASAMarineDebris.filenames for filename in filenames: - with open(os.path.join(output_dir, filename), "w") as f: - f.write("bad") + with open(os.path.join(output_dir, filename), 'w') as f: + f.write('bad') def fetch_corrupted(collection_id: str, **kwargs: str) -> Collection_corrupted: @@ -41,10 +41,10 @@ def fetch_corrupted(collection_id: str, **kwargs: str) -> Collection_corrupted: class TestNASAMarineDebris: @pytest.fixture() def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NASAMarineDebris: - radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) - md5s = ["6f4f0d2313323950e45bf3fc0c09b5de", "540cf1cf4fd2c13b609d0355abe955d7"] - monkeypatch.setattr(NASAMarineDebris, "md5s", md5s) + radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) + md5s = ['6f4f0d2313323950e45bf3fc0c09b5de', '540cf1cf4fd2c13b609d0355abe955d7'] + monkeypatch.setattr(NASAMarineDebris, 'md5s', md5s) root = str(tmp_path) transforms = nn.Identity() return NASAMarineDebris(root, transforms, download=True, checksum=True) @@ -52,10 +52,10 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NASAMarineDebris: def test_getitem(self, dataset: NASAMarineDebris) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["boxes"], torch.Tensor) - assert x["image"].shape[0] == 3 - assert x["boxes"].shape[-1] == 4 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['boxes'], torch.Tensor) + assert x['image'].shape[0] == 3 + assert x['boxes'].shape[-1] == 4 def test_len(self, dataset: NASAMarineDebris) -> None: assert len(dataset) == 4 @@ -76,29 +76,29 @@ def test_already_downloaded_not_extracted( def test_corrupted_previously_downloaded(self, tmp_path: Path) -> None: filenames = NASAMarineDebris.filenames for filename in filenames: - with open(os.path.join(tmp_path, filename), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset checksum mismatch."): + with open(os.path.join(tmp_path, filename), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'): NASAMarineDebris(root=str(tmp_path), download=False, checksum=True) def test_corrupted_new_download( self, tmp_path: Path, monkeypatch: MonkeyPatch ) -> None: - with pytest.raises(RuntimeError, match="Dataset checksum mismatch."): - radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_corrupted) + with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'): + radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_corrupted) NASAMarineDebris(root=str(tmp_path), download=True, checksum=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): NASAMarineDebris(str(tmp_path)) def test_plot(self, dataset: NASAMarineDebris) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction_boxes"] = x["boxes"].clone() + x['prediction_boxes'] = x['boxes'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_nccm.py b/tests/datasets/test_nccm.py index 0d922d9d3d5..a2b1304cb6e 100644 --- a/tests/datasets/test_nccm.py +++ b/tests/datasets/test_nccm.py @@ -24,19 +24,19 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestNCCM: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NCCM: - monkeypatch.setattr(torchgeo.datasets.nccm, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.nccm, 'download_url', download_url) md5s = { - 2017: "ae5c390d0ffb8970d544b8a09142759f", - 2018: "0d453bdb8ea5b7318c33e62513760580", - 2019: "d4ab7ab00bb57623eafb6b27747e5639", + 2017: 'ae5c390d0ffb8970d544b8a09142759f', + 2018: '0d453bdb8ea5b7318c33e62513760580', + 2019: 'd4ab7ab00bb57623eafb6b27747e5639', } - monkeypatch.setattr(NCCM, "md5s", md5s) + monkeypatch.setattr(NCCM, 'md5s', md5s) urls = { - 2017: os.path.join("tests", "data", "nccm", "CDL2017_clip.tif"), - 2018: os.path.join("tests", "data", "nccm", "CDL2018_clip1.tif"), - 2019: os.path.join("tests", "data", "nccm", "CDL2019_clip.tif"), + 2017: os.path.join('tests', 'data', 'nccm', 'CDL2017_clip.tif'), + 2018: os.path.join('tests', 'data', 'nccm', 'CDL2018_clip1.tif'), + 2019: os.path.join('tests', 'data', 'nccm', 'CDL2019_clip.tif'), } - monkeypatch.setattr(NCCM, "urls", urls) + monkeypatch.setattr(NCCM, 'urls', urls) transforms = nn.Identity() root = str(tmp_path) return NCCM(root, transforms=transforms, download=True, checksum=True) @@ -44,8 +44,8 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NCCM: def test_getitem(self, dataset: NCCM) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_and(self, dataset: NCCM) -> None: ds = dataset & dataset @@ -64,23 +64,23 @@ def test_already_downloaded(self, dataset: NCCM) -> None: def test_plot(self, dataset: NCCM) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: NCCM) -> None: query = dataset.bounds x = dataset[query] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): NCCM(str(tmp_path)) def test_invalid_query(self, dataset: NCCM) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_nlcd.py b/tests/datasets/test_nlcd.py index 67dde52648d..8dd40ce7e83 100644 --- a/tests/datasets/test_nlcd.py +++ b/tests/datasets/test_nlcd.py @@ -29,19 +29,19 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestNLCD: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NLCD: - monkeypatch.setattr(torchgeo.datasets.nlcd, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.nlcd, 'download_url', download_url) md5s = { - 2011: "99546a3b89a0dddbe4e28e661c79984e", - 2019: "a4008746f15720b8908ddd357a75fded", + 2011: '99546a3b89a0dddbe4e28e661c79984e', + 2019: 'a4008746f15720b8908ddd357a75fded', } - monkeypatch.setattr(NLCD, "md5s", md5s) + monkeypatch.setattr(NLCD, 'md5s', md5s) url = os.path.join( - "tests", "data", "nlcd", "nlcd_{}_land_cover_l48_20210604.zip" + 'tests', 'data', 'nlcd', 'nlcd_{}_land_cover_l48_20210604.zip' ) - monkeypatch.setattr(NLCD, "url", url) - monkeypatch.setattr(plt, "show", lambda *args: None) + monkeypatch.setattr(NLCD, 'url', url) + monkeypatch.setattr(plt, 'show', lambda *args: None) root = str(tmp_path) transforms = nn.Identity() return NLCD( @@ -55,15 +55,15 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NLCD: def test_getitem(self, dataset: NLCD) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_classes(self) -> None: - root = os.path.join("tests", "data", "nlcd") + root = os.path.join('tests', 'data', 'nlcd') classes = list(NLCD.cmap.keys())[:5] ds = NLCD(root, years=[2019], classes=classes) sample = ds[ds.bounds] - mask = sample["mask"] + mask = sample['mask'] assert mask.max() < len(classes) def test_and(self, dataset: NLCD) -> None: @@ -79,7 +79,7 @@ def test_already_extracted(self, dataset: NLCD) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( - "tests", "data", "nlcd", "nlcd_2019_land_cover_l48_20210604.zip" + 'tests', 'data', 'nlcd', 'nlcd_2019_land_cover_l48_20210604.zip' ) root = str(tmp_path) shutil.copy(pathname, root) @@ -88,7 +88,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: def test_invalid_year(self, tmp_path: Path) -> None: with pytest.raises( AssertionError, - match="NLCD data product only exists for the following years:", + match='NLCD data product only exists for the following years:', ): NLCD(str(tmp_path), years=[1996]) @@ -102,23 +102,23 @@ def test_invalid_classes(self) -> None: def test_plot(self, dataset: NLCD) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: NLCD) -> None: query = dataset.bounds x = dataset[query] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): NLCD(str(tmp_path)) def test_invalid_query(self, dataset: NLCD) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_openbuildings.py b/tests/datasets/test_openbuildings.py index 65244962553..f591013636f 100644 --- a/tests/datasets/test_openbuildings.py +++ b/tests/datasets/test_openbuildings.py @@ -28,15 +28,15 @@ class TestOpenBuildings: def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> OpenBuildings: root = str(tmp_path) shutil.copy( - os.path.join("tests", "data", "openbuildings", "tiles.geojson"), root + os.path.join('tests', 'data', 'openbuildings', 'tiles.geojson'), root ) shutil.copy( - os.path.join("tests", "data", "openbuildings", "000_buildings.csv.gz"), root + os.path.join('tests', 'data', 'openbuildings', '000_buildings.csv.gz'), root ) - md5s = {"000_buildings.csv.gz": "20aeeec9d45a0ce4d772a26e0bcbc25f"} + md5s = {'000_buildings.csv.gz': '20aeeec9d45a0ce4d772a26e0bcbc25f'} - monkeypatch.setattr(OpenBuildings, "md5s", md5s) + monkeypatch.setattr(OpenBuildings, 'md5s', md5s) transforms = nn.Identity() return OpenBuildings(root, transforms=transforms) @@ -44,42 +44,42 @@ def test_no_shapes_to_rasterize( self, dataset: OpenBuildings, tmp_path: Path ) -> None: # empty csv buildings file - path = os.path.join(tmp_path, "000_buildings.csv.gz") + path = os.path.join(tmp_path, '000_buildings.csv.gz') df = pd.read_csv(path) df = pd.DataFrame(columns=df.columns) - df.to_csv(path, compression="gzip") + df.to_csv(path, compression='gzip') x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_not_download(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): OpenBuildings(str(tmp_path)) def test_corrupted(self, dataset: OpenBuildings, tmp_path: Path) -> None: - with open(os.path.join(tmp_path, "000_buildings.csv.gz"), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with open(os.path.join(tmp_path, '000_buildings.csv.gz'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): OpenBuildings(dataset.paths, checksum=True) def test_nothing_in_index(self, dataset: OpenBuildings, tmp_path: Path) -> None: # change meta data to another 'title_url' so that there is no match found - with open(os.path.join(tmp_path, "tiles.geojson")) as f: + with open(os.path.join(tmp_path, 'tiles.geojson')) as f: content = json.load(f) - content["features"][0]["properties"]["tile_url"] = "mismatch.csv.gz" + content['features'][0]['properties']['tile_url'] = 'mismatch.csv.gz' - with open(os.path.join(tmp_path, "tiles.geojson"), "w") as f: + with open(os.path.join(tmp_path, 'tiles.geojson'), 'w') as f: json.dump(content, f) - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): OpenBuildings(dataset.paths) def test_getitem(self, dataset: OpenBuildings) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_and(self, dataset: OpenBuildings) -> None: ds = dataset & dataset @@ -92,17 +92,17 @@ def test_or(self, dataset: OpenBuildings) -> None: def test_invalid_query(self, dataset: OpenBuildings) -> None: query = BoundingBox(100, 100, 100, 100, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] def test_plot(self, dataset: OpenBuildings) -> None: x = dataset[dataset.bounds] - dataset.plot(x, suptitle="test") + dataset.plot(x, suptitle='test') plt.close() def test_plot_prediction(self, dataset: OpenBuildings) -> None: x = dataset[dataset.bounds] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() diff --git a/tests/datasets/test_oscd.py b/tests/datasets/test_oscd.py index 4efc6c91a45..cd1c80a443b 100644 --- a/tests/datasets/test_oscd.py +++ b/tests/datasets/test_oscd.py @@ -23,44 +23,44 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestOSCD: - @pytest.fixture(params=zip([OSCD.all_bands, OSCD.rgb_bands], ["train", "test"])) + @pytest.fixture(params=zip([OSCD.all_bands, OSCD.rgb_bands], ['train', 'test'])) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> OSCD: - monkeypatch.setattr(torchgeo.datasets.oscd, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.oscd, 'download_url', download_url) md5s = { - "Onera Satellite Change Detection dataset - Images.zip": ( - "fb4e3f54c3a31fd3f21f98cad4ddfb74" + 'Onera Satellite Change Detection dataset - Images.zip': ( + 'fb4e3f54c3a31fd3f21f98cad4ddfb74' ), - "Onera Satellite Change Detection dataset - Train Labels.zip": ( - "ca526434a60e9abdf97d528dc29e9f13" + 'Onera Satellite Change Detection dataset - Train Labels.zip': ( + 'ca526434a60e9abdf97d528dc29e9f13' ), - "Onera Satellite Change Detection dataset - Test Labels.zip": ( - "ca0ba73ba66d06fa4903e269ef12eb50" + 'Onera Satellite Change Detection dataset - Test Labels.zip': ( + 'ca0ba73ba66d06fa4903e269ef12eb50' ), } - monkeypatch.setattr(OSCD, "md5s", md5s) + monkeypatch.setattr(OSCD, 'md5s', md5s) urls = { - "Onera Satellite Change Detection dataset - Images.zip": os.path.join( - "tests", - "data", - "oscd", - "Onera Satellite Change Detection dataset - Images.zip", + 'Onera Satellite Change Detection dataset - Images.zip': os.path.join( + 'tests', + 'data', + 'oscd', + 'Onera Satellite Change Detection dataset - Images.zip', ), - "Onera Satellite Change Detection dataset - Train Labels.zip": os.path.join( - "tests", - "data", - "oscd", - "Onera Satellite Change Detection dataset - Train Labels.zip", + 'Onera Satellite Change Detection dataset - Train Labels.zip': os.path.join( + 'tests', + 'data', + 'oscd', + 'Onera Satellite Change Detection dataset - Train Labels.zip', ), - "Onera Satellite Change Detection dataset - Test Labels.zip": os.path.join( - "tests", - "data", - "oscd", - "Onera Satellite Change Detection dataset - Test Labels.zip", + 'Onera Satellite Change Detection dataset - Test Labels.zip': os.path.join( + 'tests', + 'data', + 'oscd', + 'Onera Satellite Change Detection dataset - Test Labels.zip', ), } - monkeypatch.setattr(OSCD, "urls", urls) + monkeypatch.setattr(OSCD, 'urls', urls) bands, split = request.param root = str(tmp_path) @@ -72,22 +72,22 @@ def dataset( def test_getitem(self, dataset: OSCD) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image1"], torch.Tensor) - assert x["image1"].ndim == 3 - assert isinstance(x["image2"], torch.Tensor) - assert x["image2"].ndim == 3 - assert isinstance(x["mask"], torch.Tensor) - assert x["mask"].ndim == 2 + assert isinstance(x['image1'], torch.Tensor) + assert x['image1'].ndim == 3 + assert isinstance(x['image2'], torch.Tensor) + assert x['image2'].ndim == 3 + assert isinstance(x['mask'], torch.Tensor) + assert x['mask'].ndim == 2 if dataset.bands == OSCD.rgb_bands: - assert x["image1"].shape[0] == 3 - assert x["image2"].shape[0] == 3 + assert x['image1'].shape[0] == 3 + assert x['image2'].shape[0] == 3 else: - assert x["image1"].shape[0] == 13 - assert x["image2"].shape[0] == 13 + assert x['image1'].shape[0] == 13 + assert x['image2'].shape[0] == 13 def test_len(self, dataset: OSCD) -> None: - if dataset.split == "train": + if dataset.split == 'train': assert len(dataset) == 2 else: assert len(dataset) == 1 @@ -100,24 +100,24 @@ def test_already_extracted(self, dataset: OSCD) -> None: OSCD(root=dataset.root, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "oscd", "*Onera*.zip") + pathname = os.path.join('tests', 'data', 'oscd', '*Onera*.zip') root = str(tmp_path) for zipfile in glob.iglob(pathname): shutil.copy(zipfile, root) OSCD(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): OSCD(str(tmp_path)) def test_plot(self, dataset: OSCD) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() def test_failed_plot(self, dataset: OSCD) -> None: - single_band_dataset = OSCD(root=dataset.root, bands=("B01",)) + single_band_dataset = OSCD(root=dataset.root, bands=('B01',)) with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): x = single_band_dataset[0].copy() - single_band_dataset.plot(x, suptitle="Test") + single_band_dataset.plot(x, suptitle='Test') diff --git a/tests/datasets/test_pastis.py b/tests/datasets/test_pastis.py index d0284688cf4..62ff5f913e6 100644 --- a/tests/datasets/test_pastis.py +++ b/tests/datasets/test_pastis.py @@ -24,24 +24,24 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestPASTIS: @pytest.fixture( params=[ - {"folds": (1, 2), "bands": "s2", "mode": "semantic"}, - {"folds": (1, 2), "bands": "s1a", "mode": "semantic"}, - {"folds": (1, 2), "bands": "s1d", "mode": "instance"}, + {'folds': (1, 2), 'bands': 's2', 'mode': 'semantic'}, + {'folds': (1, 2), 'bands': 's1a', 'mode': 'semantic'}, + {'folds': (1, 2), 'bands': 's1d', 'mode': 'instance'}, ] ) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> PASTIS: - monkeypatch.setattr(torchgeo.datasets.pastis, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.pastis, 'download_url', download_url) - md5 = "135a29fb8221241dde14f31579c07f45" - monkeypatch.setattr(PASTIS, "md5", md5) - url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip") - monkeypatch.setattr(PASTIS, "url", url) + md5 = '135a29fb8221241dde14f31579c07f45' + monkeypatch.setattr(PASTIS, 'md5', md5) + url = os.path.join('tests', 'data', 'pastis', 'PASTIS-R.zip') + monkeypatch.setattr(PASTIS, 'url', url) root = str(tmp_path) - folds = request.param["folds"] - bands = request.param["bands"] - mode = request.param["mode"] + folds = request.param['folds'] + bands = request.param['bands'] + mode = request.param['mode'] transforms = nn.Identity() return PASTIS( root, folds, bands, mode, transforms, download=True, checksum=True @@ -50,17 +50,17 @@ def dataset( def test_getitem_semantic(self, dataset: PASTIS) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) def test_getitem_instance(self, dataset: PASTIS) -> None: - dataset.mode = "instance" + dataset.mode = 'instance' x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - assert isinstance(x["boxes"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + assert isinstance(x['boxes'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) def test_len(self, dataset: PASTIS) -> None: assert len(dataset) == 2 @@ -74,19 +74,19 @@ def test_already_extracted(self, dataset: PASTIS) -> None: PASTIS(root=dataset.root, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip") + url = os.path.join('tests', 'data', 'pastis', 'PASTIS-R.zip') root = str(tmp_path) shutil.copy(url, root) PASTIS(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): PASTIS(str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: - with open(os.path.join(tmp_path, "PASTIS-R.zip"), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with open(os.path.join(tmp_path, 'PASTIS-R.zip'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): PASTIS(root=str(tmp_path), checksum=True) def test_invalid_fold(self) -> None: @@ -95,16 +95,16 @@ def test_invalid_fold(self) -> None: def test_invalid_mode(self) -> None: with pytest.raises(AssertionError): - PASTIS(mode="invalid") + PASTIS(mode='invalid') def test_plot(self, dataset: PASTIS) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["mask"].clone() - if dataset.mode == "instance": - x["prediction_labels"] = x["label"].clone() + x['prediction'] = x['mask'].clone() + if dataset.mode == 'instance': + x['prediction_labels'] = x['label'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_patternnet.py b/tests/datasets/test_patternnet.py index efab8bd7b31..915d7388bad 100644 --- a/tests/datasets/test_patternnet.py +++ b/tests/datasets/test_patternnet.py @@ -20,13 +20,13 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestPatternNet: - @pytest.fixture(params=["train", "test"]) + @pytest.fixture(params=['train', 'test']) def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> PatternNet: - monkeypatch.setattr(torchgeo.datasets.patternnet, "download_url", download_url) - md5 = "5649754c78219a2c19074ff93666cc61" - monkeypatch.setattr(PatternNet, "md5", md5) - url = os.path.join("tests", "data", "patternnet", "PatternNet.zip") - monkeypatch.setattr(PatternNet, "url", url) + monkeypatch.setattr(torchgeo.datasets.patternnet, 'download_url', download_url) + md5 = '5649754c78219a2c19074ff93666cc61' + monkeypatch.setattr(PatternNet, 'md5', md5) + url = os.path.join('tests', 'data', 'patternnet', 'PatternNet.zip') + monkeypatch.setattr(PatternNet, 'url', url) root = str(tmp_path) transforms = nn.Identity() return PatternNet(root, transforms, download=True, checksum=True) @@ -34,9 +34,9 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> PatternNet: def test_getitem(self, dataset: PatternNet) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) - assert x["image"].shape[0] == 3 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) + assert x['image'].shape[0] == 3 def test_len(self, dataset: PatternNet) -> None: assert len(dataset) == 2 @@ -52,14 +52,14 @@ def test_already_downloaded_not_extracted( PatternNet(root=str(tmp_path), download=False) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): PatternNet(str(tmp_path)) def test_plot(self, dataset: PatternNet) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() sample = dataset[0] - sample["prediction"] = sample["label"].clone() - dataset.plot(sample, suptitle="Prediction") + sample['prediction'] = sample['label'].clone() + dataset.plot(sample, suptitle='Prediction') plt.close() diff --git a/tests/datasets/test_potsdam.py b/tests/datasets/test_potsdam.py index b803b15ea95..4529d937690 100644 --- a/tests/datasets/test_potsdam.py +++ b/tests/datasets/test_potsdam.py @@ -16,16 +16,16 @@ class TestPotsdam2D: - @pytest.fixture(params=["train", "test"]) + @pytest.fixture(params=['train', 'test']) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> Potsdam2D: - md5s = ["e47175da529c5844052c7d483b483a30", "0cb795003a01154a72db7efaabbc76ae"] + md5s = ['e47175da529c5844052c7d483b483a30', '0cb795003a01154a72db7efaabbc76ae'] splits = { - "train": ["top_potsdam_2_10", "top_potsdam_2_11"], - "test": ["top_potsdam_5_15", "top_potsdam_6_15"], + 'train': ['top_potsdam_2_10', 'top_potsdam_2_11'], + 'test': ['top_potsdam_5_15', 'top_potsdam_6_15'], } - monkeypatch.setattr(Potsdam2D, "md5s", md5s) - monkeypatch.setattr(Potsdam2D, "splits", splits) - root = os.path.join("tests", "data", "potsdam") + monkeypatch.setattr(Potsdam2D, 'md5s', md5s) + monkeypatch.setattr(Potsdam2D, 'splits', splits) + root = os.path.join('tests', 'data', 'potsdam') split = request.param transforms = nn.Identity() return Potsdam2D(root, split, transforms, checksum=True) @@ -33,42 +33,42 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> Potsdam2D: def test_getitem(self, dataset: Potsdam2D) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) def test_len(self, dataset: Potsdam2D) -> None: assert len(dataset) == 2 def test_extract(self, tmp_path: Path) -> None: - root = os.path.join("tests", "data", "potsdam") - for filename in ["4_Ortho_RGBIR.zip", "5_Labels_all.zip"]: + root = os.path.join('tests', 'data', 'potsdam') + for filename in ['4_Ortho_RGBIR.zip', '5_Labels_all.zip']: shutil.copyfile( os.path.join(root, filename), os.path.join(str(tmp_path), filename) ) Potsdam2D(root=str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: - with open(os.path.join(tmp_path, "4_Ortho_RGBIR.zip"), "w") as f: - f.write("bad") - with open(os.path.join(tmp_path, "5_Labels_all.zip"), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with open(os.path.join(tmp_path, '4_Ortho_RGBIR.zip'), 'w') as f: + f.write('bad') + with open(os.path.join(tmp_path, '5_Labels_all.zip'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): Potsdam2D(root=str(tmp_path), checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - Potsdam2D(split="foo") + Potsdam2D(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): Potsdam2D(str(tmp_path)) def test_plot(self, dataset: Potsdam2D) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["mask"].clone() + x['prediction'] = x['mask'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_prisma.py b/tests/datasets/test_prisma.py index d7e884d185f..89ab52c7275 100644 --- a/tests/datasets/test_prisma.py +++ b/tests/datasets/test_prisma.py @@ -22,7 +22,7 @@ class TestPRISMA: @pytest.fixture def dataset(self) -> PRISMA: - paths = os.path.join("tests", "data", "prisma") + paths = os.path.join('tests', 'data', 'prisma') transforms = nn.Identity() return PRISMA(paths, transforms=transforms) @@ -32,8 +32,8 @@ def test_len(self, dataset: PRISMA) -> None: def test_getitem(self, dataset: PRISMA) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["image"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['image'], torch.Tensor) def test_and(self, dataset: PRISMA) -> None: ds = dataset & dataset @@ -45,16 +45,16 @@ def test_or(self, dataset: PRISMA) -> None: def test_plot(self, dataset: PRISMA) -> None: x = dataset[dataset.bounds] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): PRISMA(str(tmp_path)) def test_invalid_query(self, dataset: PRISMA) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_reforestree.py b/tests/datasets/test_reforestree.py index b2f3a16eaef..092e7cf2f1f 100644 --- a/tests/datasets/test_reforestree.py +++ b/tests/datasets/test_reforestree.py @@ -22,15 +22,15 @@ def download_url(url: str, root: str, *args: str) -> None: class TestReforesTree: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ReforesTree: - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - data_dir = os.path.join("tests", "data", "reforestree") + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + data_dir = os.path.join('tests', 'data', 'reforestree') - url = os.path.join(data_dir, "reforesTree.zip") + url = os.path.join(data_dir, 'reforesTree.zip') - md5 = "387e04dbbb0aa803f72bd6d774409648" + md5 = '387e04dbbb0aa803f72bd6d774409648' - monkeypatch.setattr(ReforesTree, "url", url) - monkeypatch.setattr(ReforesTree, "md5", md5) + monkeypatch.setattr(ReforesTree, 'url', url) + monkeypatch.setattr(ReforesTree, 'md5', md5) root = str(tmp_path) transforms = nn.Identity() return ReforesTree( @@ -43,39 +43,39 @@ def test_already_downloaded(self, dataset: ReforesTree) -> None: def test_getitem(self, dataset: ReforesTree) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) - assert isinstance(x["boxes"], torch.Tensor) - assert isinstance(x["agb"], torch.Tensor) - assert x["image"].shape[0] == 3 - assert x["image"].ndim == 3 - assert len(x["boxes"]) == 2 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) + assert isinstance(x['boxes'], torch.Tensor) + assert isinstance(x['agb'], torch.Tensor) + assert x['image'].shape[0] == 3 + assert x['image'].ndim == 3 + assert len(x['boxes']) == 2 def test_len(self, dataset: ReforesTree) -> None: assert len(dataset) == 2 def test_not_extracted(self, tmp_path: Path) -> None: - url = os.path.join("tests", "data", "reforestree", "reforesTree.zip") + url = os.path.join('tests', 'data', 'reforestree', 'reforesTree.zip') shutil.copy(url, tmp_path) ReforesTree(root=str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: - with open(os.path.join(tmp_path, "reforesTree.zip"), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with open(os.path.join(tmp_path, 'reforesTree.zip'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): ReforesTree(root=str(tmp_path), checksum=True) def test_not_found(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): ReforesTree(str(tmp_path)) def test_plot(self, dataset: ReforesTree) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: ReforesTree) -> None: x = dataset[0].copy() - x["prediction_boxes"] = x["boxes"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction_boxes'] = x['boxes'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() diff --git a/tests/datasets/test_resisc45.py b/tests/datasets/test_resisc45.py index 099885deac8..d84a9709602 100644 --- a/tests/datasets/test_resisc45.py +++ b/tests/datasets/test_resisc45.py @@ -21,35 +21,35 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestRESISC45: - @pytest.fixture(params=["train", "val", "test"]) + @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> RESISC45: - pytest.importorskip("rarfile", minversion="4") + pytest.importorskip('rarfile', minversion='4') - monkeypatch.setattr(torchgeo.datasets.resisc45, "download_url", download_url) - md5 = "5895dea3757ba88707d52f5521c444d3" - monkeypatch.setattr(RESISC45, "md5", md5) - url = os.path.join("tests", "data", "resisc45", "NWPU-RESISC45.rar") - monkeypatch.setattr(RESISC45, "url", url) + monkeypatch.setattr(torchgeo.datasets.resisc45, 'download_url', download_url) + md5 = '5895dea3757ba88707d52f5521c444d3' + monkeypatch.setattr(RESISC45, 'md5', md5) + url = os.path.join('tests', 'data', 'resisc45', 'NWPU-RESISC45.rar') + monkeypatch.setattr(RESISC45, 'url', url) monkeypatch.setattr( RESISC45, - "split_urls", + 'split_urls', { - "train": os.path.join( - "tests", "data", "resisc45", "resisc45-train.txt" + 'train': os.path.join( + 'tests', 'data', 'resisc45', 'resisc45-train.txt' ), - "val": os.path.join("tests", "data", "resisc45", "resisc45-val.txt"), - "test": os.path.join("tests", "data", "resisc45", "resisc45-test.txt"), + 'val': os.path.join('tests', 'data', 'resisc45', 'resisc45-val.txt'), + 'test': os.path.join('tests', 'data', 'resisc45', 'resisc45-test.txt'), }, ) monkeypatch.setattr( RESISC45, - "split_md5s", + 'split_md5s', { - "train": "7760b1960c9a3ff46fb985810815e14d", - "val": "7760b1960c9a3ff46fb985810815e14d", - "test": "7760b1960c9a3ff46fb985810815e14d", + 'train': '7760b1960c9a3ff46fb985810815e14d', + 'val': '7760b1960c9a3ff46fb985810815e14d', + 'test': '7760b1960c9a3ff46fb985810815e14d', }, ) root = str(tmp_path) @@ -60,9 +60,9 @@ def dataset( def test_getitem(self, dataset: RESISC45) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) - assert x["image"].shape[0] == 3 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) + assert x['image'].shape[0] == 3 def test_len(self, dataset: RESISC45) -> None: assert len(dataset) == 9 @@ -78,15 +78,15 @@ def test_already_downloaded_not_extracted( RESISC45(root=str(tmp_path), download=False) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): RESISC45(str(tmp_path)) def test_plot(self, dataset: RESISC45) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["label"].clone() + x['prediction'] = x['label'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_rwanda_field_boundary.py b/tests/datasets/test_rwanda_field_boundary.py index 494180b7e4e..6f83b12a93d 100644 --- a/tests/datasets/test_rwanda_field_boundary.py +++ b/tests/datasets/test_rwanda_field_boundary.py @@ -23,7 +23,7 @@ class Collection: def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join("tests", "data", "rwanda_field_boundary", "*.tar.gz") + glob_path = os.path.join('tests', 'data', 'rwanda_field_boundary', '*.tar.gz') for tarball in glob.iglob(glob_path): shutil.copy(tarball, output_dir) @@ -33,22 +33,22 @@ def fetch(dataset_id: str, **kwargs: str) -> Collection: class TestRwandaFieldBoundary: - @pytest.fixture(params=["train", "test"]) + @pytest.fixture(params=['train', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> RwandaFieldBoundary: - radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) + radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) monkeypatch.setattr( - RwandaFieldBoundary, "number_of_patches_per_split", {"train": 5, "test": 5} + RwandaFieldBoundary, 'number_of_patches_per_split', {'train': 5, 'test': 5} ) monkeypatch.setattr( RwandaFieldBoundary, - "md5s", + 'md5s', { - "train_images": "af9395e2e49deefebb35fa65fa378ba3", - "test_images": "d104bb82323a39e7c3b3b7dd0156f550", - "train_labels": "6cceaf16a141cf73179253a783e7d51b", + 'train_images': 'af9395e2e49deefebb35fa65fa378ba3', + 'test_images': 'd104bb82323a39e7c3b3b7dd0156f550', + 'train_labels': '6cceaf16a141cf73179253a783e7d51b', }, ) @@ -56,17 +56,17 @@ def dataset( split = request.param transforms = nn.Identity() return RwandaFieldBoundary( - root, split, transforms=transforms, api_key="", download=True, checksum=True + root, split, transforms=transforms, api_key='', download=True, checksum=True ) def test_getitem(self, dataset: RwandaFieldBoundary) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - if dataset.split == "train": - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + if dataset.split == 'train': + assert isinstance(x['mask'], torch.Tensor) else: - assert "mask" not in x + assert 'mask' not in x def test_len(self, dataset: RwandaFieldBoundary) -> None: assert len(dataset) == 5 @@ -79,11 +79,11 @@ def test_add(self, dataset: RwandaFieldBoundary) -> None: def test_needs_extraction(self, tmp_path: Path) -> None: root = str(tmp_path) for fn in [ - "nasa_rwanda_field_boundary_competition_source_train.tar.gz", - "nasa_rwanda_field_boundary_competition_source_test.tar.gz", - "nasa_rwanda_field_boundary_competition_labels_train.tar.gz", + 'nasa_rwanda_field_boundary_competition_source_train.tar.gz', + 'nasa_rwanda_field_boundary_competition_source_test.tar.gz', + 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz', ]: - url = os.path.join("tests", "data", "rwanda_field_boundary", fn) + url = os.path.join('tests', 'data', 'rwanda_field_boundary', fn) shutil.copy(url, root) RwandaFieldBoundary(root, checksum=False) @@ -91,56 +91,56 @@ def test_already_downloaded(self, dataset: RwandaFieldBoundary) -> None: RwandaFieldBoundary(root=dataset.root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): RwandaFieldBoundary(str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: for fn in [ - "nasa_rwanda_field_boundary_competition_source_train.tar.gz", - "nasa_rwanda_field_boundary_competition_source_test.tar.gz", - "nasa_rwanda_field_boundary_competition_labels_train.tar.gz", + 'nasa_rwanda_field_boundary_competition_source_train.tar.gz', + 'nasa_rwanda_field_boundary_competition_source_test.tar.gz', + 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz', ]: - with open(os.path.join(tmp_path, fn), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with open(os.path.join(tmp_path, fn), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): RwandaFieldBoundary(root=str(tmp_path), checksum=True) def test_failed_download(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: - radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) + radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) monkeypatch.setattr( RwandaFieldBoundary, - "md5s", - {"train_images": "bad", "test_images": "bad", "train_labels": "bad"}, + 'md5s', + {'train_images': 'bad', 'test_images': 'bad', 'train_labels': 'bad'}, ) root = str(tmp_path) - with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): - RwandaFieldBoundary(root, "train", api_key="", download=True, checksum=True) + with pytest.raises(RuntimeError, match='Dataset not found or corrupted.'): + RwandaFieldBoundary(root, 'train', api_key='', download=True, checksum=True) def test_no_api_key(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Must provide an API key to download"): + with pytest.raises(RuntimeError, match='Must provide an API key to download'): RwandaFieldBoundary(str(tmp_path), api_key=None, download=True) def test_invalid_bands(self) -> None: - with pytest.raises(ValueError, match="is an invalid band name."): - RwandaFieldBoundary(bands=("foo", "bar")) + with pytest.raises(ValueError, match='is an invalid band name.'): + RwandaFieldBoundary(bands=('foo', 'bar')) def test_plot(self, dataset: RwandaFieldBoundary) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - if dataset.split == "train": - x["prediction"] = x["mask"].clone() + if dataset.split == 'train': + x['prediction'] = x['mask'].clone() dataset.plot(x) plt.close() def test_failed_plot(self, dataset: RwandaFieldBoundary) -> None: - single_band_dataset = RwandaFieldBoundary(root=dataset.root, bands=("B01",)) + single_band_dataset = RwandaFieldBoundary(root=dataset.root, bands=('B01',)) with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): x = single_band_dataset[0].copy() - single_band_dataset.plot(x, suptitle="Test") + single_band_dataset.plot(x, suptitle='Test') diff --git a/tests/datasets/test_seasonet.py b/tests/datasets/test_seasonet.py index 1c325993f2f..9178dcb1217 100644 --- a/tests/datasets/test_seasonet.py +++ b/tests/datasets/test_seasonet.py @@ -28,9 +28,9 @@ def download_url(url: str, root: str, md5: str, *args: str, **kwargs: str) -> No class TestSeasoNet: @pytest.fixture( params=zip( - ["train", "val", "test"], - [{"Spring"}, {"Summer", "Fall", "Winter", "Snow"}, SeasoNet.all_seasons], - [SeasoNet.all_bands, ["10m_IR", "10m_RGB", "60m"], ["10m_RGB"]], + ['train', 'val', 'test'], + [{'Spring'}, {'Summer', 'Fall', 'Winter', 'Snow'}, SeasoNet.all_seasons], + [SeasoNet.all_bands, ['10m_IR', '10m_RGB', '60m'], ['10m_RGB']], [[1], [2], [1, 2]], [1, 3, 5], ) @@ -38,62 +38,62 @@ class TestSeasoNet: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> SeasoNet: - monkeypatch.setattr(torchgeo.datasets.seasonet, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.seasonet, 'download_url', download_url) monkeypatch.setitem( - SeasoNet.metadata[0], "md5", "836a0896eba0e3005208f3fd180e429d" + SeasoNet.metadata[0], 'md5', '836a0896eba0e3005208f3fd180e429d' ) monkeypatch.setitem( - SeasoNet.metadata[1], "md5", "405656c8c19d822620bbb9f92e687337" + SeasoNet.metadata[1], 'md5', '405656c8c19d822620bbb9f92e687337' ) monkeypatch.setitem( - SeasoNet.metadata[2], "md5", "dc0dda18de019a9f50a794b8b4060a3b" + SeasoNet.metadata[2], 'md5', 'dc0dda18de019a9f50a794b8b4060a3b' ) monkeypatch.setitem( - SeasoNet.metadata[3], "md5", "a70abca62e78eb1591555809dc81d91d" + SeasoNet.metadata[3], 'md5', 'a70abca62e78eb1591555809dc81d91d' ) monkeypatch.setitem( - SeasoNet.metadata[4], "md5", "67651cc9095207e07ea4db1a71f0ebc2" + SeasoNet.metadata[4], 'md5', '67651cc9095207e07ea4db1a71f0ebc2' ) monkeypatch.setitem( - SeasoNet.metadata[5], "md5", "576324ba1c32a7e9ba858f1c2577cf2a" + SeasoNet.metadata[5], 'md5', '576324ba1c32a7e9ba858f1c2577cf2a' ) monkeypatch.setitem( - SeasoNet.metadata[6], "md5", "48ff6e9e01fdd92379e5712e4f336ea8" + SeasoNet.metadata[6], 'md5', '48ff6e9e01fdd92379e5712e4f336ea8' ) monkeypatch.setitem( SeasoNet.metadata[0], - "url", - os.path.join("tests", "data", "seasonet", "spring.zip"), + 'url', + os.path.join('tests', 'data', 'seasonet', 'spring.zip'), ) monkeypatch.setitem( SeasoNet.metadata[1], - "url", - os.path.join("tests", "data", "seasonet", "summer.zip"), + 'url', + os.path.join('tests', 'data', 'seasonet', 'summer.zip'), ) monkeypatch.setitem( SeasoNet.metadata[2], - "url", - os.path.join("tests", "data", "seasonet", "fall.zip"), + 'url', + os.path.join('tests', 'data', 'seasonet', 'fall.zip'), ) monkeypatch.setitem( SeasoNet.metadata[3], - "url", - os.path.join("tests", "data", "seasonet", "winter.zip"), + 'url', + os.path.join('tests', 'data', 'seasonet', 'winter.zip'), ) monkeypatch.setitem( SeasoNet.metadata[4], - "url", - os.path.join("tests", "data", "seasonet", "snow.zip"), + 'url', + os.path.join('tests', 'data', 'seasonet', 'snow.zip'), ) monkeypatch.setitem( SeasoNet.metadata[5], - "url", - os.path.join("tests", "data", "seasonet", "splits.zip"), + 'url', + os.path.join('tests', 'data', 'seasonet', 'splits.zip'), ) monkeypatch.setitem( SeasoNet.metadata[6], - "url", - os.path.join("tests", "data", "seasonet", "meta.csv"), + 'url', + os.path.join('tests', 'data', 'seasonet', 'meta.csv'), ) root = str(tmp_path) split, seasons, bands, grids, concat_seasons = request.param @@ -113,14 +113,14 @@ def dataset( def test_getitem(self, dataset: SeasoNet) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - assert x["image"].shape == (dataset.concat_seasons * dataset.channels, 120, 120) - assert x["mask"].shape == (120, 120) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + assert x['image'].shape == (dataset.concat_seasons * dataset.channels, 120, 120) + assert x['mask'].shape == (120, 120) def test_len(self, dataset: SeasoNet, request: SubRequest) -> None: - num_seasons = len(request.node.callspec.params["dataset"][1]) - num_grids = len(request.node.callspec.params["dataset"][3]) + num_seasons = len(request.node.callspec.params['dataset'][1]) + num_grids = len(request.node.callspec.params['dataset'][3]) if dataset.concat_seasons == 1: assert len(dataset) == num_grids * num_seasons else: @@ -129,8 +129,8 @@ def test_len(self, dataset: SeasoNet, request: SubRequest) -> None: def test_add(self, dataset: SeasoNet, request: SubRequest) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) - num_seasons = len(request.node.callspec.params["dataset"][1]) - num_grids = len(request.node.callspec.params["dataset"][3]) + num_seasons = len(request.node.callspec.params['dataset'][1]) + num_grids = len(request.node.callspec.params['dataset'][3]) if dataset.concat_seasons == 1: assert len(ds) == num_grids * num_seasons * 2 else: @@ -140,14 +140,14 @@ def test_already_extracted(self, dataset: SeasoNet) -> None: SeasoNet(root=dataset.root) def test_already_downloaded(self, tmp_path: Path) -> None: - paths = os.path.join("tests", "data", "seasonet", "*.*") + paths = os.path.join('tests', 'data', 'seasonet', '*.*') root = str(tmp_path) for path in glob.iglob(paths): shutil.copy(path, root) SeasoNet(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SeasoNet(str(tmp_path), download=False) def test_out_of_bounds(self, dataset: SeasoNet) -> None: @@ -156,11 +156,11 @@ def test_out_of_bounds(self, dataset: SeasoNet) -> None: def test_invalid_seasons(self) -> None: with pytest.raises(AssertionError): - SeasoNet(seasons=("Salt", "Pepper")) + SeasoNet(seasons=('Salt', 'Pepper')) def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): - SeasoNet(bands=["30s_TOMARS", "9in_NAILS"]) + SeasoNet(bands=['30s_TOMARS', '9in_NAILS']) def test_invalid_grids(self) -> None: with pytest.raises(AssertionError): @@ -168,29 +168,29 @@ def test_invalid_grids(self) -> None: def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - SeasoNet(split="banana") + SeasoNet(split='banana') def test_invalid_concat(self) -> None: with pytest.raises(AssertionError): - SeasoNet(seasons={"Spring", "Winter", "Snow"}, concat_seasons=4) + SeasoNet(seasons={'Spring', 'Winter', 'Snow'}, concat_seasons=4) def test_plot(self, dataset: SeasoNet) -> None: x = dataset[0] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() dataset.plot(x, show_legend=False) plt.close() - x["prediction"] = x["mask"].clone() + x['prediction'] = x['mask'].clone() dataset.plot(x) plt.close() def test_plot_no_rgb(self) -> None: with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): - root = os.path.join("tests", "data", "seasonet") - dataset = SeasoNet(root, bands=["10m_IR"]) + root = os.path.join('tests', 'data', 'seasonet') + dataset = SeasoNet(root, bands=['10m_IR']) x = dataset[0] dataset.plot(x) diff --git a/tests/datasets/test_seco.py b/tests/datasets/test_seco.py index e61a23aaf76..ed273a8810c 100644 --- a/tests/datasets/test_seco.py +++ b/tests/datasets/test_seco.py @@ -29,7 +29,7 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestSeasonalContrastS2: @pytest.fixture( params=zip( - ["100k", "1m"], + ['100k', '1m'], [1, 2], [SeasonalContrastS2.rgb_bands, SeasonalContrastS2.all_bands], ) @@ -37,24 +37,24 @@ class TestSeasonalContrastS2: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> SeasonalContrastS2: - monkeypatch.setattr(torchgeo.datasets.seco, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.seco, 'download_url', download_url) monkeypatch.setitem( - SeasonalContrastS2.metadata["100k"], - "url", - os.path.join("tests", "data", "seco", "seco_100k.zip"), + SeasonalContrastS2.metadata['100k'], + 'url', + os.path.join('tests', 'data', 'seco', 'seco_100k.zip'), ) monkeypatch.setitem( - SeasonalContrastS2.metadata["100k"], - "md5", - "6f527567f066562af2c78093114599f9", + SeasonalContrastS2.metadata['100k'], + 'md5', + '6f527567f066562af2c78093114599f9', ) monkeypatch.setitem( - SeasonalContrastS2.metadata["1m"], - "url", - os.path.join("tests", "data", "seco", "seco_1m.zip"), + SeasonalContrastS2.metadata['1m'], + 'url', + os.path.join('tests', 'data', 'seco', 'seco_1m.zip'), ) monkeypatch.setitem( - SeasonalContrastS2.metadata["1m"], "md5", "3bb3fcf90f5de7d5781ce0cb85fd20af" + SeasonalContrastS2.metadata['1m'], 'md5', '3bb3fcf90f5de7d5781ce0cb85fd20af' ) root = str(tmp_path) version, seasons, bands = request.param @@ -66,11 +66,11 @@ def dataset( def test_getitem(self, dataset: SeasonalContrastS2) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert x["image"].size(0) == dataset.seasons * len(dataset.bands) + assert isinstance(x['image'], torch.Tensor) + assert x['image'].size(0) == dataset.seasons * len(dataset.bands) def test_len(self, dataset: SeasonalContrastS2) -> None: - if dataset.version == "100k": + if dataset.version == '100k': assert len(dataset) == 10**5 // 5 else: assert len(dataset) == 10**6 // 5 @@ -78,7 +78,7 @@ def test_len(self, dataset: SeasonalContrastS2) -> None: def test_add(self, dataset: SeasonalContrastS2) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) - if dataset.version == "100k": + if dataset.version == '100k': assert len(ds) == 2 * 10**5 // 5 else: assert len(ds) == 2 * 10**6 // 5 @@ -87,7 +87,7 @@ def test_already_extracted(self, dataset: SeasonalContrastS2) -> None: SeasonalContrastS2(root=dataset.root, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "seco", "*.zip") + pathname = os.path.join('tests', 'data', 'seco', '*.zip') root = str(tmp_path) for zipfile in glob.iglob(pathname): shutil.copy(zipfile, root) @@ -95,32 +95,32 @@ def test_already_downloaded(self, tmp_path: Path) -> None: def test_invalid_version(self) -> None: with pytest.raises(AssertionError): - SeasonalContrastS2(version="foo") + SeasonalContrastS2(version='foo') def test_invalid_band(self) -> None: with pytest.raises(AssertionError): - SeasonalContrastS2(bands=["A1steaksauce"]) + SeasonalContrastS2(bands=['A1steaksauce']) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SeasonalContrastS2(str(tmp_path)) def test_plot(self, dataset: SeasonalContrastS2) -> None: x = dataset[0] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() with pytest.raises(ValueError, match="doesn't support plotting"): - x["prediction"] = torch.tensor(1) + x['prediction'] = torch.tensor(1) dataset.plot(x) def test_no_rgb_plot(self) -> None: with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): - root = os.path.join("tests", "data", "seco") - dataset = SeasonalContrastS2(root, bands=["B1"]) + root = os.path.join('tests', 'data', 'seco') + dataset = SeasonalContrastS2(root, bands=['B1']) x = dataset[0] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') diff --git a/tests/datasets/test_sen12ms.py b/tests/datasets/test_sen12ms.py index 26ae1434125..5732eaf18dd 100644 --- a/tests/datasets/test_sen12ms.py +++ b/tests/datasets/test_sen12ms.py @@ -16,27 +16,27 @@ class TestSEN12MS: - @pytest.fixture(params=["train", "test"]) + @pytest.fixture(params=['train', 'test']) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> SEN12MS: md5s = [ - "b7d9e183a460979e997b443517a78ded", - "7131dbb098c832fff84c2b8a0c8f1126", - "b1057fea6ced6d648e5b16efeac352ad", - "2da32111fcfb80939aea7b18c2250fa8", - "c688ad6475660dbdbc36f66a1dd07da7", - "2ecd0dce2a21372513955c604b07e24f", - "dbc84c03edf77a68f789a6f7d2ea66a9", - "3e42a7dc4bb1ecd8c588930bf49b5c2b", - "c29053cb8cf5d75e333b1b51d37f62fe", - "5b6880526bc6da488154092741392042", - "d1b51c39b1013f2779fecf1f362f6c28", - "078def1e13ce4e88632d65f5c73a6259", - "02d5128ac1fc2bf8762091b4f319762d", - "02d5128ac1fc2bf8762091b4f319762d", + 'b7d9e183a460979e997b443517a78ded', + '7131dbb098c832fff84c2b8a0c8f1126', + 'b1057fea6ced6d648e5b16efeac352ad', + '2da32111fcfb80939aea7b18c2250fa8', + 'c688ad6475660dbdbc36f66a1dd07da7', + '2ecd0dce2a21372513955c604b07e24f', + 'dbc84c03edf77a68f789a6f7d2ea66a9', + '3e42a7dc4bb1ecd8c588930bf49b5c2b', + 'c29053cb8cf5d75e333b1b51d37f62fe', + '5b6880526bc6da488154092741392042', + 'd1b51c39b1013f2779fecf1f362f6c28', + '078def1e13ce4e88632d65f5c73a6259', + '02d5128ac1fc2bf8762091b4f319762d', + '02d5128ac1fc2bf8762091b4f319762d', ] - monkeypatch.setattr(SEN12MS, "md5s", md5s) - root = os.path.join("tests", "data", "sen12ms") + monkeypatch.setattr(SEN12MS, 'md5s', md5s) + root = os.path.join('tests', 'data', 'sen12ms') split = request.param transforms = nn.Identity() return SEN12MS(root, split, transforms=transforms, checksum=True) @@ -44,9 +44,9 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> SEN12MS: def test_getitem(self, dataset: SEN12MS) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - assert x["image"].shape[0] == 15 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + assert x['image'].shape[0] == 15 def test_len(self, dataset: SEN12MS) -> None: assert len(dataset) == 8 @@ -62,43 +62,43 @@ def test_out_of_bounds(self, dataset: SEN12MS) -> None: def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - SEN12MS(split="foo") + SEN12MS(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SEN12MS(str(tmp_path), checksum=True) - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SEN12MS(str(tmp_path), checksum=False) def test_check_integrity_light(self) -> None: - root = os.path.join("tests", "data", "sen12ms") + root = os.path.join('tests', 'data', 'sen12ms') ds = SEN12MS(root, checksum=False) assert isinstance(ds, SEN12MS) def test_band_subsets(self) -> None: - root = os.path.join("tests", "data", "sen12ms") + root = os.path.join('tests', 'data', 'sen12ms') for bands in SEN12MS.BAND_SETS.values(): ds = SEN12MS(root, bands=bands, checksum=False) - x = ds[0]["image"] + x = ds[0]['image'] assert x.shape[0] == len(bands) def test_invalid_bands(self) -> None: with pytest.raises(ValueError): - SEN12MS(bands=("OK", "BK")) + SEN12MS(bands=('OK', 'BK')) def test_plot(self, dataset: SEN12MS) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() sample = dataset[0] - sample["prediction"] = sample["mask"].clone() - dataset.plot(sample, suptitle="prediction") + sample['prediction'] = sample['mask'].clone() + dataset.plot(sample, suptitle='prediction') plt.close() def test_plot_rgb(self, dataset: SEN12MS) -> None: - dataset = SEN12MS(root=dataset.root, bands=("B03",)) + dataset = SEN12MS(root=dataset.root, bands=('B03',)) with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): - dataset.plot(dataset[0], suptitle="Single Band") + dataset.plot(dataset[0], suptitle='Single Band') diff --git a/tests/datasets/test_sentinel.py b/tests/datasets/test_sentinel.py index b64b3a0cd04..d7fb5956769 100644 --- a/tests/datasets/test_sentinel.py +++ b/tests/datasets/test_sentinel.py @@ -26,19 +26,19 @@ class TestSentinel1: @pytest.fixture( params=[ # Only horizontal or vertical receive - ["HH"], - ["HV"], - ["VV"], - ["VH"], + ['HH'], + ['HV'], + ['VV'], + ['VH'], # Both horizontal and vertical receive - ["HH", "HV"], - ["HV", "HH"], - ["VV", "VH"], - ["VH", "VV"], + ['HH', 'HV'], + ['HV', 'HH'], + ['VV', 'VH'], + ['VH', 'VV'], ] ) def dataset(self, request: SubRequest) -> Sentinel1: - root = os.path.join("tests", "data", "sentinel1") + root = os.path.join('tests', 'data', 'sentinel1') bands = request.param transforms = nn.Identity() return Sentinel1(root, bands=bands, transforms=transforms) @@ -49,8 +49,8 @@ def test_separate_files(self, dataset: Sentinel1) -> None: def test_getitem(self, dataset: Sentinel1) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["image"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['image'], torch.Tensor) def test_and(self, dataset: Sentinel1) -> None: ds = dataset & dataset @@ -62,29 +62,29 @@ def test_or(self, dataset: Sentinel1) -> None: def test_plot(self, dataset: Sentinel2) -> None: x = dataset[dataset.bounds] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): Sentinel1(str(tmp_path)) def test_empty_bands(self) -> None: with pytest.raises(AssertionError, match="'bands' cannot be an empty list"): Sentinel1(bands=[]) - @pytest.mark.parametrize("bands", [["HH", "HH"], ["HH", "HV", "HH"]]) + @pytest.mark.parametrize('bands', [['HH', 'HH'], ['HH', 'HV', 'HH']]) def test_duplicate_bands(self, bands: list[str]) -> None: with pytest.raises(AssertionError, match="'bands' contains duplicate bands"): Sentinel1(bands=bands) - @pytest.mark.parametrize("bands", [["HH_HV"], ["HH", "HV", "HH_HV"]]) + @pytest.mark.parametrize('bands', [['HH_HV'], ['HH', 'HV', 'HH_HV']]) def test_invalid_bands(self, bands: list[str]) -> None: with pytest.raises(AssertionError, match="invalid band 'HH_HV'"): Sentinel1(bands=bands) @pytest.mark.parametrize( - "bands", [["HH", "VV"], ["HH", "VH"], ["VV", "HV"], ["HH", "HV", "VV", "VH"]] + 'bands', [['HH', 'VV'], ['HH', 'VH'], ['VV', 'HV'], ['HH', 'HV', 'VV', 'VH']] ) def test_dual_transmit(self, bands: list[str]) -> None: with pytest.raises(AssertionError, match="'bands' cannot contain both "): @@ -93,7 +93,7 @@ def test_dual_transmit(self, bands: list[str]) -> None: def test_invalid_query(self, dataset: Sentinel1) -> None: query = BoundingBox(-1, -1, -1, -1, -1, -1) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] @@ -101,9 +101,9 @@ def test_invalid_query(self, dataset: Sentinel1) -> None: class TestSentinel2: @pytest.fixture def dataset(self) -> Sentinel2: - root = os.path.join("tests", "data", "sentinel2") + root = os.path.join('tests', 'data', 'sentinel2') res = 10 - bands = ["B02", "B03", "B04", "B08"] + bands = ['B02', 'B03', 'B04', 'B08'] transforms = nn.Identity() return Sentinel2(root, res=res, bands=bands, transforms=transforms) @@ -113,8 +113,8 @@ def test_separate_files(self, dataset: Sentinel2) -> None: def test_getitem(self, dataset: Sentinel2) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["image"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['image'], torch.Tensor) def test_and(self, dataset: Sentinel2) -> None: ds = dataset & dataset @@ -125,26 +125,26 @@ def test_or(self, dataset: Sentinel2) -> None: assert isinstance(ds, UnionDataset) def test_no_data(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): Sentinel2(str(tmp_path)) def test_plot(self, dataset: Sentinel2) -> None: x = dataset[dataset.bounds] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_wrong_bands(self, dataset: Sentinel2) -> None: - bands = ["B02"] + bands = ['B02'] ds = Sentinel2(dataset.paths, res=dataset.res, bands=bands) x = dataset[dataset.bounds] with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): ds.plot(x) def test_invalid_query(self, dataset: Sentinel2) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_skippd.py b/tests/datasets/test_skippd.py index 392c3255eda..01907d22cea 100644 --- a/tests/datasets/test_skippd.py +++ b/tests/datasets/test_skippd.py @@ -18,7 +18,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import SKIPPD, DatasetNotFoundError -pytest.importorskip("h5py", minversion="3") +pytest.importorskip('h5py', minversion='3') def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -26,22 +26,22 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestSKIPPD: - @pytest.fixture(params=product(["nowcast", "forecast"], ["trainval", "test"])) + @pytest.fixture(params=product(['nowcast', 'forecast'], ['trainval', 'test'])) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> SKIPPD: task, split = request.param - monkeypatch.setattr(torchgeo.datasets.skippd, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.skippd, 'download_url', download_url) md5 = { - "nowcast": "6f5e54906927278b189f9281a2f54f39", - "forecast": "f3b5d7d5c28ba238144fa1e726c46969", + 'nowcast': '6f5e54906927278b189f9281a2f54f39', + 'forecast': 'f3b5d7d5c28ba238144fa1e726c46969', } - monkeypatch.setattr(SKIPPD, "md5", md5) - url = os.path.join("tests", "data", "skippd", "{}") - monkeypatch.setattr(SKIPPD, "url", url) - monkeypatch.setattr(plt, "show", lambda *args: None) + monkeypatch.setattr(SKIPPD, 'md5', md5) + url = os.path.join('tests', 'data', 'skippd', '{}') + monkeypatch.setattr(SKIPPD, 'url', url) + monkeypatch.setattr(plt, 'show', lambda *args: None) root = str(tmp_path) transforms = nn.Identity() return SKIPPD( @@ -58,64 +58,64 @@ def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: import_orig = builtins.__import__ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == "h5py": + if name == 'h5py': raise ImportError() return import_orig(name, *args, **kwargs) - monkeypatch.setattr(builtins, "__import__", mocked_import) + monkeypatch.setattr(builtins, '__import__', mocked_import) def test_mock_missing_module( self, dataset: SKIPPD, tmp_path: Path, mock_missing_module: None ) -> None: with pytest.raises( ImportError, - match="h5py is not installed and is required to use this dataset", + match='h5py is not installed and is required to use this dataset', ): SKIPPD(dataset.root, download=True, checksum=True) def test_already_extracted(self, dataset: SKIPPD) -> None: SKIPPD(root=dataset.root, download=True) - @pytest.mark.parametrize("task", ["nowcast", "forecast"]) + @pytest.mark.parametrize('task', ['nowcast', 'forecast']) def test_already_downloaded(self, tmp_path: Path, task: str) -> None: pathname = os.path.join( - "tests", "data", "skippd", f"2017_2019_images_pv_processed_{task}.zip" + 'tests', 'data', 'skippd', f'2017_2019_images_pv_processed_{task}.zip' ) root = str(tmp_path) shutil.copy(pathname, root) SKIPPD(root=root, task=task) - @pytest.mark.parametrize("index", [0, 1, 2]) + @pytest.mark.parametrize('index', [0, 1, 2]) def test_getitem(self, dataset: SKIPPD, index: int) -> None: x = dataset[index] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) - assert isinstance(x["date"], str) - if dataset.task == "nowcast": - assert x["image"].shape == (3, 64, 64) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) + assert isinstance(x['date'], str) + if dataset.task == 'nowcast': + assert x['image'].shape == (3, 64, 64) else: - assert x["image"].shape == (48, 64, 64) + assert x['image'].shape == (48, 64, 64) def test_len(self, dataset: SKIPPD) -> None: assert len(dataset) == 3 def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - SKIPPD(split="foo") + SKIPPD(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SKIPPD(str(tmp_path)) def test_plot(self, dataset: SKIPPD) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() sample = dataset[0] - if dataset.task == "nowcast": - sample["prediction"] = sample["label"] + if dataset.task == 'nowcast': + sample['prediction'] = sample['label'] else: - sample["prediction"] = sample["label"][-1] + sample['prediction'] = sample['label'][-1] dataset.plot(sample) plt.close() diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index c947dd74dd3..6daa4057bed 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -15,22 +15,22 @@ from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, So2Sat -pytest.importorskip("h5py", minversion="3") +pytest.importorskip('h5py', minversion='3') class TestSo2Sat: - @pytest.fixture(params=["train", "validation", "test"]) + @pytest.fixture(params=['train', 'validation', 'test']) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> So2Sat: md5s_by_version = { - "2": { - "train": "56e6fa0edb25b065124a3113372f76e5", - "validation": "940c95a737bd2fcdcc46c9a52b31424d", - "test": "e97a6746aadc731a1854097f32ab1755", + '2': { + 'train': '56e6fa0edb25b065124a3113372f76e5', + 'validation': '940c95a737bd2fcdcc46c9a52b31424d', + 'test': 'e97a6746aadc731a1854097f32ab1755', } } - monkeypatch.setattr(So2Sat, "md5s_by_version", md5s_by_version) - root = os.path.join("tests", "data", "so2sat") + monkeypatch.setattr(So2Sat, 'md5s_by_version', md5s_by_version) + root = os.path.join('tests', 'data', 'so2sat') split = request.param transforms = nn.Identity() return So2Sat(root=root, split=split, transforms=transforms, checksum=True) @@ -40,17 +40,17 @@ def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: import_orig = builtins.__import__ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == "h5py": + if name == 'h5py': raise ImportError() return import_orig(name, *args, **kwargs) - monkeypatch.setattr(builtins, "__import__", mocked_import) + monkeypatch.setattr(builtins, '__import__', mocked_import) def test_getitem(self, dataset: So2Sat) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) def test_len(self, dataset: So2Sat) -> None: assert len(dataset) == 2 @@ -63,38 +63,38 @@ def test_out_of_bounds(self, dataset: So2Sat) -> None: def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - So2Sat(split="foo") + So2Sat(split='foo') def test_invalid_bands(self) -> None: with pytest.raises(ValueError): - So2Sat(bands=("OK", "BK")) + So2Sat(bands=('OK', 'BK')) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): So2Sat(str(tmp_path)) def test_plot(self, dataset: So2Sat) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["label"].clone() + x['prediction'] = x['label'].clone() dataset.plot(x) plt.close() def test_plot_rgb(self, dataset: So2Sat) -> None: - dataset = So2Sat(root=dataset.root, bands=("S2_B03",)) + dataset = So2Sat(root=dataset.root, bands=('S2_B03',)) with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): - dataset.plot(dataset[0], suptitle="Single Band") + dataset.plot(dataset[0], suptitle='Single Band') def test_mock_missing_module( self, dataset: So2Sat, mock_missing_module: None ) -> None: with pytest.raises( ImportError, - match="h5py is not installed and is required to use this dataset", + match='h5py is not installed and is required to use this dataset', ): So2Sat(dataset.root) diff --git a/tests/datasets/test_south_africa_crop_type.py b/tests/datasets/test_south_africa_crop_type.py index 7422395f0b5..a6413cb9bad 100644 --- a/tests/datasets/test_south_africa_crop_type.py +++ b/tests/datasets/test_south_africa_crop_type.py @@ -23,16 +23,16 @@ class TestSouthAfricaCropType: @pytest.fixture def dataset(self) -> SouthAfricaCropType: - path = os.path.join("tests", "data", "south_africa_crop_type") + path = os.path.join('tests', 'data', 'south_africa_crop_type') transforms = nn.Identity() return SouthAfricaCropType(paths=path, transforms=transforms) def test_getitem(self, dataset: SouthAfricaCropType) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) def test_and(self, dataset: SouthAfricaCropType) -> None: ds = dataset & dataset @@ -46,32 +46,32 @@ def test_already_downloaded(self, dataset: SouthAfricaCropType) -> None: SouthAfricaCropType(paths=dataset.paths) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SouthAfricaCropType(str(tmp_path)) def test_plot(self, dataset: SouthAfricaCropType) -> None: x = dataset[dataset.bounds] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: SouthAfricaCropType) -> None: x = dataset[dataset.bounds] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() def test_invalid_query(self, dataset: SouthAfricaCropType) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] def test_rgb_bands_absent_plot(self, dataset: SouthAfricaCropType) -> None: with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): - ds = SouthAfricaCropType(dataset.paths, bands=["B01", "B02", "B05"]) + ds = SouthAfricaCropType(dataset.paths, bands=['B01', 'B02', 'B05']) x = ds[ds.bounds] - ds.plot(x, suptitle="Test") + ds.plot(x, suptitle='Test') plt.close() diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index 11dcc2b5ff9..06fb7018d53 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -29,18 +29,18 @@ class TestSouthAmericaSoybean: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybean: monkeypatch.setattr( - torchgeo.datasets.south_america_soybean, "download_url", download_url + torchgeo.datasets.south_america_soybean, 'download_url', download_url ) transforms = nn.Identity() url = os.path.join( - "tests", - "data", - "south_america_soybean", - "SouthAmericaSoybean", - "South_America_Soybean_{}.tif", + 'tests', + 'data', + 'south_america_soybean', + 'SouthAmericaSoybean', + 'South_America_Soybean_{}.tif', ) - monkeypatch.setattr(SouthAmericaSoybean, "url", url) + monkeypatch.setattr(SouthAmericaSoybean, 'url', url) root = str(tmp_path) return SouthAmericaSoybean( paths=root, @@ -53,8 +53,8 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybe def test_getitem(self, dataset: SouthAmericaSoybean) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) - assert isinstance(x["crs"], CRS) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['crs'], CRS) + assert isinstance(x['mask'], torch.Tensor) def test_and(self, dataset: SouthAmericaSoybean) -> None: ds = dataset & dataset @@ -69,11 +69,11 @@ def test_already_extracted(self, dataset: SouthAmericaSoybean) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( - "tests", - "data", - "south_america_soybean", - "SouthAmericaSoybean", - "South_America_Soybean_2002.tif", + 'tests', + 'data', + 'south_america_soybean', + 'SouthAmericaSoybean', + 'South_America_Soybean_2002.tif', ) root = str(tmp_path) shutil.copy(pathname, root) @@ -82,23 +82,23 @@ def test_already_downloaded(self, tmp_path: Path) -> None: def test_plot(self, dataset: SouthAmericaSoybean) -> None: query = dataset.bounds x = dataset[query] - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() def test_plot_prediction(self, dataset: SouthAmericaSoybean) -> None: query = dataset.bounds x = dataset[query] - x["prediction"] = x["mask"].clone() - dataset.plot(x, suptitle="Prediction") + x['prediction'] = x['mask'].clone() + dataset.plot(x, suptitle='Prediction') plt.close() def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SouthAmericaSoybean(str(tmp_path)) def test_invalid_query(self, dataset: SouthAmericaSoybean) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( - IndexError, match="query: .* not found in index with bounds:" + IndexError, match='query: .* not found in index with bounds:' ): dataset[query] diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py index 046b83cfba1..2676af497fd 100644 --- a/tests/datasets/test_spacenet.py +++ b/tests/datasets/test_spacenet.py @@ -24,8 +24,8 @@ SpaceNet7, ) -TEST_DATA_DIR = "tests/data/spacenet" -radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") +TEST_DATA_DIR = 'tests/data/spacenet' +radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') class Collection: @@ -33,7 +33,7 @@ def __init__(self, collection_id: str) -> None: self.collection_id = collection_id def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join(TEST_DATA_DIR, "*.tar.gz") + glob_path = os.path.join(TEST_DATA_DIR, '*.tar.gz') for tarball in glob.iglob(glob_path): shutil.copy(tarball, output_dir) @@ -43,7 +43,7 @@ def __init__(self, dataset_id: str) -> None: self.dataset_id = dataset_id def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join(TEST_DATA_DIR, "spacenet*") + glob_path = os.path.join(TEST_DATA_DIR, 'spacenet*') for directory in glob.iglob(glob_path): dataset_name = os.path.basename(directory) output_dir = os.path.join(output_dir, dataset_name) @@ -59,31 +59,31 @@ def fetch_dataset(dataset_id: str, **kwargs: str) -> Dataset: class TestSpaceNet1: - @pytest.fixture(params=["rgb", "8band"]) + @pytest.fixture(params=['rgb', '8band']) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> SpaceNet1: - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection) - test_md5 = {"sn1_AOI_1_RIO": "127a523561987110f008e8c9815ce807"} + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) + test_md5 = {'sn1_AOI_1_RIO': '127a523561987110f008e8c9815ce807'} # Refer https://github.com/python/mypy/issues/1032 - monkeypatch.setattr(SpaceNet1, "collection_md5_dict", test_md5) + monkeypatch.setattr(SpaceNet1, 'collection_md5_dict', test_md5) root = str(tmp_path) transforms = nn.Identity() return SpaceNet1( - root, image=request.param, transforms=transforms, download=True, api_key="" + root, image=request.param, transforms=transforms, download=True, api_key='' ) def test_getitem(self, dataset: SpaceNet1) -> None: x = dataset[0] dataset[1] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - if dataset.image == "rgb": - assert x["image"].shape[0] == 3 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + if dataset.image == 'rgb': + assert x['image'].shape[0] == 3 else: - assert x["image"].shape[0] == 8 + assert x['image'].shape[0] == 8 def test_len(self, dataset: SpaceNet1) -> None: assert len(dataset) == 3 @@ -92,54 +92,54 @@ def test_already_downloaded(self, dataset: SpaceNet1) -> None: SpaceNet1(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SpaceNet1(str(tmp_path)) def test_plot(self, dataset: SpaceNet1) -> None: x = dataset[0].copy() - x["prediction"] = x["mask"] - dataset.plot(x, suptitle="Test") + x['prediction'] = x['mask'] + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() class TestSpaceNet2: - @pytest.fixture(params=["PAN", "MS", "PS-MS", "PS-RGB"]) + @pytest.fixture(params=['PAN', 'MS', 'PS-MS', 'PS-RGB']) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> SpaceNet2: - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection) + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) test_md5 = { - "sn2_AOI_2_Vegas": "131048686ba21a45853c05f227f40b7f", - "sn2_AOI_3_Paris": "62242fd198ee32b59f0178cf656e1513", - "sn2_AOI_4_Shanghai": "563b0817ecedd8ff3b3e4cb2991bf3fb", - "sn2_AOI_5_Khartoum": "e4185a2e9a12cf7b3d0cd1db6b3e0f06", + 'sn2_AOI_2_Vegas': '131048686ba21a45853c05f227f40b7f', + 'sn2_AOI_3_Paris': '62242fd198ee32b59f0178cf656e1513', + 'sn2_AOI_4_Shanghai': '563b0817ecedd8ff3b3e4cb2991bf3fb', + 'sn2_AOI_5_Khartoum': 'e4185a2e9a12cf7b3d0cd1db6b3e0f06', } - monkeypatch.setattr(SpaceNet2, "collection_md5_dict", test_md5) + monkeypatch.setattr(SpaceNet2, 'collection_md5_dict', test_md5) root = str(tmp_path) transforms = nn.Identity() return SpaceNet2( root, image=request.param, - collections=["sn2_AOI_2_Vegas", "sn2_AOI_5_Khartoum"], + collections=['sn2_AOI_2_Vegas', 'sn2_AOI_5_Khartoum'], transforms=transforms, download=True, - api_key="", + api_key='', ) def test_getitem(self, dataset: SpaceNet2) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - if dataset.image == "PS-RGB": - assert x["image"].shape[0] == 3 - elif dataset.image in ["MS", "PS-MS"]: - assert x["image"].shape[0] == 8 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + if dataset.image == 'PS-RGB': + assert x['image'].shape[0] == 3 + elif dataset.image in ['MS', 'PS-MS']: + assert x['image'].shape[0] == 8 else: - assert x["image"].shape[0] == 1 + assert x['image'].shape[0] == 1 def test_len(self, dataset: SpaceNet2) -> None: assert len(dataset) == 4 @@ -148,45 +148,45 @@ def test_already_downloaded(self, dataset: SpaceNet2) -> None: SpaceNet2(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SpaceNet2(str(tmp_path)) def test_collection_checksum(self, dataset: SpaceNet2) -> None: - dataset.collection_md5_dict["sn2_AOI_2_Vegas"] = "randommd5hash123" - with pytest.raises(RuntimeError, match="Collection sn2_AOI_2_Vegas corrupted"): + dataset.collection_md5_dict['sn2_AOI_2_Vegas'] = 'randommd5hash123' + with pytest.raises(RuntimeError, match='Collection sn2_AOI_2_Vegas corrupted'): SpaceNet2(root=dataset.root, download=True, checksum=True) def test_plot(self, dataset: SpaceNet2) -> None: x = dataset[0].copy() - x["prediction"] = x["mask"] - dataset.plot(x, suptitle="Test") + x['prediction'] = x['mask'] + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() class TestSpaceNet3: - @pytest.fixture(params=zip(["PAN", "MS"], [False, True])) + @pytest.fixture(params=zip(['PAN', 'MS'], [False, True])) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> SpaceNet3: - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection) + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) test_md5 = { - "sn3_AOI_3_Paris": "93452c68da11dd6b57dc83dba43c2c9d", - "sn3_AOI_5_Khartoum": "7c9d96810198bf101cbaf54f7a5e8b3b", + 'sn3_AOI_3_Paris': '93452c68da11dd6b57dc83dba43c2c9d', + 'sn3_AOI_5_Khartoum': '7c9d96810198bf101cbaf54f7a5e8b3b', } - monkeypatch.setattr(SpaceNet3, "collection_md5_dict", test_md5) + monkeypatch.setattr(SpaceNet3, 'collection_md5_dict', test_md5) root = str(tmp_path) transforms = nn.Identity() return SpaceNet3( root, image=request.param[0], speed_mask=request.param[1], - collections=["sn3_AOI_3_Paris", "sn3_AOI_5_Khartoum"], + collections=['sn3_AOI_3_Paris', 'sn3_AOI_5_Khartoum'], transforms=transforms, download=True, - api_key="", + api_key='', ) def test_getitem(self, dataset: SpaceNet3) -> None: @@ -194,12 +194,12 @@ def test_getitem(self, dataset: SpaceNet3) -> None: samples = [dataset[i] for i in range(len(dataset))] x = samples[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - if dataset.image == "MS": - assert x["image"].shape[0] == 8 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + if dataset.image == 'MS': + assert x['image'].shape[0] == 8 else: - assert x["image"].shape[0] == 1 + assert x['image'].shape[0] == 1 def test_len(self, dataset: SpaceNet3) -> None: assert len(dataset) == 4 @@ -208,38 +208,38 @@ def test_already_downloaded(self, dataset: SpaceNet3) -> None: SpaceNet3(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SpaceNet3(str(tmp_path)) def test_collection_checksum(self, dataset: SpaceNet3) -> None: - dataset.collection_md5_dict["sn3_AOI_5_Khartoum"] = "randommd5hash123" + dataset.collection_md5_dict['sn3_AOI_5_Khartoum'] = 'randommd5hash123' with pytest.raises( - RuntimeError, match="Collection sn3_AOI_5_Khartoum corrupted" + RuntimeError, match='Collection sn3_AOI_5_Khartoum corrupted' ): SpaceNet3(root=dataset.root, download=True, checksum=True) def test_plot(self, dataset: SpaceNet3) -> None: x = dataset[0].copy() - x["prediction"] = x["mask"] - dataset.plot(x, suptitle="Test") + x['prediction'] = x['mask'] + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - dataset.plot({"image": x["image"]}) + dataset.plot({'image': x['image']}) plt.close() class TestSpaceNet4: - @pytest.fixture(params=["PAN", "MS", "PS-RGBNIR"]) + @pytest.fixture(params=['PAN', 'MS', 'PS-RGBNIR']) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> SpaceNet4: - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection) - test_md5 = {"sn4_AOI_6_Atlanta": "097a76a2319b7ba34dac1722862fc93b"} + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) + test_md5 = {'sn4_AOI_6_Atlanta': '097a76a2319b7ba34dac1722862fc93b'} - test_angles = ["nadir", "off-nadir", "very-off-nadir"] + test_angles = ['nadir', 'off-nadir', 'very-off-nadir'] - monkeypatch.setattr(SpaceNet4, "collection_md5_dict", test_md5) + monkeypatch.setattr(SpaceNet4, 'collection_md5_dict', test_md5) root = str(tmp_path) transforms = nn.Identity() return SpaceNet4( @@ -248,7 +248,7 @@ def dataset( angles=test_angles, transforms=transforms, download=True, - api_key="", + api_key='', ) def test_getitem(self, dataset: SpaceNet4) -> None: @@ -256,14 +256,14 @@ def test_getitem(self, dataset: SpaceNet4) -> None: # ensure coverage x = dataset[2] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - if dataset.image == "PS-RGBNIR": - assert x["image"].shape[0] == 4 - elif dataset.image == "MS": - assert x["image"].shape[0] == 8 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + if dataset.image == 'PS-RGBNIR': + assert x['image'].shape[0] == 4 + elif dataset.image == 'MS': + assert x['image'].shape[0] == 8 else: - assert x["image"].shape[0] == 1 + assert x['image'].shape[0] == 1 def test_len(self, dataset: SpaceNet4) -> None: assert len(dataset) == 4 @@ -272,47 +272,47 @@ def test_already_downloaded(self, dataset: SpaceNet4) -> None: SpaceNet4(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SpaceNet4(str(tmp_path)) def test_collection_checksum(self, dataset: SpaceNet4) -> None: - dataset.collection_md5_dict["sn4_AOI_6_Atlanta"] = "randommd5hash123" + dataset.collection_md5_dict['sn4_AOI_6_Atlanta'] = 'randommd5hash123' with pytest.raises( - RuntimeError, match="Collection sn4_AOI_6_Atlanta corrupted" + RuntimeError, match='Collection sn4_AOI_6_Atlanta corrupted' ): SpaceNet4(root=dataset.root, download=True, checksum=True) def test_plot(self, dataset: SpaceNet4) -> None: x = dataset[0].copy() - x["prediction"] = x["mask"] - dataset.plot(x, suptitle="Test") + x['prediction'] = x['mask'] + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() class TestSpaceNet5: - @pytest.fixture(params=zip(["PAN", "MS"], [False, True])) + @pytest.fixture(params=zip(['PAN', 'MS'], [False, True])) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> SpaceNet5: - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection) + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) test_md5 = { - "sn5_AOI_7_Moscow": "5c511dd31eea739cc1f81ef5962f3d56", - "sn5_AOI_8_Mumbai": "e00452b87bbe87feaef65f373be3978e", + 'sn5_AOI_7_Moscow': '5c511dd31eea739cc1f81ef5962f3d56', + 'sn5_AOI_8_Mumbai': 'e00452b87bbe87feaef65f373be3978e', } - monkeypatch.setattr(SpaceNet5, "collection_md5_dict", test_md5) + monkeypatch.setattr(SpaceNet5, 'collection_md5_dict', test_md5) root = str(tmp_path) transforms = nn.Identity() return SpaceNet5( root, image=request.param[0], speed_mask=request.param[1], - collections=["sn5_AOI_7_Moscow", "sn5_AOI_8_Mumbai"], + collections=['sn5_AOI_7_Moscow', 'sn5_AOI_8_Mumbai'], transforms=transforms, download=True, - api_key="", + api_key='', ) def test_getitem(self, dataset: SpaceNet5) -> None: @@ -320,12 +320,12 @@ def test_getitem(self, dataset: SpaceNet5) -> None: samples = [dataset[i] for i in range(len(dataset))] x = samples[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - if dataset.image == "MS": - assert x["image"].shape[0] == 8 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + if dataset.image == 'MS': + assert x['image'].shape[0] == 8 else: - assert x["image"].shape[0] == 1 + assert x['image'].shape[0] == 1 def test_len(self, dataset: SpaceNet5) -> None: assert len(dataset) == 5 @@ -334,48 +334,48 @@ def test_already_downloaded(self, dataset: SpaceNet5) -> None: SpaceNet5(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SpaceNet5(str(tmp_path)) def test_collection_checksum(self, dataset: SpaceNet5) -> None: - dataset.collection_md5_dict["sn5_AOI_8_Mumbai"] = "randommd5hash123" - with pytest.raises(RuntimeError, match="Collection sn5_AOI_8_Mumbai corrupted"): + dataset.collection_md5_dict['sn5_AOI_8_Mumbai'] = 'randommd5hash123' + with pytest.raises(RuntimeError, match='Collection sn5_AOI_8_Mumbai corrupted'): SpaceNet5(root=dataset.root, download=True, checksum=True) def test_plot(self, dataset: SpaceNet5) -> None: x = dataset[0].copy() - x["prediction"] = x["mask"] - dataset.plot(x, suptitle="Test") + x['prediction'] = x['mask'] + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - dataset.plot({"image": x["image"]}) + dataset.plot({'image': x['image']}) plt.close() class TestSpaceNet6: - @pytest.fixture(params=["PAN", "RGBNIR", "PS-RGB", "PS-RGBNIR", "SAR-Intensity"]) + @pytest.fixture(params=['PAN', 'RGBNIR', 'PS-RGB', 'PS-RGBNIR', 'SAR-Intensity']) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> SpaceNet6: - monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch_dataset) + monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset) root = str(tmp_path) transforms = nn.Identity() return SpaceNet6( - root, image=request.param, transforms=transforms, download=True, api_key="" + root, image=request.param, transforms=transforms, download=True, api_key='' ) def test_getitem(self, dataset: SpaceNet6) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - if dataset.image == "PS-RGB": - assert x["image"].shape[0] == 3 - elif dataset.image in ["RGBNIR", "PS-RGBNIR"]: - assert x["image"].shape[0] == 4 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + if dataset.image == 'PS-RGB': + assert x['image'].shape[0] == 3 + elif dataset.image in ['RGBNIR', 'PS-RGBNIR']: + assert x['image'].shape[0] == 4 else: - assert x["image"].shape[0] == 1 + assert x['image'].shape[0] == 1 def test_len(self, dataset: SpaceNet6) -> None: assert len(dataset) == 2 @@ -385,41 +385,41 @@ def test_already_downloaded(self, dataset: SpaceNet6) -> None: def test_plot(self, dataset: SpaceNet6) -> None: x = dataset[0].copy() - x["prediction"] = x["mask"] - dataset.plot(x, suptitle="Test") + x['prediction'] = x['mask'] + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() class TestSpaceNet7: - @pytest.fixture(params=["train", "test"]) + @pytest.fixture(params=['train', 'test']) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> SpaceNet7: - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection) + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) test_md5 = { - "sn7_train_source": "197bfa8842a40b09b6837b824a6370e0", - "sn7_train_labels": "625ad8a989a5105bc766a53e53df4d0e", - "sn7_test_source": "461f59eb21bb4f416c867f5037dfceeb", + 'sn7_train_source': '197bfa8842a40b09b6837b824a6370e0', + 'sn7_train_labels': '625ad8a989a5105bc766a53e53df4d0e', + 'sn7_test_source': '461f59eb21bb4f416c867f5037dfceeb', } - monkeypatch.setattr(SpaceNet7, "collection_md5_dict", test_md5) + monkeypatch.setattr(SpaceNet7, 'collection_md5_dict', test_md5) root = str(tmp_path) transforms = nn.Identity() return SpaceNet7( - root, split=request.param, transforms=transforms, download=True, api_key="" + root, split=request.param, transforms=transforms, download=True, api_key='' ) def test_getitem(self, dataset: SpaceNet7) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - if dataset.split == "train": - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + if dataset.split == 'train': + assert isinstance(x['mask'], torch.Tensor) def test_len(self, dataset: SpaceNet7) -> None: - if dataset.split == "train": + if dataset.split == 'train': assert len(dataset) == 2 else: assert len(dataset) == 1 @@ -428,19 +428,19 @@ def test_already_downloaded(self, dataset: SpaceNet4) -> None: SpaceNet7(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SpaceNet7(str(tmp_path)) def test_collection_checksum(self, dataset: SpaceNet4) -> None: - dataset.collection_md5_dict["sn7_train_source"] = "randommd5hash123" - with pytest.raises(RuntimeError, match="Collection sn7_train_source corrupted"): + dataset.collection_md5_dict['sn7_train_source'] = 'randommd5hash123' + with pytest.raises(RuntimeError, match='Collection sn7_train_source corrupted'): SpaceNet7(root=dataset.root, download=True, checksum=True) def test_plot(self, dataset: SpaceNet7) -> None: x = dataset[0].copy() - if dataset.split == "train": - x["prediction"] = x["mask"] - dataset.plot(x, suptitle="Test") + if dataset.split == 'train': + x['prediction'] = x['mask'] + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 11f5838e1e3..2977586ddb4 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -39,7 +39,7 @@ def no_overlap(ds1: GeoDataset, ds2: GeoDataset) -> bool: class CustomGeoDataset(GeoDataset): def __init__( self, - items: list[tuple[BoundingBox, str]] = [(BoundingBox(0, 1, 0, 1, 0, 40), "")], + items: list[tuple[BoundingBox, str]] = [(BoundingBox(0, 1, 0, 1, 0, 40), '')], crs: CRS = CRS.from_epsg(3005), res: float = 1, ) -> None: @@ -52,11 +52,11 @@ def __init__( def __getitem__(self, query: BoundingBox) -> dict[str, Any]: hits = self.index.intersection(tuple(query), objects=True) hit = next(iter(hits)) - return {"content": hit.object} + return {'content': hit.object} @pytest.mark.parametrize( - "lengths,expected_lengths", + 'lengths,expected_lengths', [ # List of lengths ([2, 1, 1], [2, 1, 1]), @@ -69,10 +69,10 @@ def test_random_bbox_assignment( ) -> None: ds = CustomGeoDataset( [ - (BoundingBox(0, 1, 0, 1, 0, 0), "a"), - (BoundingBox(1, 2, 0, 1, 0, 0), "b"), - (BoundingBox(2, 3, 0, 1, 0, 0), "c"), - (BoundingBox(3, 4, 0, 1, 0, 0), "d"), + (BoundingBox(0, 1, 0, 1, 0, 0), 'a'), + (BoundingBox(1, 2, 0, 1, 0, 0), 'b'), + (BoundingBox(2, 3, 0, 1, 0, 0), 'c'), + (BoundingBox(3, 4, 0, 1, 0, 0), 'd'), ] ) @@ -94,7 +94,7 @@ def test_random_bbox_assignment( # Test __getitem__ x = train_ds[train_ds.bounds] assert isinstance(x, dict) - assert isinstance(x["content"], str) + assert isinstance(x['content'], str) def test_random_bbox_assignment_invalid_inputs() -> None: @@ -104,7 +104,7 @@ def test_random_bbox_assignment_invalid_inputs() -> None: ): random_bbox_assignment(CustomGeoDataset(), lengths=[2, 2, 1]) with pytest.raises( - ValueError, match="All items in input lengths must be greater than 0." + ValueError, match='All items in input lengths must be greater than 0.' ): random_bbox_assignment(CustomGeoDataset(), lengths=[1 / 2, 3 / 4, -1 / 4]) @@ -112,10 +112,10 @@ def test_random_bbox_assignment_invalid_inputs() -> None: def test_random_bbox_splitting() -> None: ds = CustomGeoDataset( [ - (BoundingBox(0, 1, 0, 1, 0, 0), "a"), - (BoundingBox(1, 2, 0, 1, 0, 0), "b"), - (BoundingBox(2, 3, 0, 1, 0, 0), "c"), - (BoundingBox(3, 4, 0, 1, 0, 0), "d"), + (BoundingBox(0, 1, 0, 1, 0, 0), 'a'), + (BoundingBox(1, 2, 0, 1, 0, 0), 'b'), + (BoundingBox(2, 3, 0, 1, 0, 0), 'c'), + (BoundingBox(3, 4, 0, 1, 0, 0), 'd'), ] ) @@ -145,13 +145,13 @@ def test_random_bbox_splitting() -> None: # Test __get_item__ x = train_ds[train_ds.bounds] assert isinstance(x, dict) - assert isinstance(x["content"], str) + assert isinstance(x['content'], str) # Test invalid input fractions - with pytest.raises(ValueError, match="Sum of input fractions must equal 1."): + with pytest.raises(ValueError, match='Sum of input fractions must equal 1.'): random_bbox_splitting(ds, fractions=[1 / 2, 1 / 3, 1 / 4]) with pytest.raises( - ValueError, match="All items in input fractions must be greater than 0." + ValueError, match='All items in input fractions must be greater than 0.' ): random_bbox_splitting(ds, fractions=[1 / 2, 3 / 4, -1 / 4]) @@ -159,8 +159,8 @@ def test_random_bbox_splitting() -> None: def test_random_grid_cell_assignment() -> None: ds = CustomGeoDataset( [ - (BoundingBox(0, 12, 0, 12, 0, 0), "a"), - (BoundingBox(12, 24, 0, 12, 0, 0), "b"), + (BoundingBox(0, 12, 0, 12, 0, 0), 'a'), + (BoundingBox(12, 24, 0, 12, 0, 0), 'b'), ] ) @@ -185,26 +185,26 @@ def test_random_grid_cell_assignment() -> None: # Test __get_item__ x = train_ds[train_ds.bounds] assert isinstance(x, dict) - assert isinstance(x["content"], str) + assert isinstance(x['content'], str) # Test invalid input fractions - with pytest.raises(ValueError, match="Sum of input fractions must equal 1."): + with pytest.raises(ValueError, match='Sum of input fractions must equal 1.'): random_grid_cell_assignment(ds, fractions=[1 / 2, 1 / 3, 1 / 4]) with pytest.raises( - ValueError, match="All items in input fractions must be greater than 0." + ValueError, match='All items in input fractions must be greater than 0.' ): random_grid_cell_assignment(ds, fractions=[1 / 2, 3 / 4, -1 / 4]) - with pytest.raises(ValueError, match="Input grid_size must be greater than 1."): + with pytest.raises(ValueError, match='Input grid_size must be greater than 1.'): random_grid_cell_assignment(ds, fractions=[1 / 2, 1 / 4, 1 / 4], grid_size=1) def test_roi_split() -> None: ds = CustomGeoDataset( [ - (BoundingBox(0, 1, 0, 1, 0, 0), "a"), - (BoundingBox(1, 2, 0, 1, 0, 0), "b"), - (BoundingBox(2, 3, 0, 1, 0, 0), "c"), - (BoundingBox(3, 4, 0, 1, 0, 0), "d"), + (BoundingBox(0, 1, 0, 1, 0, 0), 'a'), + (BoundingBox(1, 2, 0, 1, 0, 0), 'b'), + (BoundingBox(2, 3, 0, 1, 0, 0), 'c'), + (BoundingBox(3, 4, 0, 1, 0, 0), 'd'), ] ) @@ -234,7 +234,7 @@ def test_roi_split() -> None: # Test __get_item__ x = train_ds[train_ds.bounds] assert isinstance(x, dict) - assert isinstance(x["content"], str) + assert isinstance(x['content'], str) # Test invalid input rois with pytest.raises(ValueError, match="ROIs in input rois can't overlap."): @@ -244,7 +244,7 @@ def test_roi_split() -> None: @pytest.mark.parametrize( - "lengths,expected_lengths", + 'lengths,expected_lengths', [ # List of timestamps ([(0, 20), (20, 35), (35, 40)], [2, 2, 1]), @@ -259,10 +259,10 @@ def test_time_series_split( ) -> None: ds = CustomGeoDataset( [ - (BoundingBox(0, 1, 0, 1, 0, 10), "a"), - (BoundingBox(0, 1, 0, 1, 10, 20), "b"), - (BoundingBox(0, 1, 0, 1, 20, 30), "c"), - (BoundingBox(0, 1, 0, 1, 30, 40), "d"), + (BoundingBox(0, 1, 0, 1, 0, 10), 'a'), + (BoundingBox(0, 1, 0, 1, 10, 20), 'b'), + (BoundingBox(0, 1, 0, 1, 20, 30), 'c'), + (BoundingBox(0, 1, 0, 1, 30, 40), 'd'), ] ) @@ -284,13 +284,13 @@ def test_time_series_split( # Test __get_item__ x = train_ds[train_ds.bounds] assert isinstance(x, dict) - assert isinstance(x["content"], str) + assert isinstance(x['content'], str) def test_time_series_split_invalid_input() -> None: with pytest.raises( ValueError, - match="Pairs of timestamps in lengths must have end greater than start.", + match='Pairs of timestamps in lengths must have end greater than start.', ): time_series_split(CustomGeoDataset(), lengths=[(0, 20), (35, 20), (35, 40)]) @@ -318,6 +318,6 @@ def test_time_series_split_invalid_input() -> None: time_series_split(CustomGeoDataset(), lengths=[1 / 2, 1 / 2, 1 / 2]) with pytest.raises( - ValueError, match="All items in input lengths must be greater than 0." + ValueError, match='All items in input lengths must be greater than 0.' ): time_series_split(CustomGeoDataset(), lengths=[20, 25, -5]) diff --git a/tests/datasets/test_ssl4eo.py b/tests/datasets/test_ssl4eo.py index 68b6df002b4..ad45798946e 100644 --- a/tests/datasets/test_ssl4eo.py +++ b/tests/datasets/test_ssl4eo.py @@ -27,39 +27,39 @@ class TestSSL4EOL: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> SSL4EOL: - monkeypatch.setattr(torchgeo.datasets.ssl4eo, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.ssl4eo, 'download_url', download_url) - url = os.path.join("tests", "data", "ssl4eo", "l", "ssl4eo_l_{0}.tar.gz{1}") - monkeypatch.setattr(SSL4EOL, "url", url) + url = os.path.join('tests', 'data', 'ssl4eo', 'l', 'ssl4eo_l_{0}.tar.gz{1}') + monkeypatch.setattr(SSL4EOL, 'url', url) checksums = { - "tm_toa": { - "aa": "010b9d72b476e0e30741c17725f84e5c", - "ab": "39171bd7bca8a56a8cb339a0f88da9d3", - "ac": "3cfc407ce3f4f4d6e3c5fdb457bb87da", + 'tm_toa': { + 'aa': '010b9d72b476e0e30741c17725f84e5c', + 'ab': '39171bd7bca8a56a8cb339a0f88da9d3', + 'ac': '3cfc407ce3f4f4d6e3c5fdb457bb87da', }, - "etm_toa": { - "aa": "87e47278f5a30acd3b696b6daaa4713b", - "ab": "59295e1816e08a5acd3a18ae56b6f32e", - "ac": "f3ff76eb6987501000228ce15684e09f", + 'etm_toa': { + 'aa': '87e47278f5a30acd3b696b6daaa4713b', + 'ab': '59295e1816e08a5acd3a18ae56b6f32e', + 'ac': 'f3ff76eb6987501000228ce15684e09f', }, - "etm_sr": { - "aa": "fd61a4154eafaeb350dbb01a2551a818", - "ab": "0c3117bc7682ba9ffdc6871e6c364b36", - "ac": "93d3385e47de4578878ca5c4fa6a628d", + 'etm_sr': { + 'aa': 'fd61a4154eafaeb350dbb01a2551a818', + 'ab': '0c3117bc7682ba9ffdc6871e6c364b36', + 'ac': '93d3385e47de4578878ca5c4fa6a628d', }, - "oli_tirs_toa": { - "aa": "defb9e91a73b145b2dbe347649bded06", - "ab": "97f7edaa4e288fc14ec7581dccea766f", - "ac": "7472fad9929a0dc96ccf4dc6c804b92f", + 'oli_tirs_toa': { + 'aa': 'defb9e91a73b145b2dbe347649bded06', + 'ab': '97f7edaa4e288fc14ec7581dccea766f', + 'ac': '7472fad9929a0dc96ccf4dc6c804b92f', }, - "oli_sr": { - "aa": "8fd3aa6b581d024299f44457956faa05", - "ab": "7eb4d761ce1afd89cae9c6142ca17882", - "ac": "a3210da9fcc71e3a4efde71c30d78c59", + 'oli_sr': { + 'aa': '8fd3aa6b581d024299f44457956faa05', + 'ab': '7eb4d761ce1afd89cae9c6142ca17882', + 'ac': 'a3210da9fcc71e3a4efde71c30d78c59', }, } - monkeypatch.setattr(SSL4EOL, "checksums", checksums) + monkeypatch.setattr(SSL4EOL, 'checksums', checksums) root = str(tmp_path) split, seasons = request.param @@ -69,10 +69,10 @@ def dataset( def test_getitem(self, dataset: SSL4EOL) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) assert ( - x["image"].size(0) - == dataset.seasons * dataset.metadata[dataset.split]["num_bands"] + x['image'].size(0) + == dataset.seasons * dataset.metadata[dataset.split]['num_bands'] ) def test_len(self, dataset: SSL4EOL) -> None: @@ -87,23 +87,23 @@ def test_already_extracted(self, dataset: SSL4EOL) -> None: SSL4EOL(dataset.root, dataset.split, dataset.seasons) def test_already_downloaded(self, dataset: SSL4EOL, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "ssl4eo", "l", "*.tar.gz*") + pathname = os.path.join('tests', 'data', 'ssl4eo', 'l', '*.tar.gz*') root = str(tmp_path) for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) SSL4EOL(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SSL4EOL(str(tmp_path)) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - SSL4EOL(split="foo") + SSL4EOL(split='foo') def test_plot(self, dataset: SSL4EOL) -> None: sample = dataset[0] - dataset.plot(sample, suptitle="Test") + dataset.plot(sample, suptitle='Test') plt.close() dataset.plot(sample, show_titles=False) plt.close() @@ -113,16 +113,16 @@ class TestSSL4EOS12: @pytest.fixture(params=zip(SSL4EOS12.metadata.keys(), [1, 2, 4])) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> SSL4EOS12: monkeypatch.setitem( - SSL4EOS12.metadata["s1"], "md5", "a716f353e4c2f0014f2e1f1ad848f82e" + SSL4EOS12.metadata['s1'], 'md5', 'a716f353e4c2f0014f2e1f1ad848f82e' ) monkeypatch.setitem( - SSL4EOS12.metadata["s2c"], "md5", "85eaf474af5642588a97dc5c991cfc15" + SSL4EOS12.metadata['s2c'], 'md5', '85eaf474af5642588a97dc5c991cfc15' ) monkeypatch.setitem( - SSL4EOS12.metadata["s2a"], "md5", "df41a5d1ae6f840bc9a11ee254110369" + SSL4EOS12.metadata['s2a'], 'md5', 'df41a5d1ae6f840bc9a11ee254110369' ) - root = os.path.join("tests", "data", "ssl4eo", "s12") + root = os.path.join('tests', 'data', 'ssl4eo', 's12') split, seasons = request.param transforms = nn.Identity() return SSL4EOS12(root, split, seasons, transforms, checksum=True) @@ -130,8 +130,8 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> SSL4EOS12: def test_getitem(self, dataset: SSL4EOS12) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert x["image"].size(0) == dataset.seasons * len(dataset.bands) + assert isinstance(x['image'], torch.Tensor) + assert x['image'].size(0) == dataset.seasons * len(dataset.bands) def test_len(self, dataset: SSL4EOS12) -> None: assert len(dataset) == 251079 @@ -143,24 +143,24 @@ def test_add(self, dataset: SSL4EOS12) -> None: def test_extract(self, tmp_path: Path) -> None: for split in SSL4EOS12.metadata: - filename = SSL4EOS12.metadata[split]["filename"] + filename = SSL4EOS12.metadata[split]['filename'] shutil.copyfile( - os.path.join("tests", "data", "ssl4eo", "s12", filename), + os.path.join('tests', 'data', 'ssl4eo', 's12', filename), tmp_path / filename, ) SSL4EOS12(str(tmp_path)) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - SSL4EOS12(split="foo") + SSL4EOS12(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SSL4EOS12(str(tmp_path)) def test_plot(self, dataset: SSL4EOS12) -> None: sample = dataset[0] - dataset.plot(sample, suptitle="Test") + dataset.plot(sample, suptitle='Test') plt.close() dataset.plot(sample, show_titles=False) plt.close() diff --git a/tests/datasets/test_ssl4eo_benchmark.py b/tests/datasets/test_ssl4eo_benchmark.py index 0d5b3f94030..d759b65183e 100644 --- a/tests/datasets/test_ssl4eo_benchmark.py +++ b/tests/datasets/test_ssl4eo_benchmark.py @@ -32,51 +32,51 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestSSL4EOLBenchmark: @pytest.fixture( params=product( - ["tm_toa", "etm_toa", "etm_sr", "oli_tirs_toa", "oli_sr"], - ["cdl", "nlcd"], - ["train", "val", "test"], + ['tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr'], + ['cdl', 'nlcd'], + ['train', 'val', 'test'], ) ) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> SSL4EOLBenchmark: monkeypatch.setattr( - torchgeo.datasets.ssl4eo_benchmark, "download_url", download_url + torchgeo.datasets.ssl4eo_benchmark, 'download_url', download_url ) root = str(tmp_path) - url = os.path.join("tests", "data", "ssl4eo_benchmark_landsat", "{}.tar.gz") - monkeypatch.setattr(SSL4EOLBenchmark, "url", url) + url = os.path.join('tests', 'data', 'ssl4eo_benchmark_landsat', '{}.tar.gz') + monkeypatch.setattr(SSL4EOLBenchmark, 'url', url) sensor, product, split = request.param monkeypatch.setattr( - SSL4EOLBenchmark, "split_percentages", [1 / 3, 1 / 3, 1 / 3] + SSL4EOLBenchmark, 'split_percentages', [1 / 3, 1 / 3, 1 / 3] ) img_md5s = { - "tm_toa": "ecfdd3dcbc812c5e7cf272a5cddb33e9", - "etm_sr": "3e598245948eb7d072d5b83c95f22422", - "etm_toa": "e24ff11f6aedb3930380b53cb6f780b6", - "oli_tirs_toa": "490baa1eedd5032277e2a07f45dd8c2b", - "oli_sr": "884f6e28a23a1b7d464eff39abd7667d", + 'tm_toa': 'ecfdd3dcbc812c5e7cf272a5cddb33e9', + 'etm_sr': '3e598245948eb7d072d5b83c95f22422', + 'etm_toa': 'e24ff11f6aedb3930380b53cb6f780b6', + 'oli_tirs_toa': '490baa1eedd5032277e2a07f45dd8c2b', + 'oli_sr': '884f6e28a23a1b7d464eff39abd7667d', } - monkeypatch.setattr(SSL4EOLBenchmark, "img_md5s", img_md5s) + monkeypatch.setattr(SSL4EOLBenchmark, 'img_md5s', img_md5s) mask_md5s = { - "tm": { - "cdl": "43f30648e0f7c8dba78fa729b6db9ffe", - "nlcd": "4272958acb32cc3b83f593684bc3e63c", + 'tm': { + 'cdl': '43f30648e0f7c8dba78fa729b6db9ffe', + 'nlcd': '4272958acb32cc3b83f593684bc3e63c', }, - "etm": { - "cdl": "b215b7e3b65b18a6d52ce9a35c90a16f", - "nlcd": "f823fc69965d7f6215f52bea2141df41", + 'etm': { + 'cdl': 'b215b7e3b65b18a6d52ce9a35c90a16f', + 'nlcd': 'f823fc69965d7f6215f52bea2141df41', }, - "oli": { - "cdl": "aaa956d7aa985e8de2c565858c9ac4e8", - "nlcd": "cc49207df010a4f358fb16a46772e9ae", + 'oli': { + 'cdl': 'aaa956d7aa985e8de2c565858c9ac4e8', + 'nlcd': 'cc49207df010a4f358fb16a46772e9ae', }, } - monkeypatch.setattr(SSL4EOLBenchmark, "mask_md5s", mask_md5s) + monkeypatch.setattr(SSL4EOLBenchmark, 'mask_md5s', mask_md5s) transforms = nn.Identity() return SSL4EOLBenchmark( @@ -92,29 +92,29 @@ def dataset( def test_getitem(self, dataset: SSL4EOLBenchmark) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) - @pytest.mark.parametrize("product,base_class", [("nlcd", NLCD), ("cdl", CDL)]) + @pytest.mark.parametrize('product,base_class', [('nlcd', NLCD), ('cdl', CDL)]) def test_classes(self, product: str, base_class: RasterDataset) -> None: - root = os.path.join("tests", "data", "ssl4eo_benchmark_landsat") + root = os.path.join('tests', 'data', 'ssl4eo_benchmark_landsat') classes = list(base_class.cmap.keys())[:5] ds = SSL4EOLBenchmark(root, product=product, classes=classes) sample = ds[0] - mask = sample["mask"] + mask = sample['mask'] assert mask.max() < len(classes) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - SSL4EOLBenchmark(split="foo") + SSL4EOLBenchmark(split='foo') def test_invalid_sensor(self) -> None: with pytest.raises(AssertionError): - SSL4EOLBenchmark(sensor="foo") + SSL4EOLBenchmark(sensor='foo') def test_invalid_product(self) -> None: with pytest.raises(AssertionError): - SSL4EOLBenchmark(product="foo") + SSL4EOLBenchmark(product='foo') def test_invalid_classes(self) -> None: with pytest.raises(AssertionError): @@ -136,22 +136,22 @@ def test_already_extracted(self, dataset: SSL4EOLBenchmark) -> None: ) def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "ssl4eo_benchmark_landsat", "*.tar.gz") + pathname = os.path.join('tests', 'data', 'ssl4eo_benchmark_landsat', '*.tar.gz') root = str(tmp_path) for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) SSL4EOLBenchmark(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SSL4EOLBenchmark(str(tmp_path)) def test_plot(self, dataset: SSL4EOLBenchmark) -> None: sample = dataset[0] - dataset.plot(sample, suptitle="Test") + dataset.plot(sample, suptitle='Test') plt.close() dataset.plot(sample, show_titles=False) plt.close() - sample["prediction"] = sample["mask"].clone() + sample['prediction'] = sample['mask'].clone() dataset.plot(sample) plt.close() diff --git a/tests/datasets/test_sustainbench_crop_yield.py b/tests/datasets/test_sustainbench_crop_yield.py index 071f0c81a8f..36e746aaf92 100644 --- a/tests/datasets/test_sustainbench_crop_yield.py +++ b/tests/datasets/test_sustainbench_crop_yield.py @@ -21,22 +21,22 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestSustainBenchCropYield: - @pytest.fixture(params=["train", "dev", "test"]) + @pytest.fixture(params=['train', 'dev', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> SustainBenchCropYield: monkeypatch.setattr( - torchgeo.datasets.sustainbench_crop_yield, "download_url", download_url + torchgeo.datasets.sustainbench_crop_yield, 'download_url', download_url ) - md5 = "7a5591794e14dd73d2b747cd2244acbc" - monkeypatch.setattr(SustainBenchCropYield, "md5", md5) - url = os.path.join("tests", "data", "sustainbench_crop_yield", "soybeans.zip") - monkeypatch.setattr(SustainBenchCropYield, "url", url) - monkeypatch.setattr(plt, "show", lambda *args: None) + md5 = '7a5591794e14dd73d2b747cd2244acbc' + monkeypatch.setattr(SustainBenchCropYield, 'md5', md5) + url = os.path.join('tests', 'data', 'sustainbench_crop_yield', 'soybeans.zip') + monkeypatch.setattr(SustainBenchCropYield, 'url', url) + monkeypatch.setattr(plt, 'show', lambda *args: None) root = str(tmp_path) split = request.param - countries = ["argentina", "brazil", "usa"] + countries = ['argentina', 'brazil', 'usa'] transforms = nn.Identity() return SustainBenchCropYield( root, split, countries, transforms, download=True, checksum=True @@ -47,38 +47,38 @@ def test_already_extracted(self, dataset: SustainBenchCropYield) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( - "tests", "data", "sustainbench_crop_yield", "soybeans.zip" + 'tests', 'data', 'sustainbench_crop_yield', 'soybeans.zip' ) root = str(tmp_path) shutil.copy(pathname, root) SustainBenchCropYield(root) - @pytest.mark.parametrize("index", [0, 1, 2]) + @pytest.mark.parametrize('index', [0, 1, 2]) def test_getitem(self, dataset: SustainBenchCropYield, index: int) -> None: x = dataset[index] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) - assert isinstance(x["year"], torch.Tensor) - assert isinstance(x["ndvi"], torch.Tensor) - assert x["image"].shape == (9, 32, 32) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) + assert isinstance(x['year'], torch.Tensor) + assert isinstance(x['ndvi'], torch.Tensor) + assert x['image'].shape == (9, 32, 32) def test_len(self, dataset: SustainBenchCropYield) -> None: assert len(dataset) == len(dataset.countries) * 3 def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - SustainBenchCropYield(split="foo") + SustainBenchCropYield(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): SustainBenchCropYield(str(tmp_path)) def test_plot(self, dataset: SustainBenchCropYield) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() sample = dataset[0] - sample["prediction"] = sample["label"] + sample['prediction'] = sample['label'] dataset.plot(sample) plt.close() diff --git a/tests/datasets/test_ucmerced.py b/tests/datasets/test_ucmerced.py index 61c76f9cecd..bedeb588c66 100644 --- a/tests/datasets/test_ucmerced.py +++ b/tests/datasets/test_ucmerced.py @@ -22,33 +22,33 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestUCMerced: - @pytest.fixture(params=["train", "val", "test"]) + @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> UCMerced: - monkeypatch.setattr(torchgeo.datasets.ucmerced, "download_url", download_url) - md5 = "a42ef8779469d196d8f2971ee135f030" - monkeypatch.setattr(UCMerced, "md5", md5) - url = os.path.join("tests", "data", "ucmerced", "UCMerced_LandUse.zip") - monkeypatch.setattr(UCMerced, "url", url) + monkeypatch.setattr(torchgeo.datasets.ucmerced, 'download_url', download_url) + md5 = 'a42ef8779469d196d8f2971ee135f030' + monkeypatch.setattr(UCMerced, 'md5', md5) + url = os.path.join('tests', 'data', 'ucmerced', 'UCMerced_LandUse.zip') + monkeypatch.setattr(UCMerced, 'url', url) monkeypatch.setattr( UCMerced, - "split_urls", + 'split_urls', { - "train": os.path.join( - "tests", "data", "ucmerced", "uc_merced-train.txt" + 'train': os.path.join( + 'tests', 'data', 'ucmerced', 'uc_merced-train.txt' ), - "val": os.path.join("tests", "data", "ucmerced", "uc_merced-val.txt"), - "test": os.path.join("tests", "data", "ucmerced", "uc_merced-test.txt"), + 'val': os.path.join('tests', 'data', 'ucmerced', 'uc_merced-val.txt'), + 'test': os.path.join('tests', 'data', 'ucmerced', 'uc_merced-test.txt'), }, ) monkeypatch.setattr( UCMerced, - "split_md5s", + 'split_md5s', { - "train": "a01fa9f13333bb176fc1bfe26ff4c711", - "val": "a01fa9f13333bb176fc1bfe26ff4c711", - "test": "a01fa9f13333bb176fc1bfe26ff4c711", + 'train': 'a01fa9f13333bb176fc1bfe26ff4c711', + 'val': 'a01fa9f13333bb176fc1bfe26ff4c711', + 'test': 'a01fa9f13333bb176fc1bfe26ff4c711', }, ) root = str(tmp_path) @@ -59,8 +59,8 @@ def dataset( def test_getitem(self, dataset: UCMerced) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) def test_len(self, dataset: UCMerced) -> None: assert len(dataset) == 4 @@ -81,15 +81,15 @@ def test_already_downloaded_not_extracted( UCMerced(root=str(tmp_path), download=False) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): UCMerced(str(tmp_path)) def test_plot(self, dataset: UCMerced) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["label"].clone() + x['prediction'] = x['label'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index 4c256ad5c25..0566a1f3153 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -24,54 +24,54 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestUSAVars: @pytest.fixture( params=zip( - ["train", "val", "test"], + ['train', 'val', 'test'], [ - ["elevation", "population", "treecover"], - ["elevation", "population"], - ["treecover"], + ['elevation', 'population', 'treecover'], + ['elevation', 'population'], + ['treecover'], ], ) ) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> USAVars: - monkeypatch.setattr(torchgeo.datasets.usavars, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.usavars, 'download_url', download_url) - md5 = "b504580a00bdc27097d5421dec50481b" - monkeypatch.setattr(USAVars, "md5", md5) + md5 = 'b504580a00bdc27097d5421dec50481b' + monkeypatch.setattr(USAVars, 'md5', md5) - data_url = os.path.join("tests", "data", "usavars", "uar.zip") - monkeypatch.setattr(USAVars, "data_url", data_url) + data_url = os.path.join('tests', 'data', 'usavars', 'uar.zip') + monkeypatch.setattr(USAVars, 'data_url', data_url) label_urls = { - "elevation": os.path.join("tests", "data", "usavars", "elevation.csv"), - "population": os.path.join("tests", "data", "usavars", "population.csv"), - "treecover": os.path.join("tests", "data", "usavars", "treecover.csv"), - "income": os.path.join("tests", "data", "usavars", "income.csv"), - "nightlights": os.path.join("tests", "data", "usavars", "nightlights.csv"), - "roads": os.path.join("tests", "data", "usavars", "roads.csv"), - "housing": os.path.join("tests", "data", "usavars", "housing.csv"), + 'elevation': os.path.join('tests', 'data', 'usavars', 'elevation.csv'), + 'population': os.path.join('tests', 'data', 'usavars', 'population.csv'), + 'treecover': os.path.join('tests', 'data', 'usavars', 'treecover.csv'), + 'income': os.path.join('tests', 'data', 'usavars', 'income.csv'), + 'nightlights': os.path.join('tests', 'data', 'usavars', 'nightlights.csv'), + 'roads': os.path.join('tests', 'data', 'usavars', 'roads.csv'), + 'housing': os.path.join('tests', 'data', 'usavars', 'housing.csv'), } - monkeypatch.setattr(USAVars, "label_urls", label_urls) + monkeypatch.setattr(USAVars, 'label_urls', label_urls) split_metadata = { - "train": { - "url": os.path.join("tests", "data", "usavars", "train_split.txt"), - "filename": "train_split.txt", - "md5": "b94f3f6f63110b253779b65bc31d91b5", + 'train': { + 'url': os.path.join('tests', 'data', 'usavars', 'train_split.txt'), + 'filename': 'train_split.txt', + 'md5': 'b94f3f6f63110b253779b65bc31d91b5', }, - "val": { - "url": os.path.join("tests", "data", "usavars", "val_split.txt"), - "filename": "val_split.txt", - "md5": "e39aa54b646c4c45921fcc9765d5a708", + 'val': { + 'url': os.path.join('tests', 'data', 'usavars', 'val_split.txt'), + 'filename': 'val_split.txt', + 'md5': 'e39aa54b646c4c45921fcc9765d5a708', }, - "test": { - "url": os.path.join("tests", "data", "usavars", "test_split.txt"), - "filename": "test_split.txt", - "md5": "4ab0f5549fee944a5690de1bc95ed245", + 'test': { + 'url': os.path.join('tests', 'data', 'usavars', 'test_split.txt'), + 'filename': 'test_split.txt', + 'md5': '4ab0f5549fee944a5690de1bc95ed245', }, } - monkeypatch.setattr(USAVars, "split_metadata", split_metadata) + monkeypatch.setattr(USAVars, 'split_metadata', split_metadata) root = str(tmp_path) split, labels = request.param @@ -84,18 +84,18 @@ def dataset( def test_getitem(self, dataset: USAVars) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert x["image"].ndim == 3 + assert isinstance(x['image'], torch.Tensor) + assert x['image'].ndim == 3 assert len(x.keys()) == 4 # image, labels, centroid_lat, centroid_lon - assert x["image"].shape[0] == 4 # R, G, B, Inf - assert len(dataset.labels) == len(x["labels"]) - assert len(x["centroid_lat"]) == 1 - assert len(x["centroid_lon"]) == 1 + assert x['image'].shape[0] == 4 # R, G, B, Inf + assert len(dataset.labels) == len(x['labels']) + assert len(x['centroid_lat']) == 1 + assert len(x['centroid_lon']) == 1 def test_len(self, dataset: USAVars) -> None: - if dataset.split == "train": + if dataset.split == 'train': assert len(dataset) == 3 - elif dataset.split == "val": + elif dataset.split == 'val': assert len(dataset) == 2 else: assert len(dataset) == 1 @@ -108,30 +108,30 @@ def test_already_extracted(self, dataset: USAVars) -> None: USAVars(root=dataset.root, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join("tests", "data", "usavars", "uar.zip") + pathname = os.path.join('tests', 'data', 'usavars', 'uar.zip') root = str(tmp_path) shutil.copy(pathname, root) csvs = [ - "elevation.csv", - "population.csv", - "treecover.csv", - "income.csv", - "nightlights.csv", - "roads.csv", - "housing.csv", + 'elevation.csv', + 'population.csv', + 'treecover.csv', + 'income.csv', + 'nightlights.csv', + 'roads.csv', + 'housing.csv', ] for csv in csvs: - shutil.copy(os.path.join("tests", "data", "usavars", csv), root) - splits = ["train_split.txt", "val_split.txt", "test_split.txt"] + shutil.copy(os.path.join('tests', 'data', 'usavars', csv), root) + splits = ['train_split.txt', 'val_split.txt', 'test_split.txt'] for split in splits: - shutil.copy(os.path.join("tests", "data", "usavars", split), root) + shutil.copy(os.path.join('tests', 'data', 'usavars', split), root) USAVars(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): USAVars(str(tmp_path)) def test_plot(self, dataset: USAVars) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 018b0d9bae2..092e8864acf 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -42,45 +42,45 @@ class TestDatasetNotFoundError: def test_none(self) -> None: ds: Dataset[Any] = Dataset() - match = "Dataset not found." + match = 'Dataset not found.' with pytest.raises(DatasetNotFoundError, match=match): raise DatasetNotFoundError(ds) def test_root(self) -> None: ds: Dataset[Any] = Dataset() - ds.root = "foo" # type: ignore[attr-defined] + ds.root = 'foo' # type: ignore[attr-defined] match = "Dataset not found in `root='foo'` and cannot be automatically " - match += "downloaded, either specify a different `root` or manually " - match += "download the dataset." + match += 'downloaded, either specify a different `root` or manually ' + match += 'download the dataset.' with pytest.raises(DatasetNotFoundError, match=match): raise DatasetNotFoundError(ds) def test_paths(self) -> None: ds: Dataset[Any] = Dataset() - ds.paths = "foo" # type: ignore[attr-defined] + ds.paths = 'foo' # type: ignore[attr-defined] match = "Dataset not found in `paths='foo'` and cannot be automatically " - match += "downloaded, either specify a different `paths` or manually " - match += "download the dataset." + match += 'downloaded, either specify a different `paths` or manually ' + match += 'download the dataset.' with pytest.raises(DatasetNotFoundError, match=match): raise DatasetNotFoundError(ds) def test_root_download(self) -> None: ds: Dataset[Any] = Dataset() - ds.root = "foo" # type: ignore[attr-defined] + ds.root = 'foo' # type: ignore[attr-defined] ds.download = False # type: ignore[attr-defined] match = "Dataset not found in `root='foo'` and `download=False`, either " - match += "specify a different `root` or use `download=True` to automatically " - match += "download the dataset." + match += 'specify a different `root` or use `download=True` to automatically ' + match += 'download the dataset.' with pytest.raises(DatasetNotFoundError, match=match): raise DatasetNotFoundError(ds) def test_paths_download(self) -> None: ds: Dataset[Any] = Dataset() - ds.paths = "foo" # type: ignore[attr-defined] + ds.paths = 'foo' # type: ignore[attr-defined] ds.download = False # type: ignore[attr-defined] match = "Dataset not found in `paths='foo'` and `download=False`, either " - match += "specify a different `paths` or use `download=True` to automatically " - match += "download the dataset." + match += 'specify a different `paths` or use `download=True` to automatically ' + match += 'download the dataset.' with pytest.raises(DatasetNotFoundError, match=match): raise DatasetNotFoundError(ds) @@ -90,17 +90,17 @@ def mock_missing_module(monkeypatch: MonkeyPatch) -> None: import_orig = builtins.__import__ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name in ["radiant_mlhub", "rarfile", "zipfile_deflate64"]: + if name in ['radiant_mlhub', 'rarfile', 'zipfile_deflate64']: raise ImportError() return import_orig(name, *args, **kwargs) - monkeypatch.setattr(builtins, "__import__", mocked_import) + monkeypatch.setattr(builtins, '__import__', mocked_import) class MLHubDataset: def download(self, output_dir: str, **kwargs: str) -> None: glob_path = os.path.join( - "tests", "data", "ref_african_crops_kenya_02", "*.tar.gz" + 'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz' ) for tarball in glob.iglob(glob_path): shutil.copy(tarball, output_dir) @@ -109,7 +109,7 @@ def download(self, output_dir: str, **kwargs: str) -> None: class Collection: def download(self, output_dir: str, **kwargs: str) -> None: glob_path = os.path.join( - "tests", "data", "ref_african_crops_kenya_02", "*.tar.gz" + 'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz' ) for tarball in glob.iglob(glob_path): shutil.copy(tarball, output_dir) @@ -132,50 +132,50 @@ def test_mock_missing_module(mock_missing_module: None) -> None: @pytest.mark.parametrize( - "src", + 'src', [ - os.path.join("cowc_detection", "COWC_Detection_Columbus_CSUAV_AFRL.tbz"), - os.path.join("cowc_detection", "COWC_test_list_detection.txt.bz2"), - os.path.join("vhr10", "NWPU VHR-10 dataset.rar"), - os.path.join("landcoverai", "landcover.ai.v1.zip"), - os.path.join("chesapeake", "BAYWIDE", "Baywide_13Class_20132014.zip"), - os.path.join("sen12ms", "ROIs1158_spring_lc.tar.gz"), + os.path.join('cowc_detection', 'COWC_Detection_Columbus_CSUAV_AFRL.tbz'), + os.path.join('cowc_detection', 'COWC_test_list_detection.txt.bz2'), + os.path.join('vhr10', 'NWPU VHR-10 dataset.rar'), + os.path.join('landcoverai', 'landcover.ai.v1.zip'), + os.path.join('chesapeake', 'BAYWIDE', 'Baywide_13Class_20132014.zip'), + os.path.join('sen12ms', 'ROIs1158_spring_lc.tar.gz'), ], ) def test_extract_archive(src: str, tmp_path: Path) -> None: - if src.endswith(".rar"): - pytest.importorskip("rarfile", minversion="4") - if src.startswith("chesapeake"): - pytest.importorskip("zipfile_deflate64") - extract_archive(os.path.join("tests", "data", src), str(tmp_path)) + if src.endswith('.rar'): + pytest.importorskip('rarfile', minversion='4') + if src.startswith('chesapeake'): + pytest.importorskip('zipfile_deflate64') + extract_archive(os.path.join('tests', 'data', src), str(tmp_path)) def test_missing_rarfile(mock_missing_module: None) -> None: with pytest.raises( ImportError, - match="rarfile is not installed and is required to extract this dataset", + match='rarfile is not installed and is required to extract this dataset', ): extract_archive( - os.path.join("tests", "data", "vhr10", "NWPU VHR-10 dataset.rar") + os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.rar') ) def test_missing_zipfile_deflate64(mock_missing_module: None) -> None: # Should fallback on Python builtin zipfile - extract_archive(os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip")) + extract_archive(os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')) def test_unsupported_scheme() -> None: with pytest.raises( - RuntimeError, match="src file has unknown archival/compression scheme" + RuntimeError, match='src file has unknown archival/compression scheme' ): - extract_archive("foo.bar") + extract_archive('foo.bar') def test_download_and_extract_archive(tmp_path: Path, monkeypatch: MonkeyPatch) -> None: - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) download_and_extract_archive( - os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip"), + os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip'), str(tmp_path), ) @@ -183,38 +183,38 @@ def test_download_and_extract_archive(tmp_path: Path, monkeypatch: MonkeyPatch) def test_download_radiant_mlhub_dataset( tmp_path: Path, monkeypatch: MonkeyPatch ) -> None: - radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") - monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch_dataset) - download_radiant_mlhub_dataset("", str(tmp_path)) + radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') + monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset) + download_radiant_mlhub_dataset('', str(tmp_path)) def test_download_radiant_mlhub_collection( tmp_path: Path, monkeypatch: MonkeyPatch ) -> None: - radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection) - download_radiant_mlhub_collection("", str(tmp_path)) + radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) + download_radiant_mlhub_collection('', str(tmp_path)) def test_missing_radiant_mlhub(mock_missing_module: None) -> None: with pytest.raises( ImportError, - match="radiant_mlhub is not installed and is required to download this dataset", + match='radiant_mlhub is not installed and is required to download this dataset', ): - download_radiant_mlhub_dataset("", "") + download_radiant_mlhub_dataset('', '') with pytest.raises( ImportError, - match="radiant_mlhub is not installed and is required to download this" - + " collection", + match='radiant_mlhub is not installed and is required to download this' + + ' collection', ): - download_radiant_mlhub_collection("", "") + download_radiant_mlhub_collection('', '') class TestBoundingBox: def test_repr_str(self) -> None: bbox = BoundingBox(0, 1, 2.0, 3.0, -5, -4) - expected = "BoundingBox(minx=0, maxx=1, miny=2.0, maxy=3.0, mint=-5, maxt=-4)" + expected = 'BoundingBox(minx=0, maxx=1, miny=2.0, maxy=3.0, mint=-5, maxt=-4)' assert repr(bbox) == expected assert str(bbox) == expected @@ -243,7 +243,7 @@ def test_iter(self) -> None: assert i == 6 @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ # Same box ((0, 1, 0, 1, 0, 1), True), @@ -278,7 +278,7 @@ def test_contains( assert (bbox1 in bbox2) == expected @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ # Same box ((0, 1, 0, 1, 0, 1), (0, 1, 0, 1, 0, 1)), @@ -314,7 +314,7 @@ def test_or( assert (bbox1 | bbox2) == bbox3 @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ # Same box ((0, 1, 0, 1, 0, 1), (0, 1, 0, 1, 0, 1)), @@ -348,7 +348,7 @@ def test_and_intersection( assert (bbox1 & bbox2) == bbox3 @pytest.mark.parametrize( - "test_input", + 'test_input', [ # No overlap (0.5, 1.5, 0.5, 1.5, 2, 3), @@ -364,12 +364,12 @@ def test_and_no_intersection( bbox2 = BoundingBox(*test_input) with pytest.raises( ValueError, - match=re.escape(f"Bounding boxes {bbox1} and {bbox2} do not overlap"), + match=re.escape(f'Bounding boxes {bbox1} and {bbox2} do not overlap'), ): bbox1 & bbox2 @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ # Rectangular prism ((0, 1, 0, 1, 0, 1), 1), @@ -389,7 +389,7 @@ def test_area( assert bbox.area == expected @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ # Rectangular prism ((0, 1, 0, 1, 0, 1), 1), @@ -409,7 +409,7 @@ def test_volume( assert bbox.volume == expected @pytest.mark.parametrize( - "test_input,expected", + 'test_input,expected', [ # Same box ((0, 1, 0, 1, 0, 1), True), @@ -444,7 +444,7 @@ def test_intersects( assert bbox1.intersects(bbox2) == bbox2.intersects(bbox1) == expected @pytest.mark.parametrize( - "proportion,horizontal,expected", + 'proportion,horizontal,expected', [ (0.25, True, ((0, 0.25, 0, 1, 0, 1), (0.25, 1, 0, 1, 0, 1))), (0.25, False, ((0, 1, 0, 0.25, 0, 1), (0, 1, 0.25, 1, 0, 1))), @@ -468,7 +468,7 @@ def test_split( def test_split_error(self) -> None: bbox = BoundingBox(0, 1, 0, 1, 0, 1) with pytest.raises( - ValueError, match="Input proportion must be between 0 and 1." + ValueError, match='Input proportion must be between 0 and 1.' ): bbox.split(1.5) @@ -498,54 +498,54 @@ def test_invalid_t(self) -> None: @pytest.mark.parametrize( - "date_string,format,min_datetime,max_datetime", + 'date_string,format,min_datetime,max_datetime', [ - ("", "", 0, sys.maxsize), + ('', '', 0, sys.maxsize), ( - "2021", - "%Y", + '2021', + '%Y', datetime(2021, 1, 1, 0, 0, 0, 0).timestamp(), datetime(2021, 12, 31, 23, 59, 59, 999999).timestamp(), ), ( - "2021-09", - "%Y-%m", + '2021-09', + '%Y-%m', datetime(2021, 9, 1, 0, 0, 0, 0).timestamp(), datetime(2021, 9, 30, 23, 59, 59, 999999).timestamp(), ), ( - "Dec 21", - "%b %y", + 'Dec 21', + '%b %y', datetime(2021, 12, 1, 0, 0, 0, 0).timestamp(), datetime(2021, 12, 31, 23, 59, 59, 999999).timestamp(), ), ( - "2021-09-13", - "%Y-%m-%d", + '2021-09-13', + '%Y-%m-%d', datetime(2021, 9, 13, 0, 0, 0, 0).timestamp(), datetime(2021, 9, 13, 23, 59, 59, 999999).timestamp(), ), ( - "2021-09-13 17", - "%Y-%m-%d %H", + '2021-09-13 17', + '%Y-%m-%d %H', datetime(2021, 9, 13, 17, 0, 0, 0).timestamp(), datetime(2021, 9, 13, 17, 59, 59, 999999).timestamp(), ), ( - "2021-09-13 17:21", - "%Y-%m-%d %H:%M", + '2021-09-13 17:21', + '%Y-%m-%d %H:%M', datetime(2021, 9, 13, 17, 21, 0, 0).timestamp(), datetime(2021, 9, 13, 17, 21, 59, 999999).timestamp(), ), ( - "2021-09-13 17:21:53", - "%Y-%m-%d %H:%M:%S", + '2021-09-13 17:21:53', + '%Y-%m-%d %H:%M:%S', datetime(2021, 9, 13, 17, 21, 53, 0).timestamp(), datetime(2021, 9, 13, 17, 21, 53, 999999).timestamp(), ), ( - "2021-09-13 17:21:53:000123", - "%Y-%m-%d %H:%M:%S:%f", + '2021-09-13 17:21:53:000123', + '%Y-%m-%d %H:%M:%S:%f', datetime(2021, 9, 13, 17, 21, 53, 123).timestamp(), datetime(2021, 9, 13, 17, 21, 53, 123).timestamp(), ), @@ -560,81 +560,81 @@ def test_disambiguate_timestamp( class TestCollateFunctionsMatchingKeys: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def samples(self) -> list[dict[str, Any]]: return [ - {"image": torch.tensor([1, 2, 0]), "crs": CRS.from_epsg(2000)}, - {"image": torch.tensor([0, 0, 3]), "crs": CRS.from_epsg(2001)}, + {'image': torch.tensor([1, 2, 0]), 'crs': CRS.from_epsg(2000)}, + {'image': torch.tensor([0, 0, 3]), 'crs': CRS.from_epsg(2001)}, ] def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None: sample = stack_samples(samples) - assert sample["image"].size() == torch.Size([2, 3]) - assert torch.allclose(sample["image"], torch.tensor([[1, 2, 0], [0, 0, 3]])) - assert sample["crs"] == [CRS.from_epsg(2000), CRS.from_epsg(2001)] + assert sample['image'].size() == torch.Size([2, 3]) + assert torch.allclose(sample['image'], torch.tensor([[1, 2, 0], [0, 0, 3]])) + assert sample['crs'] == [CRS.from_epsg(2000), CRS.from_epsg(2001)] new_samples = unbind_samples(sample) for i in range(2): - assert torch.allclose(samples[i]["image"], new_samples[i]["image"]) - assert samples[i]["crs"] == new_samples[i]["crs"] + assert torch.allclose(samples[i]['image'], new_samples[i]['image']) + assert samples[i]['crs'] == new_samples[i]['crs'] def test_concat_samples(self, samples: list[dict[str, Any]]) -> None: sample = concat_samples(samples) - assert sample["image"].size() == torch.Size([6]) - assert torch.allclose(sample["image"], torch.tensor([1, 2, 0, 0, 0, 3])) - assert sample["crs"] == CRS.from_epsg(2000) + assert sample['image'].size() == torch.Size([6]) + assert torch.allclose(sample['image'], torch.tensor([1, 2, 0, 0, 0, 3])) + assert sample['crs'] == CRS.from_epsg(2000) def test_merge_samples(self, samples: list[dict[str, Any]]) -> None: sample = merge_samples(samples) - assert sample["image"].size() == torch.Size([3]) - assert torch.allclose(sample["image"], torch.tensor([1, 2, 3])) - assert sample["crs"] == CRS.from_epsg(2001) + assert sample['image'].size() == torch.Size([3]) + assert torch.allclose(sample['image'], torch.tensor([1, 2, 3])) + assert sample['crs'] == CRS.from_epsg(2001) class TestCollateFunctionsDifferingKeys: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def samples(self) -> list[dict[str, Any]]: return [ - {"image": torch.tensor([1, 2, 0]), "crs1": CRS.from_epsg(2000)}, - {"mask": torch.tensor([0, 0, 3]), "crs2": CRS.from_epsg(2001)}, + {'image': torch.tensor([1, 2, 0]), 'crs1': CRS.from_epsg(2000)}, + {'mask': torch.tensor([0, 0, 3]), 'crs2': CRS.from_epsg(2001)}, ] def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None: sample = stack_samples(samples) - assert sample["image"].size() == torch.Size([1, 3]) - assert sample["mask"].size() == torch.Size([1, 3]) - assert torch.allclose(sample["image"], torch.tensor([[1, 2, 0]])) - assert torch.allclose(sample["mask"], torch.tensor([[0, 0, 3]])) - assert sample["crs1"] == [CRS.from_epsg(2000)] - assert sample["crs2"] == [CRS.from_epsg(2001)] + assert sample['image'].size() == torch.Size([1, 3]) + assert sample['mask'].size() == torch.Size([1, 3]) + assert torch.allclose(sample['image'], torch.tensor([[1, 2, 0]])) + assert torch.allclose(sample['mask'], torch.tensor([[0, 0, 3]])) + assert sample['crs1'] == [CRS.from_epsg(2000)] + assert sample['crs2'] == [CRS.from_epsg(2001)] new_samples = unbind_samples(sample) - assert torch.allclose(samples[0]["image"], new_samples[0]["image"]) - assert samples[0]["crs1"] == new_samples[0]["crs1"] - assert torch.allclose(samples[1]["mask"], new_samples[0]["mask"]) - assert samples[1]["crs2"] == new_samples[0]["crs2"] + assert torch.allclose(samples[0]['image'], new_samples[0]['image']) + assert samples[0]['crs1'] == new_samples[0]['crs1'] + assert torch.allclose(samples[1]['mask'], new_samples[0]['mask']) + assert samples[1]['crs2'] == new_samples[0]['crs2'] def test_concat_samples(self, samples: list[dict[str, Any]]) -> None: sample = concat_samples(samples) - assert sample["image"].size() == torch.Size([3]) - assert sample["mask"].size() == torch.Size([3]) - assert torch.allclose(sample["image"], torch.tensor([1, 2, 0])) - assert torch.allclose(sample["mask"], torch.tensor([0, 0, 3])) - assert sample["crs1"] == CRS.from_epsg(2000) - assert sample["crs2"] == CRS.from_epsg(2001) + assert sample['image'].size() == torch.Size([3]) + assert sample['mask'].size() == torch.Size([3]) + assert torch.allclose(sample['image'], torch.tensor([1, 2, 0])) + assert torch.allclose(sample['mask'], torch.tensor([0, 0, 3])) + assert sample['crs1'] == CRS.from_epsg(2000) + assert sample['crs2'] == CRS.from_epsg(2001) def test_merge_samples(self, samples: list[dict[str, Any]]) -> None: sample = merge_samples(samples) - assert sample["image"].size() == torch.Size([3]) - assert sample["mask"].size() == torch.Size([3]) - assert torch.allclose(sample["image"], torch.tensor([1, 2, 0])) - assert torch.allclose(sample["mask"], torch.tensor([0, 0, 3])) - assert sample["crs1"] == CRS.from_epsg(2000) - assert sample["crs2"] == CRS.from_epsg(2001) + assert sample['image'].size() == torch.Size([3]) + assert sample['mask'].size() == torch.Size([3]) + assert torch.allclose(sample['image'], torch.tensor([1, 2, 0])) + assert torch.allclose(sample['mask'], torch.tensor([0, 0, 3])) + assert sample['crs1'] == CRS.from_epsg(2000) + assert sample['crs2'] == CRS.from_epsg(2001) def test_existing_directory(tmp_path: Path) -> None: - subdir = tmp_path / "foo" / "bar" + subdir = tmp_path / 'foo' / 'bar' subdir.mkdir(parents=True) assert subdir.exists() @@ -644,7 +644,7 @@ def test_existing_directory(tmp_path: Path) -> None: def test_nonexisting_directory(tmp_path: Path) -> None: - subdir = tmp_path / "foo" / "bar" + subdir = tmp_path / 'foo' / 'bar' assert not subdir.exists() @@ -653,7 +653,7 @@ def test_nonexisting_directory(tmp_path: Path) -> None: def test_percentile_normalization() -> None: - img: "np.typing.NDArray[np.int_]" = np.array([[1, 2], [98, 100]]) + img: 'np.typing.NDArray[np.int_]' = np.array([[1, 2], [98, 100]]) img = percentile_normalization(img, 2, 98) assert img.min() == 0 @@ -661,11 +661,11 @@ def test_percentile_normalization() -> None: @pytest.mark.parametrize( - "array_dtype", + 'array_dtype', [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32, np.int64], ) -def test_array_to_tensor(array_dtype: "np.typing.DTypeLike") -> None: - array: "np.typing.NDArray[Any]" = np.zeros((2,), dtype=array_dtype) +def test_array_to_tensor(array_dtype: 'np.typing.DTypeLike') -> None: + array: 'np.typing.NDArray[Any]' = np.zeros((2,), dtype=array_dtype) array[0] = np.iinfo(array.dtype).min array[1] = np.iinfo(array.dtype).max tensor = array_to_tensor(array) diff --git a/tests/datasets/test_vaihingen.py b/tests/datasets/test_vaihingen.py index fe34bccea08..e4b36b99edd 100644 --- a/tests/datasets/test_vaihingen.py +++ b/tests/datasets/test_vaihingen.py @@ -16,16 +16,16 @@ class TestVaihingen2D: - @pytest.fixture(params=["train", "test"]) + @pytest.fixture(params=['train', 'test']) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> Vaihingen2D: - md5s = ["c15fbff78d307e51c73f609c0859afc3", "ec2c0a5149f2371479b38cf8cfbab961"] + md5s = ['c15fbff78d307e51c73f609c0859afc3', 'ec2c0a5149f2371479b38cf8cfbab961'] splits = { - "train": ["top_mosaic_09cm_area1.tif", "top_mosaic_09cm_area11.tif"], - "test": ["top_mosaic_09cm_area6.tif", "top_mosaic_09cm_area24.tif"], + 'train': ['top_mosaic_09cm_area1.tif', 'top_mosaic_09cm_area11.tif'], + 'test': ['top_mosaic_09cm_area6.tif', 'top_mosaic_09cm_area24.tif'], } - monkeypatch.setattr(Vaihingen2D, "md5s", md5s) - monkeypatch.setattr(Vaihingen2D, "splits", splits) - root = os.path.join("tests", "data", "vaihingen") + monkeypatch.setattr(Vaihingen2D, 'md5s', md5s) + monkeypatch.setattr(Vaihingen2D, 'splits', splits) + root = os.path.join('tests', 'data', 'vaihingen') split = request.param transforms = nn.Identity() return Vaihingen2D(root, split, transforms, checksum=True) @@ -33,19 +33,19 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> Vaihingen2D: def test_getitem(self, dataset: Vaihingen2D) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - assert x["image"].ndim == 3 - assert x["mask"].ndim == 2 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + assert x['image'].ndim == 3 + assert x['mask'].ndim == 2 def test_len(self, dataset: Vaihingen2D) -> None: assert len(dataset) == 2 def test_extract(self, tmp_path: Path) -> None: - root = os.path.join("tests", "data", "vaihingen") + root = os.path.join('tests', 'data', 'vaihingen') filenames = [ - "ISPRS_semantic_labeling_Vaihingen.zip", - "ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip", + 'ISPRS_semantic_labeling_Vaihingen.zip', + 'ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip', ] for filename in filenames: shutil.copyfile( @@ -55,29 +55,29 @@ def test_extract(self, tmp_path: Path) -> None: def test_corrupted(self, tmp_path: Path) -> None: filenames = [ - "ISPRS_semantic_labeling_Vaihingen.zip", - "ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip", + 'ISPRS_semantic_labeling_Vaihingen.zip', + 'ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip', ] for filename in filenames: - with open(os.path.join(tmp_path, filename), "w") as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + with open(os.path.join(tmp_path, filename), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): Vaihingen2D(root=str(tmp_path), checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - Vaihingen2D(split="foo") + Vaihingen2D(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): Vaihingen2D(str(tmp_path)) def test_plot(self, dataset: Vaihingen2D) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["mask"].clone() + x['prediction'] = x['mask'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index 5480acb4ef4..de4c6c2d507 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -18,7 +18,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import VHR10, DatasetNotFoundError -pytest.importorskip("pycocotools") +pytest.importorskip('pycocotools') def download_url(url: str, root: str, *args: str) -> None: @@ -26,21 +26,21 @@ def download_url(url: str, root: str, *args: str) -> None: class TestVHR10: - @pytest.fixture(params=["positive", "negative"]) + @pytest.fixture(params=['positive', 'negative']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> VHR10: - pytest.importorskip("rarfile", minversion="4") - monkeypatch.setattr(torchgeo.datasets.vhr10, "download_url", download_url) - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - url = os.path.join("tests", "data", "vhr10", "NWPU VHR-10 dataset.rar") - monkeypatch.setitem(VHR10.image_meta, "url", url) - md5 = "92769845cae6a4e8c74bfa1a0d1d4a80" - monkeypatch.setitem(VHR10.image_meta, "md5", md5) - url = os.path.join("tests", "data", "vhr10", "annotations.json") - monkeypatch.setitem(VHR10.target_meta, "url", url) - md5 = "567c4cd8c12624864ff04865de504c58" - monkeypatch.setitem(VHR10.target_meta, "md5", md5) + pytest.importorskip('rarfile', minversion='4') + monkeypatch.setattr(torchgeo.datasets.vhr10, 'download_url', download_url) + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) + url = os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.rar') + monkeypatch.setitem(VHR10.image_meta, 'url', url) + md5 = '92769845cae6a4e8c74bfa1a0d1d4a80' + monkeypatch.setitem(VHR10.image_meta, 'md5', md5) + url = os.path.join('tests', 'data', 'vhr10', 'annotations.json') + monkeypatch.setitem(VHR10.target_meta, 'url', url) + md5 = '567c4cd8c12624864ff04865de504c58' + monkeypatch.setitem(VHR10.target_meta, 'md5', md5) root = str(tmp_path) split = request.param transforms = nn.Identity() @@ -51,35 +51,35 @@ def mock_missing_modules(self, monkeypatch: MonkeyPatch) -> None: import_orig = builtins.__import__ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name in {"pycocotools.coco", "skimage.measure"}: + if name in {'pycocotools.coco', 'skimage.measure'}: raise ImportError() return import_orig(name, *args, **kwargs) - monkeypatch.setattr(builtins, "__import__", mocked_import) + monkeypatch.setattr(builtins, '__import__', mocked_import) def test_getitem(self, dataset: VHR10) -> None: for i in range(2): x = dataset[i] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - if dataset.split == "positive": - assert isinstance(x["labels"], torch.Tensor) - assert isinstance(x["boxes"], torch.Tensor) - if "masks" in x: - assert isinstance(x["masks"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + if dataset.split == 'positive': + assert isinstance(x['labels'], torch.Tensor) + assert isinstance(x['boxes'], torch.Tensor) + if 'masks' in x: + assert isinstance(x['masks'], torch.Tensor) def test_len(self, dataset: VHR10) -> None: - if dataset.split == "positive": + if dataset.split == 'positive': assert len(dataset) == 5 - elif dataset.split == "negative": + elif dataset.split == 'negative': assert len(dataset) == 150 def test_add(self, dataset: VHR10) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) - if dataset.split == "positive": + if dataset.split == 'positive': assert len(ds) == 10 - elif dataset.split == "negative": + elif dataset.split == 'negative': assert len(ds) == 300 def test_already_downloaded(self, dataset: VHR10) -> None: @@ -87,44 +87,44 @@ def test_already_downloaded(self, dataset: VHR10) -> None: def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - VHR10(split="train") + VHR10(split='train') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): VHR10(str(tmp_path)) def test_mock_missing_module( self, dataset: VHR10, mock_missing_modules: None ) -> None: - if dataset.split == "positive": + if dataset.split == 'positive': with pytest.raises( ImportError, - match="pycocotools is not installed and is required to use this datase", + match='pycocotools is not installed and is required to use this datase', ): VHR10(dataset.root, dataset.split) with pytest.raises( ImportError, - match="scikit-image is not installed and is required to plot masks", + match='scikit-image is not installed and is required to plot masks', ): x = dataset[0] dataset.plot(x) def test_plot(self, dataset: VHR10) -> None: - pytest.importorskip("skimage", minversion="0.18") + pytest.importorskip('skimage', minversion='0.18') x = dataset[1].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - if dataset.split == "positive": + if dataset.split == 'positive': scores = [0.7, 0.3, 0.7] for i in range(3): x = dataset[i] - x["prediction_labels"] = x["labels"] - x["prediction_boxes"] = x["boxes"] - x["prediction_scores"] = torch.Tensor([scores[i]]) - if "masks" in x: - x["prediction_masks"] = x["masks"] - dataset.plot(x, show_feats="masks") + x['prediction_labels'] = x['labels'] + x['prediction_boxes'] = x['boxes'] + x['prediction_scores'] = torch.Tensor([scores[i]]) + if 'masks' in x: + x['prediction_masks'] = x['masks'] + dataset.plot(x, show_feats='masks') plt.close() diff --git a/tests/datasets/test_western_usa_live_fuel_moisture.py b/tests/datasets/test_western_usa_live_fuel_moisture.py index 3337965228e..e2c9120ae02 100644 --- a/tests/datasets/test_western_usa_live_fuel_moisture.py +++ b/tests/datasets/test_western_usa_live_fuel_moisture.py @@ -16,10 +16,10 @@ class Collection: def download(self, output_dir: str, **kwargs: str) -> None: tarball_path = os.path.join( - "tests", - "data", - "western_usa_live_fuel_moisture", - "su_sar_moisture_content.tar.gz", + 'tests', + 'data', + 'western_usa_live_fuel_moisture', + 'su_sar_moisture_content.tar.gz', ) shutil.copy(tarball_path, output_dir) @@ -33,41 +33,41 @@ class TestWesternUSALiveFuelMoisture: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path ) -> WesternUSALiveFuelMoisture: - radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3") - monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) - md5 = "ecbc9269dd27c4efe7aa887960054351" - monkeypatch.setattr(WesternUSALiveFuelMoisture, "md5", md5) + radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') + monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) + md5 = 'ecbc9269dd27c4efe7aa887960054351' + monkeypatch.setattr(WesternUSALiveFuelMoisture, 'md5', md5) root = str(tmp_path) transforms = nn.Identity() return WesternUSALiveFuelMoisture( - root, transforms=transforms, download=True, api_key="", checksum=True + root, transforms=transforms, download=True, api_key='', checksum=True ) - @pytest.mark.parametrize("index", [0, 1, 2]) + @pytest.mark.parametrize('index', [0, 1, 2]) def test_getitem(self, dataset: WesternUSALiveFuelMoisture, index: int) -> None: x = dataset[index] assert isinstance(x, dict) - assert isinstance(x["input"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) + assert isinstance(x['input'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) def test_len(self, dataset: WesternUSALiveFuelMoisture) -> None: assert len(dataset) == 3 def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( - "tests", - "data", - "western_usa_live_fuel_moisture", - "su_sar_moisture_content.tar.gz", + 'tests', + 'data', + 'western_usa_live_fuel_moisture', + 'su_sar_moisture_content.tar.gz', ) root = str(tmp_path) shutil.copy(pathname, root) WesternUSALiveFuelMoisture(root) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): WesternUSALiveFuelMoisture(str(tmp_path)) def test_invalid_features(self, dataset: WesternUSALiveFuelMoisture) -> None: - with pytest.raises(AssertionError, match="Invalid input variable name."): - WesternUSALiveFuelMoisture(dataset.root, input_features=["foo"]) + with pytest.raises(AssertionError, match='Invalid input variable name.'): + WesternUSALiveFuelMoisture(dataset.root, input_features=['foo']) diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview2.py index 28292775a46..7689acf5f78 100644 --- a/tests/datasets/test_xview2.py +++ b/tests/datasets/test_xview2.py @@ -16,25 +16,25 @@ class TestXView2: - @pytest.fixture(params=["train", "test"]) + @pytest.fixture(params=['train', 'test']) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: monkeypatch.setattr( XView2, - "metadata", + 'metadata', { - "train": { - "filename": "train_images_labels_targets.tar.gz", - "md5": "373e61d55c1b294aa76b94dbbd81332b", - "directory": "train", + 'train': { + 'filename': 'train_images_labels_targets.tar.gz', + 'md5': '373e61d55c1b294aa76b94dbbd81332b', + 'directory': 'train', }, - "test": { - "filename": "test_images_labels_targets.tar.gz", - "md5": "bc6de81c956a3bada38b5b4e246266a1", - "directory": "test", + 'test': { + 'filename': 'test_images_labels_targets.tar.gz', + 'md5': 'bc6de81c956a3bada38b5b4e246266a1', + 'directory': 'test', }, }, ) - root = os.path.join("tests", "data", "xview2") + root = os.path.join('tests', 'data', 'xview2') split = request.param transforms = nn.Identity() return XView2(root, split, transforms, checksum=True) @@ -42,8 +42,8 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: def test_getitem(self, dataset: XView2) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) def test_len(self, dataset: XView2) -> None: assert len(dataset) == 2 @@ -51,44 +51,44 @@ def test_len(self, dataset: XView2) -> None: def test_extract(self, tmp_path: Path) -> None: shutil.copyfile( os.path.join( - "tests", "data", "xview2", "train_images_labels_targets.tar.gz" + 'tests', 'data', 'xview2', 'train_images_labels_targets.tar.gz' ), - os.path.join(tmp_path, "train_images_labels_targets.tar.gz"), + os.path.join(tmp_path, 'train_images_labels_targets.tar.gz'), ) shutil.copyfile( os.path.join( - "tests", "data", "xview2", "test_images_labels_targets.tar.gz" + 'tests', 'data', 'xview2', 'test_images_labels_targets.tar.gz' ), - os.path.join(tmp_path, "test_images_labels_targets.tar.gz"), + os.path.join(tmp_path, 'test_images_labels_targets.tar.gz'), ) XView2(root=str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: with open( - os.path.join(tmp_path, "train_images_labels_targets.tar.gz"), "w" + os.path.join(tmp_path, 'train_images_labels_targets.tar.gz'), 'w' ) as f: - f.write("bad") + f.write('bad') with open( - os.path.join(tmp_path, "test_images_labels_targets.tar.gz"), "w" + os.path.join(tmp_path, 'test_images_labels_targets.tar.gz'), 'w' ) as f: - f.write("bad") - with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + f.write('bad') + with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): XView2(root=str(tmp_path), checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): - XView2(split="foo") + XView2(split='foo') def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): XView2(str(tmp_path)) def test_plot(self, dataset: XView2) -> None: x = dataset[0].copy() - dataset.plot(x, suptitle="Test") + dataset.plot(x, suptitle='Test') plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction"] = x["mask"][0].clone() + x['prediction'] = x['mask'][0].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_zuericrop.py b/tests/datasets/test_zuericrop.py index 1dd5336c6c6..330866b36f0 100644 --- a/tests/datasets/test_zuericrop.py +++ b/tests/datasets/test_zuericrop.py @@ -16,7 +16,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, ZueriCrop -pytest.importorskip("h5py", minversion="3") +pytest.importorskip('h5py', minversion='3') def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -26,15 +26,15 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestZueriCrop: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ZueriCrop: - monkeypatch.setattr(torchgeo.datasets.zuericrop, "download_url", download_url) - data_dir = os.path.join("tests", "data", "zuericrop") + monkeypatch.setattr(torchgeo.datasets.zuericrop, 'download_url', download_url) + data_dir = os.path.join('tests', 'data', 'zuericrop') urls = [ - os.path.join(data_dir, "ZueriCrop.hdf5"), - os.path.join(data_dir, "labels.csv"), + os.path.join(data_dir, 'ZueriCrop.hdf5'), + os.path.join(data_dir, 'labels.csv'), ] - md5s = ["1635231df67f3d25f4f1e62c98e221a4", "5118398c7a5bbc246f5f6bb35d8d529b"] - monkeypatch.setattr(ZueriCrop, "urls", urls) - monkeypatch.setattr(ZueriCrop, "md5s", md5s) + md5s = ['1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b'] + monkeypatch.setattr(ZueriCrop, 'urls', urls) + monkeypatch.setattr(ZueriCrop, 'md5s', md5s) root = str(tmp_path) transforms = nn.Identity() return ZueriCrop(root=root, transforms=transforms, download=True, checksum=True) @@ -44,33 +44,33 @@ def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: import_orig = builtins.__import__ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == "h5py": + if name == 'h5py': raise ImportError() return import_orig(name, *args, **kwargs) - monkeypatch.setattr(builtins, "__import__", mocked_import) + monkeypatch.setattr(builtins, '__import__', mocked_import) def test_getitem(self, dataset: ZueriCrop) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["mask"], torch.Tensor) - assert isinstance(x["boxes"], torch.Tensor) - assert isinstance(x["label"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + assert isinstance(x['boxes'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) # Image tests - assert x["image"].ndim == 4 + assert x['image'].ndim == 4 # Instance masks tests - assert x["mask"].ndim == 3 - assert x["mask"].shape[-2:] == x["image"].shape[-2:] + assert x['mask'].ndim == 3 + assert x['mask'].shape[-2:] == x['image'].shape[-2:] # Bboxes tests - assert x["boxes"].ndim == 2 - assert x["boxes"].shape[1] == 4 + assert x['boxes'].ndim == 2 + assert x['boxes'].shape[1] == 4 # Labels tests - assert x["label"].ndim == 1 + assert x['label'].ndim == 1 def test_len(self, dataset: ZueriCrop) -> None: assert len(dataset) == 2 @@ -79,7 +79,7 @@ def test_already_downloaded(self, dataset: ZueriCrop) -> None: ZueriCrop(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): ZueriCrop(str(tmp_path)) def test_mock_missing_module( @@ -87,26 +87,26 @@ def test_mock_missing_module( ) -> None: with pytest.raises( ImportError, - match="h5py is not installed and is required to use this dataset", + match='h5py is not installed and is required to use this dataset', ): ZueriCrop(dataset.root, download=True, checksum=True) def test_invalid_bands(self) -> None: with pytest.raises(ValueError): - ZueriCrop(bands=("OK", "BK")) + ZueriCrop(bands=('OK', 'BK')) def test_plot(self, dataset: ZueriCrop) -> None: - dataset.plot(dataset[0], suptitle="Test") + dataset.plot(dataset[0], suptitle='Test') plt.close() sample = dataset[0] - sample["prediction"] = sample["mask"].clone() - dataset.plot(sample, suptitle="prediction") + sample['prediction'] = sample['mask'].clone() + dataset.plot(sample, suptitle='prediction') plt.close() def test_plot_rgb(self, dataset: ZueriCrop) -> None: - dataset = ZueriCrop(root=dataset.root, bands=("B02",)) + dataset = ZueriCrop(root=dataset.root, bands=('B02',)) with pytest.raises( - RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): - dataset.plot(dataset[0], time_step=0, suptitle="Single Band") + dataset.plot(dataset[0], time_step=0, suptitle='Single Band') diff --git a/tests/models/test_api.py b/tests/models/test_api.py index 18d1f1a8028..c2afd42c0f5 100644 --- a/tests/models/test_api.py +++ b/tests/models/test_api.py @@ -27,13 +27,13 @@ enums = [ResNet18_Weights, ResNet50_Weights, ViTSmall16_Weights, Swin_V2_B_Weights] -@pytest.mark.parametrize("builder", builders) +@pytest.mark.parametrize('builder', builders) def test_get_model(builder: Callable[..., nn.Module]) -> None: model = get_model(builder.__name__) assert isinstance(model, nn.Module) -@pytest.mark.parametrize("builder", builders) +@pytest.mark.parametrize('builder', builders) def test_get_model_weights(builder: Callable[..., nn.Module]) -> None: weights = get_model_weights(builder) assert isinstance(weights, enum.EnumMeta) @@ -41,7 +41,7 @@ def test_get_model_weights(builder: Callable[..., nn.Module]) -> None: assert isinstance(weights, enum.EnumMeta) -@pytest.mark.parametrize("enum", enums) +@pytest.mark.parametrize('enum', enums) def test_get_weight(enum: WeightsEnum) -> None: for weight in enum: assert weight == get_weight(str(weight)) diff --git a/tests/models/test_changestar.py b/tests/models/test_changestar.py index d9552de7210..b7bf06c75d8 100644 --- a/tests/models/test_changestar.py +++ b/tests/models/test_changestar.py @@ -7,7 +7,7 @@ from torchgeo.models import ChangeMixin, ChangeStar, ChangeStarFarSeg -BACKBONE = ["resnet18", "resnet34", "resnet50", "resnet101"] +BACKBONE = ['resnet18', 'resnet34', 'resnet50', 'resnet101'] IN_CHANNELS = [64, 128] INNNR_CHANNELS = [16, 32, 64] NC = [1, 2, 4] @@ -18,46 +18,46 @@ class TestChangeStar: @torch.no_grad() def test_changestar_farseg_classes(self) -> None: model = ChangeStarFarSeg( - classes=4, backbone="resnet50", backbone_pretrained=False + classes=4, backbone='resnet50', backbone_pretrained=False ) x = torch.randn(2, 2, 3, 128, 128) y = model(x) - assert y["bi_seg_logit"].shape[2] == 4 + assert y['bi_seg_logit'].shape[2] == 4 @torch.no_grad() def test_changestar_farseg_output_size(self) -> None: model = ChangeStarFarSeg( - classes=4, backbone="resnet50", backbone_pretrained=False + classes=4, backbone='resnet50', backbone_pretrained=False ) model.eval() x = torch.randn(2, 2, 3, 128, 128) y = model(x) - assert y["bi_seg_logit"].shape[3] == 128 and y["bi_seg_logit"].shape[4] == 128 - assert y["change_prob"].shape[2] == 128 and y["change_prob"].shape[3] == 128 + assert y['bi_seg_logit'].shape[3] == 128 and y['bi_seg_logit'].shape[4] == 128 + assert y['change_prob'].shape[2] == 128 and y['change_prob'].shape[3] == 128 model.train() y = model(x) - assert y["bi_seg_logit"].shape[3] == 128 and y["bi_seg_logit"].shape[4] == 128 - assert y["bi_change_logit"].shape[3] == 128 - assert y["bi_change_logit"].shape[4] == 128 + assert y['bi_seg_logit'].shape[3] == 128 and y['bi_seg_logit'].shape[4] == 128 + assert y['bi_change_logit'].shape[3] == 128 + assert y['bi_change_logit'].shape[4] == 128 - @pytest.mark.parametrize("backbone", BACKBONE) + @pytest.mark.parametrize('backbone', BACKBONE) def test_valid_changestar_farseg_backbone(self, backbone: str) -> None: ChangeStarFarSeg(classes=4, backbone=backbone, backbone_pretrained=False) def test_invalid_changestar_farseg_backbone(self) -> None: - match = "unknown backbone: anynet." + match = 'unknown backbone: anynet.' with pytest.raises(ValueError, match=match): - ChangeStarFarSeg(classes=4, backbone="anynet", backbone_pretrained=False) + ChangeStarFarSeg(classes=4, backbone='anynet', backbone_pretrained=False) @torch.no_grad() - @pytest.mark.parametrize("inc", IN_CHANNELS) - @pytest.mark.parametrize("innerc", INNNR_CHANNELS) - @pytest.mark.parametrize("nc", NC) - @pytest.mark.parametrize("sf", SF) + @pytest.mark.parametrize('inc', IN_CHANNELS) + @pytest.mark.parametrize('innerc', INNNR_CHANNELS) + @pytest.mark.parametrize('nc', NC) + @pytest.mark.parametrize('sf', SF) def test_changemixin_output_size( self, inc: int, innerc: int, nc: int, sf: int ) -> None: @@ -93,8 +93,8 @@ def test_changestar(self) -> None: m.eval() y = m(torch.rand(3, 2, 3, 64, 64)) - assert y["bi_seg_logit"].shape == (3, 2, 2, 64, 64) - assert y["change_prob"].shape == (3, 1, 64, 64) + assert y['bi_seg_logit'].shape == (3, 2, 2, 64, 64) + assert y['change_prob'].shape == (3, 1, 64, 64) @torch.no_grad() def test_changestar_invalid_inference_mode(self) -> None: @@ -110,7 +110,7 @@ def test_changestar_invalid_inference_mode(self) -> None: nn.modules.UpsamplingBilinear2d(scale_factor=2.0), ) - match = "Unknown inference_mode: random" + match = 'Unknown inference_mode: random' with pytest.raises(ValueError, match=match): ChangeStar( dense_feature_extractor, @@ -118,11 +118,11 @@ def test_changestar_invalid_inference_mode(self) -> None: ChangeMixin( in_channels=32 * 2, inner_channels=16, num_convs=4, scale_factor=2.0 ), - inference_mode="random", + inference_mode='random', ) @torch.no_grad() - @pytest.mark.parametrize("inference_mode", ["t1t2", "t2t1", "mean"]) + @pytest.mark.parametrize('inference_mode', ['t1t2', 't2t1', 'mean']) def test_changestar_inference_output_size(self, inference_mode: str) -> None: dense_feature_extractor = nn.modules.Sequential( nn.modules.Conv2d(3, 32, 3, 1, 1), @@ -149,5 +149,5 @@ def test_changestar_inference_output_size(self, inference_mode: str) -> None: x = torch.randn(2, 2, 3, 128, 128) y = m(x) - assert y["bi_seg_logit"].shape == (2, 2, CLASSES, 128, 128) - assert y["change_prob"].shape == (2, 1, 128, 128) + assert y['bi_seg_logit'].shape == (2, 2, CLASSES, 128, 128) + assert y['change_prob'].shape == (2, 1, 128, 128) diff --git a/tests/models/test_dofa.py b/tests/models/test_dofa.py index bd50b8eb739..34f1701d7b1 100644 --- a/tests/models/test_dofa.py +++ b/tests/models/test_dofa.py @@ -29,7 +29,7 @@ def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: class TestDOFA: @pytest.mark.parametrize( - "wavelengths", + 'wavelengths', [ # Gaofen [0.443, 0.565, 0.763, 0.765, 0.910], @@ -88,14 +88,14 @@ def weights(self, request: SubRequest) -> WeightsEnum: def mocked_weights( self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum ) -> WeightsEnum: - path = tmp_path / f"{weights}.pth" + path = tmp_path / f'{weights}.pth' model = dofa_base_patch16_224() torch.save(model.state_dict(), path) try: - monkeypatch.setattr(weights.value, "url", str(path)) + monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: - monkeypatch.setattr(weights, "url", str(path)) - monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + monkeypatch.setattr(weights, 'url', str(path)) + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_dofa(self) -> None: @@ -107,7 +107,7 @@ def test_dofa_weights(self, mocked_weights: WeightsEnum) -> None: def test_transforms(self, mocked_weights: WeightsEnum) -> None: c = 4 sample = { - "image": torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) + 'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) } mocked_weights.transforms(sample) @@ -125,14 +125,14 @@ def weights(self, request: SubRequest) -> WeightsEnum: def mocked_weights( self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum ) -> WeightsEnum: - path = tmp_path / f"{weights}.pth" + path = tmp_path / f'{weights}.pth' model = dofa_large_patch16_224() torch.save(model.state_dict(), path) try: - monkeypatch.setattr(weights.value, "url", str(path)) + monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: - monkeypatch.setattr(weights, "url", str(path)) - monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + monkeypatch.setattr(weights, 'url', str(path)) + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_dofa(self) -> None: @@ -144,7 +144,7 @@ def test_dofa_weights(self, mocked_weights: WeightsEnum) -> None: def test_transforms(self, mocked_weights: WeightsEnum) -> None: c = 4 sample = { - "image": torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) + 'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) } mocked_weights.transforms(sample) diff --git a/tests/models/test_farseg.py b/tests/models/test_farseg.py index c87a172a378..0ea1cf864fd 100644 --- a/tests/models/test_farseg.py +++ b/tests/models/test_farseg.py @@ -11,12 +11,12 @@ class TestFarSeg: @torch.no_grad() @pytest.mark.parametrize( - "backbone,pretrained", + 'backbone,pretrained', [ - ("resnet18", True), - ("resnet34", False), - ("resnet50", True), - ("resnet101", False), + ('resnet18', True), + ('resnet34', False), + ('resnet50', True), + ('resnet101', False), ], ) def test_valid_backbone(self, backbone: str, pretrained: bool) -> None: @@ -27,6 +27,6 @@ def test_valid_backbone(self, backbone: str, pretrained: bool) -> None: assert y.shape == (2, 4, 128, 128) def test_invalid_backbone(self) -> None: - match = "unknown backbone: anynet." + match = 'unknown backbone: anynet.' with pytest.raises(ValueError, match=match): - FarSeg(classes=4, backbone="anynet", backbone_pretrained=False) + FarSeg(classes=4, backbone='anynet', backbone_pretrained=False) diff --git a/tests/models/test_fcn.py b/tests/models/test_fcn.py index cda8c90de6e..b1c13b52559 100644 --- a/tests/models/test_fcn.py +++ b/tests/models/test_fcn.py @@ -14,7 +14,7 @@ def test_in_channels(self) -> None: model(x) model = FCN(in_channels=3, classes=4, num_filters=10) - match = "to have 3 channels, but got 5 channels instead" + match = 'to have 3 channels, but got 5 channels instead' with pytest.raises(RuntimeError, match=match): model(x) diff --git a/tests/models/test_fcsiam.py b/tests/models/test_fcsiam.py index a66964e147c..2724fae73d1 100644 --- a/tests/models/test_fcsiam.py +++ b/tests/models/test_fcsiam.py @@ -13,8 +13,8 @@ class TestFCSiamConc: @torch.no_grad() - @pytest.mark.parametrize("b", BATCH_SIZE) - @pytest.mark.parametrize("c", CHANNELS) + @pytest.mark.parametrize('b', BATCH_SIZE) + @pytest.mark.parametrize('c', CHANNELS) def test_in_channels(self, b: int, c: int) -> None: classes = 2 t, h, w = 2, 64, 64 @@ -24,8 +24,8 @@ def test_in_channels(self, b: int, c: int) -> None: assert y.shape == (b, classes, h, w) @torch.no_grad() - @pytest.mark.parametrize("b", BATCH_SIZE) - @pytest.mark.parametrize("classes", CLASSES) + @pytest.mark.parametrize('b', BATCH_SIZE) + @pytest.mark.parametrize('classes', CLASSES) def test_classes(self, b: int, classes: int) -> None: t, c, h, w = 2, 3, 64, 64 model = FCSiamConc(in_channels=3, classes=classes, encoder_weights=None) @@ -36,8 +36,8 @@ def test_classes(self, b: int, classes: int) -> None: class TestFCSiamDiff: @torch.no_grad() - @pytest.mark.parametrize("b", BATCH_SIZE) - @pytest.mark.parametrize("c", CHANNELS) + @pytest.mark.parametrize('b', BATCH_SIZE) + @pytest.mark.parametrize('c', CHANNELS) def test_in_channels(self, b: int, c: int) -> None: classes = 2 t, h, w = 2, 64, 64 @@ -47,8 +47,8 @@ def test_in_channels(self, b: int, c: int) -> None: assert y.shape == (b, classes, h, w) @torch.no_grad() - @pytest.mark.parametrize("b", BATCH_SIZE) - @pytest.mark.parametrize("classes", CLASSES) + @pytest.mark.parametrize('b', BATCH_SIZE) + @pytest.mark.parametrize('classes', CLASSES) def test_classes(self, b: int, classes: int) -> None: t, c, h, w = 2, 3, 64, 64 model = FCSiamDiff(in_channels=3, classes=classes, encoder_weights=None) diff --git a/tests/models/test_rcf.py b/tests/models/test_rcf.py index f6d8091bc0d..b3a0f4bda83 100644 --- a/tests/models/test_rcf.py +++ b/tests/models/test_rcf.py @@ -12,17 +12,17 @@ class TestRCF: def test_in_channels(self) -> None: - model = RCF(in_channels=5, features=4, kernel_size=3, mode="gaussian") + model = RCF(in_channels=5, features=4, kernel_size=3, mode='gaussian') x = torch.randn(2, 5, 64, 64) model(x) - model = RCF(in_channels=3, features=4, kernel_size=3, mode="gaussian") - match = "to have 3 channels, but got 5 channels instead" + model = RCF(in_channels=3, features=4, kernel_size=3, mode='gaussian') + match = 'to have 3 channels, but got 5 channels instead' with pytest.raises(RuntimeError, match=match): model(x) def test_num_features(self) -> None: - model = RCF(in_channels=5, features=4, kernel_size=3, mode="gaussian") + model = RCF(in_channels=5, features=4, kernel_size=3, mode='gaussian') x = torch.randn(2, 5, 64, 64) y = model(x) assert y.shape[1] == 4 @@ -32,27 +32,27 @@ def test_num_features(self) -> None: assert y.shape[0] == 4 def test_untrainable(self) -> None: - model = RCF(in_channels=5, features=4, kernel_size=3, mode="gaussian") + model = RCF(in_channels=5, features=4, kernel_size=3, mode='gaussian') assert len(list(model.parameters())) == 0 def test_biases(self) -> None: - model = RCF(features=24, bias=10, mode="gaussian") + model = RCF(features=24, bias=10, mode='gaussian') assert torch.all(model.biases == 10) def test_seed(self) -> None: - weights1 = RCF(seed=1, mode="gaussian").weights - weights2 = RCF(seed=1, mode="gaussian").weights + weights1 = RCF(seed=1, mode='gaussian').weights + weights2 = RCF(seed=1, mode='gaussian').weights assert torch.allclose(weights1, weights2) def test_empirical(self) -> None: - root = os.path.join("tests", "data", "eurosat") - ds = EuroSAT(root=root, bands=EuroSAT.rgb_bands, split="train") + root = os.path.join('tests', 'data', 'eurosat') + ds = EuroSAT(root=root, bands=EuroSAT.rgb_bands, split='train') model = RCF( - in_channels=3, features=4, kernel_size=3, mode="empirical", dataset=ds + in_channels=3, features=4, kernel_size=3, mode='empirical', dataset=ds ) model(torch.randn(2, 3, 8, 8)) def test_empirical_no_dataset(self) -> None: match = "dataset must be provided when mode is 'empirical'" with pytest.raises(ValueError, match=match): - RCF(mode="empirical", dataset=None) + RCF(mode='empirical', dataset=None) diff --git a/tests/models/test_resnet.py b/tests/models/test_resnet.py index 90318bb4b0c..ea5397e6099 100644 --- a/tests/models/test_resnet.py +++ b/tests/models/test_resnet.py @@ -29,14 +29,14 @@ def weights(self, request: SubRequest) -> WeightsEnum: def mocked_weights( self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum ) -> WeightsEnum: - path = tmp_path / f"{weights}.pth" - model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"]) + path = tmp_path / f'{weights}.pth' + model = timm.create_model('resnet18', in_chans=weights.meta['in_chans']) torch.save(model.state_dict(), path) try: - monkeypatch.setattr(weights.value, "url", str(path)) + monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: - monkeypatch.setattr(weights, "url", str(path)) - monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + monkeypatch.setattr(weights, 'url', str(path)) + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_resnet(self) -> None: @@ -46,9 +46,9 @@ def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None: resnet18(weights=mocked_weights) def test_transforms(self, mocked_weights: WeightsEnum) -> None: - c = mocked_weights.meta["in_chans"] + c = mocked_weights.meta['in_chans'] sample = { - "image": torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) + 'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) } mocked_weights.transforms(sample) @@ -66,14 +66,14 @@ def weights(self, request: SubRequest) -> WeightsEnum: def mocked_weights( self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum ) -> WeightsEnum: - path = tmp_path / f"{weights}.pth" - model = timm.create_model("resnet50", in_chans=weights.meta["in_chans"]) + path = tmp_path / f'{weights}.pth' + model = timm.create_model('resnet50', in_chans=weights.meta['in_chans']) torch.save(model.state_dict(), path) try: - monkeypatch.setattr(weights.value, "url", str(path)) + monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: - monkeypatch.setattr(weights, "url", str(path)) - monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + monkeypatch.setattr(weights, 'url', str(path)) + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_resnet(self) -> None: @@ -83,9 +83,9 @@ def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None: resnet50(weights=mocked_weights) def test_transforms(self, mocked_weights: WeightsEnum) -> None: - c = mocked_weights.meta["in_chans"] + c = mocked_weights.meta['in_chans'] sample = { - "image": torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) + 'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) } mocked_weights.transforms(sample) diff --git a/tests/models/test_swin.py b/tests/models/test_swin.py index e781b16b146..489b3642ce7 100644 --- a/tests/models/test_swin.py +++ b/tests/models/test_swin.py @@ -28,14 +28,14 @@ def weights(self, request: SubRequest) -> WeightsEnum: def mocked_weights( self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum ) -> WeightsEnum: - path = tmp_path / f"{weights}.pth" + path = tmp_path / f'{weights}.pth' model = torchvision.models.swin_v2_b() torch.save(model.state_dict(), path) try: - monkeypatch.setattr(weights.value, "url", str(path)) + monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: - monkeypatch.setattr(weights, "url", str(path)) - monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + monkeypatch.setattr(weights, 'url', str(path)) + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_swin_v2_b(self) -> None: @@ -45,9 +45,9 @@ def test_swin_v2_b_weights(self, mocked_weights: WeightsEnum) -> None: swin_v2_b(weights=mocked_weights) def test_transforms(self, mocked_weights: WeightsEnum) -> None: - c = mocked_weights.meta["in_chans"] + c = mocked_weights.meta['in_chans'] sample = { - "image": torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256) + 'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256) } mocked_weights.transforms(sample) diff --git a/tests/models/test_vit.py b/tests/models/test_vit.py index ea4b509ca95..b69e2398996 100644 --- a/tests/models/test_vit.py +++ b/tests/models/test_vit.py @@ -29,16 +29,16 @@ def weights(self, request: SubRequest) -> WeightsEnum: def mocked_weights( self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum ) -> WeightsEnum: - path = tmp_path / f"{weights}.pth" + path = tmp_path / f'{weights}.pth' model = timm.create_model( - weights.meta["model"], in_chans=weights.meta["in_chans"] + weights.meta['model'], in_chans=weights.meta['in_chans'] ) torch.save(model.state_dict(), path) try: - monkeypatch.setattr(weights.value, "url", str(path)) + monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: - monkeypatch.setattr(weights, "url", str(path)) - monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + monkeypatch.setattr(weights, 'url', str(path)) + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_vit(self) -> None: @@ -48,9 +48,9 @@ def test_vit_weights(self, mocked_weights: WeightsEnum) -> None: vit_small_patch16_224(weights=mocked_weights) def test_transforms(self, mocked_weights: WeightsEnum) -> None: - c = mocked_weights.meta["in_chans"] + c = mocked_weights.meta['in_chans'] sample = { - "image": torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) + 'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) } mocked_weights.transforms(sample) diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 05b0cf0b3a8..59c8aaa00be 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -33,17 +33,17 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: self.res = res def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: - return {"index": query} + return {'index': query} class TestBatchGeoSampler: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> CustomGeoDataset: ds = CustomGeoDataset() ds.index.insert(0, (0, 100, 200, 300, 400, 500)) return ds - @pytest.fixture(scope="function") + @pytest.fixture(scope='function') def sampler(self) -> CustomBatchGeoSampler: return CustomBatchGeoSampler() @@ -55,7 +55,7 @@ def test_len(self, sampler: CustomBatchGeoSampler) -> None: assert len(sampler) == 2 @pytest.mark.slow - @pytest.mark.parametrize("num_workers", [0, 1, 2]) + @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( self, dataset: CustomGeoDataset, @@ -77,7 +77,7 @@ def test_abstract(self, dataset: CustomGeoDataset) -> None: class TestRandomBatchGeoSampler: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> CustomGeoDataset: ds = CustomGeoDataset() ds.index.insert(0, (0, 100, 200, 300, 400, 500)) @@ -85,7 +85,7 @@ def dataset(self) -> CustomGeoDataset: return ds @pytest.fixture( - scope="function", + scope='function', params=product([3, 4.5, (2, 2), (3, 4.5), (4.5, 3)], [Units.PIXELS, Units.CRS]), ) def sampler( @@ -145,7 +145,7 @@ def test_weighted_sampling(self) -> None: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) @pytest.mark.slow - @pytest.mark.parametrize("num_workers", [0, 1, 2]) + @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( self, dataset: CustomGeoDataset, diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index f43db1ce195..1416368098a 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -40,17 +40,17 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: self.res = res def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: - return {"index": query} + return {'index': query} class TestGeoSampler: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> CustomGeoDataset: ds = CustomGeoDataset() ds.index.insert(0, (0, 100, 200, 300, 400, 500)) return ds - @pytest.fixture(scope="function") + @pytest.fixture(scope='function') def sampler(self) -> CustomGeoSampler: return CustomGeoSampler() @@ -65,7 +65,7 @@ def test_abstract(self, dataset: CustomGeoDataset) -> None: GeoSampler(dataset) # type: ignore[abstract] @pytest.mark.slow - @pytest.mark.parametrize("num_workers", [0, 1, 2]) + @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( self, dataset: CustomGeoDataset, sampler: CustomGeoSampler, num_workers: int ) -> None: @@ -77,7 +77,7 @@ def test_dataloader( class TestRandomGeoSampler: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> CustomGeoDataset: ds = CustomGeoDataset() ds.index.insert(0, (0, 100, 200, 300, 400, 500)) @@ -85,7 +85,7 @@ def dataset(self) -> CustomGeoDataset: return ds @pytest.fixture( - scope="function", + scope='function', params=product([3, 4.5, (2, 2), (3, 4.5), (4.5, 3)], [Units.PIXELS, Units.CRS]), ) def sampler( @@ -140,7 +140,7 @@ def test_weighted_sampling(self) -> None: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) @pytest.mark.slow - @pytest.mark.parametrize("num_workers", [0, 1, 2]) + @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( self, dataset: CustomGeoDataset, sampler: RandomGeoSampler, num_workers: int ) -> None: @@ -152,7 +152,7 @@ def test_dataloader( class TestGridGeoSampler: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> CustomGeoDataset: ds = CustomGeoDataset() ds.index.insert(0, (0, 100, 200, 300, 400, 500)) @@ -160,7 +160,7 @@ def dataset(self) -> CustomGeoDataset: return ds @pytest.fixture( - scope="function", + scope='function', params=product( [ (8, 1), @@ -244,7 +244,7 @@ def test_float_multiple(self) -> None: assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) @pytest.mark.slow - @pytest.mark.parametrize("num_workers", [0, 1, 2]) + @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( self, dataset: CustomGeoDataset, sampler: GridGeoSampler, num_workers: int ) -> None: @@ -256,14 +256,14 @@ def test_dataloader( class TestPreChippedGeoSampler: - @pytest.fixture(scope="class") + @pytest.fixture(scope='class') def dataset(self) -> CustomGeoDataset: ds = CustomGeoDataset() ds.index.insert(0, (0, 20, 0, 20, 0, 20)) ds.index.insert(1, (0, 30, 0, 30, 0, 30)) return ds - @pytest.fixture(scope="function") + @pytest.fixture(scope='function') def sampler(self, dataset: CustomGeoDataset) -> PreChippedGeoSampler: return PreChippedGeoSampler(dataset, shuffle=True) @@ -289,7 +289,7 @@ def test_point_data(self) -> None: continue @pytest.mark.slow - @pytest.mark.parametrize("num_workers", [0, 1, 2]) + @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( self, dataset: CustomGeoDataset, sampler: PreChippedGeoSampler, num_workers: int ) -> None: diff --git a/tests/samplers/test_utils.py b/tests/samplers/test_utils.py index 0d15822bec5..90e888b55cc 100644 --- a/tests/samplers/test_utils.py +++ b/tests/samplers/test_utils.py @@ -13,7 +13,7 @@ @pytest.mark.parametrize( - "size,stride,expected", + 'size,stride,expected', [ # size == bounds (10, 1, 1), diff --git a/tests/test_main.py b/tests/test_main.py index ab811e1cd6a..eee0ea95cc2 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -6,4 +6,4 @@ def test_help() -> None: - subprocess.run([sys.executable, "-m", "torchgeo", "--help"], check=True) + subprocess.run([sys.executable, '-m', 'torchgeo', '--help'], check=True) diff --git a/tests/trainers/conftest.py b/tests/trainers/conftest.py index 614babe7102..a3ce098ae7d 100644 --- a/tests/trainers/conftest.py +++ b/tests/trainers/conftest.py @@ -14,38 +14,38 @@ @pytest.fixture( - scope="package", params=[True, pytest.param(False, marks=pytest.mark.slow)] + scope='package', params=[True, pytest.param(False, marks=pytest.mark.slow)] ) def fast_dev_run(request: SubRequest) -> bool: flag: bool = request.param return flag -@pytest.fixture(scope="package") +@pytest.fixture(scope='package') def model() -> Module: model: Module = torchvision.models.resnet18(weights=None) return model -@pytest.fixture(scope="package") +@pytest.fixture(scope='package') def state_dict(model: Module) -> dict[str, Tensor]: return model.state_dict() -@pytest.fixture(params=["model", "backbone"]) +@pytest.fixture(params=['model', 'backbone']) def checkpoint( state_dict: dict[str, Tensor], request: SubRequest, tmp_path: Path ) -> str: - if request.param == "model": - state_dict = OrderedDict({"model." + k: v for k, v in state_dict.items()}) + if request.param == 'model': + state_dict = OrderedDict({'model.' + k: v for k, v in state_dict.items()}) else: state_dict = OrderedDict( - {"model.backbone.model." + k: v for k, v in state_dict.items()} + {'model.backbone.model.' + k: v for k, v in state_dict.items()} ) checkpoint = { - "hyper_parameters": {request.param: "resnet18"}, - "state_dict": state_dict, + 'hyper_parameters': {request.param: 'resnet18'}, + 'state_dict': state_dict, } - path = os.path.join(str(tmp_path), f"model_{request.param}.ckpt") + path = os.path.join(str(tmp_path), f'model_{request.param}.ckpt') torch.save(checkpoint, path) return path diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 235a1681a70..64143759a3c 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -45,42 +45,42 @@ def test_custom_augment_fn(self) -> None: class TestBYOLTask: @pytest.mark.parametrize( - "name", + 'name', [ - "chesapeake_cvpr_prior_byol", - "seco_byol_1", - "seco_byol_2", - "ssl4eo_l_byol_1", - "ssl4eo_l_byol_2", - "ssl4eo_s12_byol_1", - "ssl4eo_s12_byol_2", + 'chesapeake_cvpr_prior_byol', + 'seco_byol_1', + 'seco_byol_2', + 'ssl4eo_l_byol_1', + 'ssl4eo_l_byol_2', + 'ssl4eo_s12_byol_1', + 'ssl4eo_s12_byol_2', ], ) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - config = os.path.join("tests", "conf", name + ".yaml") + config = os.path.join('tests', 'conf', name + '.yaml') - if name.startswith("seco"): - monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2) + if name.startswith('seco'): + monkeypatch.setattr(SeasonalContrastS2, '__len__', lambda self: 2) - if name.startswith("ssl4eo_s12"): - monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2) + if name.startswith('ssl4eo_s12'): + monkeypatch.setattr(SSL4EOS12, '__len__', lambda self: 2) args = [ - "--config", + '--config', config, - "--trainer.accelerator", - "cpu", - "--trainer.fast_dev_run", + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', str(fast_dev_run), - "--trainer.max_epochs", - "1", - "--trainer.log_every_n_steps", - "1", + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', ] - main(["fit"] + args) + main(['fit'] + args) @pytest.fixture def weights(self) -> WeightsEnum: @@ -90,48 +90,48 @@ def weights(self) -> WeightsEnum: def mocked_weights( self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum ) -> WeightsEnum: - path = tmp_path / f"{weights}.pth" + path = tmp_path / f'{weights}.pth' model = timm.create_model( - weights.meta["model"], in_chans=weights.meta["in_chans"] + weights.meta['model'], in_chans=weights.meta['in_chans'] ) torch.save(model.state_dict(), path) try: - monkeypatch.setattr(weights.value, "url", str(path)) + monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: - monkeypatch.setattr(weights, "url", str(path)) - monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + monkeypatch.setattr(weights, 'url', str(path)) + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: with pytest.warns(UserWarning): - BYOLTask(model="resnet18", in_channels=13, weights=checkpoint) + BYOLTask(model='resnet18', in_channels=13, weights=checkpoint) def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: BYOLTask( - model=mocked_weights.meta["model"], + model=mocked_weights.meta['model'], weights=mocked_weights, - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], ) def test_weight_str(self, mocked_weights: WeightsEnum) -> None: BYOLTask( - model=mocked_weights.meta["model"], + model=mocked_weights.meta['model'], weights=str(mocked_weights), - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], ) @pytest.mark.slow def test_weight_enum_download(self, weights: WeightsEnum) -> None: BYOLTask( - model=weights.meta["model"], + model=weights.meta['model'], weights=weights, - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], ) @pytest.mark.slow def test_weight_str_download(self, weights: WeightsEnum) -> None: BYOLTask( - model=weights.meta["model"], + model=weights.meta['model'], weights=str(weights), - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], ) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 02183978995..2dbde66cef1 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -44,12 +44,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class PredictClassificationDataModule(EuroSATDataModule): def setup(self, stage: str) -> None: - self.predict_dataset = EuroSAT(split="test", **self.kwargs) + self.predict_dataset = EuroSAT(split='test', **self.kwargs) class PredictMultiLabelClassificationDataModule(BigEarthNetDataModule): def setup(self, stage: str) -> None: - self.predict_dataset = BigEarthNet(split="test", **self.kwargs) + self.predict_dataset = BigEarthNet(split='test', **self.kwargs) def create_model(*args: Any, **kwargs: Any) -> Module: @@ -71,49 +71,49 @@ def plot_missing_bands(*args: Any, **kwargs: Any) -> None: class TestClassificationTask: @pytest.mark.parametrize( - "name", + 'name', [ - "eurosat", - "eurosat100", - "fire_risk", - "resisc45", - "so2sat_all", - "so2sat_s1", - "so2sat_s2", - "so2sat_rgb", - "ucmerced", + 'eurosat', + 'eurosat100', + 'fire_risk', + 'resisc45', + 'so2sat_all', + 'so2sat_s1', + 'so2sat_s2', + 'so2sat_rgb', + 'ucmerced', ], ) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - if name.startswith("so2sat"): - pytest.importorskip("h5py", minversion="3") + if name.startswith('so2sat'): + pytest.importorskip('h5py', minversion='3') - config = os.path.join("tests", "conf", name + ".yaml") + config = os.path.join('tests', 'conf', name + '.yaml') - monkeypatch.setattr(timm, "create_model", create_model) + monkeypatch.setattr(timm, 'create_model', create_model) args = [ - "--config", + '--config', config, - "--trainer.accelerator", - "cpu", - "--trainer.fast_dev_run", + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', str(fast_dev_run), - "--trainer.max_epochs", - "1", - "--trainer.log_every_n_steps", - "1", + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', ] - main(["fit"] + args) + main(['fit'] + args) try: - main(["test"] + args) + main(['test'] + args) except MisconfigurationException: pass try: - main(["predict"] + args) + main(['predict'] + args) except MisconfigurationException: pass @@ -125,73 +125,73 @@ def weights(self) -> WeightsEnum: def mocked_weights( self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum ) -> WeightsEnum: - path = tmp_path / f"{weights}.pth" + path = tmp_path / f'{weights}.pth' model = timm.create_model( - weights.meta["model"], in_chans=weights.meta["in_chans"] + weights.meta['model'], in_chans=weights.meta['in_chans'] ) torch.save(model.state_dict(), path) try: - monkeypatch.setattr(weights.value, "url", str(path)) + monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: - monkeypatch.setattr(weights, "url", str(path)) - monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + monkeypatch.setattr(weights, 'url', str(path)) + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: with pytest.warns(UserWarning): ClassificationTask( - model="resnet18", weights=checkpoint, in_channels=13, num_classes=10 + model='resnet18', weights=checkpoint, in_channels=13, num_classes=10 ) def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: with pytest.warns(UserWarning): ClassificationTask( - model=mocked_weights.meta["model"], + model=mocked_weights.meta['model'], weights=mocked_weights, - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], num_classes=10, ) def test_weight_str(self, mocked_weights: WeightsEnum) -> None: with pytest.warns(UserWarning): ClassificationTask( - model=mocked_weights.meta["model"], + model=mocked_weights.meta['model'], weights=str(mocked_weights), - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], num_classes=10, ) @pytest.mark.slow def test_weight_enum_download(self, weights: WeightsEnum) -> None: ClassificationTask( - model=weights.meta["model"], + model=weights.meta['model'], weights=weights, - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], num_classes=10, ) @pytest.mark.slow def test_weight_str_download(self, weights: WeightsEnum) -> None: ClassificationTask( - model=weights.meta["model"], + model=weights.meta['model'], weights=str(weights), - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], num_classes=10, ) def test_invalid_loss(self) -> None: match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): - ClassificationTask(model="resnet18", loss="invalid_loss") + ClassificationTask(model='resnet18', loss='invalid_loss') def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(EuroSATDataModule, "plot", plot) + monkeypatch.setattr(EuroSATDataModule, 'plot', plot) datamodule = EuroSATDataModule( - root="tests/data/eurosat", batch_size=1, num_workers=0 + root='tests/data/eurosat', batch_size=1, num_workers=0 ) - model = ClassificationTask(model="resnet18", in_channels=13, num_classes=10) + model = ClassificationTask(model='resnet18', in_channels=13, num_classes=10) trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, @@ -199,13 +199,13 @@ def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> N trainer.validate(model=model, datamodule=datamodule) def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(EuroSATDataModule, "plot", plot_missing_bands) + monkeypatch.setattr(EuroSATDataModule, 'plot', plot_missing_bands) datamodule = EuroSATDataModule( - root="tests/data/eurosat", batch_size=1, num_workers=0 + root='tests/data/eurosat', batch_size=1, num_workers=0 ) - model = ClassificationTask(model="resnet18", in_channels=13, num_classes=10) + model = ClassificationTask(model='resnet18', in_channels=13, num_classes=10) trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, @@ -214,11 +214,11 @@ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: def test_predict(self, fast_dev_run: bool) -> None: datamodule = PredictClassificationDataModule( - root="tests/data/eurosat", batch_size=1, num_workers=0 + root='tests/data/eurosat', batch_size=1, num_workers=0 ) - model = ClassificationTask(model="resnet18", in_channels=13, num_classes=10) + model = ClassificationTask(model='resnet18', in_channels=13, num_classes=10) trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, @@ -226,7 +226,7 @@ def test_predict(self, fast_dev_run: bool) -> None: trainer.predict(model=model, datamodule=datamodule) @pytest.mark.parametrize( - "model_name", ["resnet18", "efficientnetv2_s", "vit_base_patch16_384"] + 'model_name', ['resnet18', 'efficientnetv2_s', 'vit_base_patch16_384'] ) def test_freeze_backbone(self, model_name: str) -> None: model = ClassificationTask(model=model_name, freeze_backbone=True) @@ -238,53 +238,53 @@ def test_freeze_backbone(self, model_name: str) -> None: class TestMultiLabelClassificationTask: @pytest.mark.parametrize( - "name", ["bigearthnet_all", "bigearthnet_s1", "bigearthnet_s2"] + 'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2'] ) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - config = os.path.join("tests", "conf", name + ".yaml") + config = os.path.join('tests', 'conf', name + '.yaml') - monkeypatch.setattr(timm, "create_model", create_model) + monkeypatch.setattr(timm, 'create_model', create_model) args = [ - "--config", + '--config', config, - "--trainer.accelerator", - "cpu", - "--trainer.fast_dev_run", + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', str(fast_dev_run), - "--trainer.max_epochs", - "1", - "--trainer.log_every_n_steps", - "1", + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', ] - main(["fit"] + args) + main(['fit'] + args) try: - main(["test"] + args) + main(['test'] + args) except MisconfigurationException: pass try: - main(["predict"] + args) + main(['predict'] + args) except MisconfigurationException: pass def test_invalid_loss(self) -> None: match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): - MultiLabelClassificationTask(model="resnet18", loss="invalid_loss") + MultiLabelClassificationTask(model='resnet18', loss='invalid_loss') def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(BigEarthNetDataModule, "plot", plot) + monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot) datamodule = BigEarthNetDataModule( - root="tests/data/bigearthnet", batch_size=1, num_workers=0 + root='tests/data/bigearthnet', batch_size=1, num_workers=0 ) model = MultiLabelClassificationTask( - model="resnet18", in_channels=14, num_classes=19, loss="bce" + model='resnet18', in_channels=14, num_classes=19, loss='bce' ) trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, @@ -292,15 +292,15 @@ def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> N trainer.validate(model=model, datamodule=datamodule) def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(BigEarthNetDataModule, "plot", plot_missing_bands) + monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot_missing_bands) datamodule = BigEarthNetDataModule( - root="tests/data/bigearthnet", batch_size=1, num_workers=0 + root='tests/data/bigearthnet', batch_size=1, num_workers=0 ) model = MultiLabelClassificationTask( - model="resnet18", in_channels=14, num_classes=19, loss="bce" + model='resnet18', in_channels=14, num_classes=19, loss='bce' ) trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, @@ -309,13 +309,13 @@ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: def test_predict(self, fast_dev_run: bool) -> None: datamodule = PredictMultiLabelClassificationDataModule( - root="tests/data/bigearthnet", batch_size=1, num_workers=0 + root='tests/data/bigearthnet', batch_size=1, num_workers=0 ) model = MultiLabelClassificationTask( - model="resnet18", in_channels=14, num_classes=19, loss="bce" + model='resnet18', in_channels=14, num_classes=19, loss='bce' ) trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index e4151ac0d29..035bdacc260 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -18,7 +18,7 @@ from torchgeo.trainers import ObjectDetectionTask # MAP metric requires pycocotools to be installed -pytest.importorskip("pycocotools") +pytest.importorskip('pycocotools') class PredictObjectDetectionDataModule(NASAMarineDebrisDataModule): @@ -41,10 +41,10 @@ def forward(self, images: Any, targets: Any = None) -> Any: assert batch_size == len(targets) # use the Linear layer to generate a tensor that has a gradient return { - "loss_classifier": self.fc(torch.rand(1)), - "loss_box_reg": self.fc(torch.rand(1)), - "loss_objectness": self.fc(torch.rand(1)), - "loss_rpn_box_reg": self.fc(torch.rand(1)), + 'loss_classifier': self.fc(torch.rand(1)), + 'loss_box_reg': self.fc(torch.rand(1)), + 'loss_objectness': self.fc(torch.rand(1)), + 'loss_rpn_box_reg': self.fc(torch.rand(1)), } else: # eval mode output = [] @@ -54,9 +54,9 @@ def forward(self, images: Any, targets: Any = None) -> Any: boxes[:, 2:] += 1 output.append( { - "boxes": boxes, - "labels": torch.randint(0, 2, (10,)), - "scores": torch.rand(10), + 'boxes': boxes, + 'labels': torch.randint(0, 2, (10,)), + 'scores': torch.rand(10), } ) return output @@ -67,67 +67,67 @@ def plot(*args: Any, **kwargs: Any) -> None: class TestObjectDetectionTask: - @pytest.mark.parametrize("name", ["nasa_marine_debris", "vhr10"]) - @pytest.mark.parametrize("model_name", ["faster-rcnn", "fcos", "retinanet"]) + @pytest.mark.parametrize('name', ['nasa_marine_debris', 'vhr10']) + @pytest.mark.parametrize('model_name', ['faster-rcnn', 'fcos', 'retinanet']) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, model_name: str, fast_dev_run: bool ) -> None: - config = os.path.join("tests", "conf", name + ".yaml") + config = os.path.join('tests', 'conf', name + '.yaml') monkeypatch.setattr( - torchvision.models.detection, "FasterRCNN", ObjectDetectionTestModel + torchvision.models.detection, 'FasterRCNN', ObjectDetectionTestModel ) monkeypatch.setattr( - torchvision.models.detection, "FCOS", ObjectDetectionTestModel + torchvision.models.detection, 'FCOS', ObjectDetectionTestModel ) monkeypatch.setattr( - torchvision.models.detection, "RetinaNet", ObjectDetectionTestModel + torchvision.models.detection, 'RetinaNet', ObjectDetectionTestModel ) args = [ - "--config", + '--config', config, - "--trainer.accelerator", - "cpu", - "--trainer.fast_dev_run", + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', str(fast_dev_run), - "--trainer.max_epochs", - "1", - "--trainer.log_every_n_steps", - "1", + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', ] - main(["fit"] + args) + main(['fit'] + args) try: - main(["test"] + args) + main(['test'] + args) except MisconfigurationException: pass try: - main(["predict"] + args) + main(['predict'] + args) except MisconfigurationException: pass def test_invalid_model(self) -> None: match = "Model type 'invalid_model' is not valid." with pytest.raises(ValueError, match=match): - ObjectDetectionTask(model="invalid_model") + ObjectDetectionTask(model='invalid_model') def test_invalid_backbone(self) -> None: match = "Backbone type 'invalid_backbone' is not valid." with pytest.raises(ValueError, match=match): - ObjectDetectionTask(backbone="invalid_backbone") + ObjectDetectionTask(backbone='invalid_backbone') def test_pretrained_backbone(self) -> None: - ObjectDetectionTask(backbone="resnet18", weights=True) + ObjectDetectionTask(backbone='resnet18', weights=True) def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(NASAMarineDebrisDataModule, "plot", plot) + monkeypatch.setattr(NASAMarineDebrisDataModule, 'plot', plot) datamodule = NASAMarineDebrisDataModule( - root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 + root='tests/data/nasa_marine_debris', batch_size=1, num_workers=0 ) - model = ObjectDetectionTask(backbone="resnet18", num_classes=2) + model = ObjectDetectionTask(backbone='resnet18', num_classes=2) trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, @@ -135,13 +135,13 @@ def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> N trainer.validate(model=model, datamodule=datamodule) def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(NASAMarineDebrisDataModule, "plot", plot_missing_bands) + monkeypatch.setattr(NASAMarineDebrisDataModule, 'plot', plot_missing_bands) datamodule = NASAMarineDebrisDataModule( - root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 + root='tests/data/nasa_marine_debris', batch_size=1, num_workers=0 ) - model = ObjectDetectionTask(backbone="resnet18", num_classes=2) + model = ObjectDetectionTask(backbone='resnet18', num_classes=2) trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, @@ -150,20 +150,20 @@ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: def test_predict(self, fast_dev_run: bool) -> None: datamodule = PredictObjectDetectionDataModule( - root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 + root='tests/data/nasa_marine_debris', batch_size=1, num_workers=0 ) - model = ObjectDetectionTask(backbone="resnet18", num_classes=2) + model = ObjectDetectionTask(backbone='resnet18', num_classes=2) trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, ) trainer.predict(model=model, datamodule=datamodule) - @pytest.mark.parametrize("model_name", ["faster-rcnn", "fcos", "retinanet"]) + @pytest.mark.parametrize('model_name', ['faster-rcnn', 'fcos', 'retinanet']) def test_freeze_backbone(self, model_name: str) -> None: model = ObjectDetectionTask( - model=model_name, backbone="resnet18", freeze_backbone=True + model=model_name, backbone='resnet18', freeze_backbone=True ) assert not all([param.requires_grad for param in model.model.parameters()]) diff --git a/tests/trainers/test_moco.py b/tests/trainers/test_moco.py index a4d19fc98f2..ba3b5641d6b 100644 --- a/tests/trainers/test_moco.py +++ b/tests/trainers/test_moco.py @@ -32,55 +32,55 @@ def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: class TestMoCoTask: @pytest.mark.parametrize( - "name", + 'name', [ - "chesapeake_cvpr_prior_moco", - "seco_moco_1", - "seco_moco_2", - "ssl4eo_l_moco_1", - "ssl4eo_l_moco_2", - "ssl4eo_s12_moco_1", - "ssl4eo_s12_moco_2", + 'chesapeake_cvpr_prior_moco', + 'seco_moco_1', + 'seco_moco_2', + 'ssl4eo_l_moco_1', + 'ssl4eo_l_moco_2', + 'ssl4eo_s12_moco_1', + 'ssl4eo_s12_moco_2', ], ) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - config = os.path.join("tests", "conf", name + ".yaml") + config = os.path.join('tests', 'conf', name + '.yaml') - if name.startswith("seco"): - monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2) + if name.startswith('seco'): + monkeypatch.setattr(SeasonalContrastS2, '__len__', lambda self: 2) - if name.startswith("ssl4eo_s12"): - monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2) + if name.startswith('ssl4eo_s12'): + monkeypatch.setattr(SSL4EOS12, '__len__', lambda self: 2) - monkeypatch.setattr(timm, "create_model", create_model) + monkeypatch.setattr(timm, 'create_model', create_model) args = [ - "--config", + '--config', config, - "--trainer.accelerator", - "cpu", - "--trainer.fast_dev_run", + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', str(fast_dev_run), - "--trainer.max_epochs", - "1", - "--trainer.log_every_n_steps", - "1", + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', ] - main(["fit"] + args) + main(['fit'] + args) def test_version_warnings(self) -> None: - with pytest.warns(UserWarning, match="MoCo v1 uses a memory bank"): + with pytest.warns(UserWarning, match='MoCo v1 uses a memory bank'): MoCoTask(version=1, layers=2, memory_bank_size=0) - with pytest.warns(UserWarning, match="MoCo v2 only uses 2 layers"): + with pytest.warns(UserWarning, match='MoCo v2 only uses 2 layers'): MoCoTask(version=2, layers=3, memory_bank_size=10) - with pytest.warns(UserWarning, match="MoCo v2 uses a memory bank"): + with pytest.warns(UserWarning, match='MoCo v2 uses a memory bank'): MoCoTask(version=2, layers=2, memory_bank_size=0) - with pytest.warns(UserWarning, match="MoCo v3 uses 3 layers"): + with pytest.warns(UserWarning, match='MoCo v3 uses 3 layers'): MoCoTask(version=3, layers=2, memory_bank_size=0) - with pytest.warns(UserWarning, match="MoCo v3 does not use a memory bank"): + with pytest.warns(UserWarning, match='MoCo v3 does not use a memory bank'): MoCoTask(version=3, layers=3, memory_bank_size=10) @pytest.fixture @@ -91,53 +91,53 @@ def weights(self) -> WeightsEnum: def mocked_weights( self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum ) -> WeightsEnum: - path = tmp_path / f"{weights}.pth" + path = tmp_path / f'{weights}.pth' model = timm.create_model( - weights.meta["model"], in_chans=weights.meta["in_chans"] + weights.meta['model'], in_chans=weights.meta['in_chans'] ) torch.save(model.state_dict(), path) try: - monkeypatch.setattr(weights.value, "url", str(path)) + monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: - monkeypatch.setattr(weights, "url", str(path)) - monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + monkeypatch.setattr(weights, 'url', str(path)) + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: - match = "num classes .* != num classes in pretrained model" + match = 'num classes .* != num classes in pretrained model' with pytest.warns(UserWarning, match=match): - MoCoTask(model="resnet18", weights=checkpoint) + MoCoTask(model='resnet18', weights=checkpoint) def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: - match = "num classes .* != num classes in pretrained model" + match = 'num classes .* != num classes in pretrained model' with pytest.warns(UserWarning, match=match): MoCoTask( - model=mocked_weights.meta["model"], + model=mocked_weights.meta['model'], weights=mocked_weights, - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], ) def test_weight_str(self, mocked_weights: WeightsEnum) -> None: - match = "num classes .* != num classes in pretrained model" + match = 'num classes .* != num classes in pretrained model' with pytest.warns(UserWarning, match=match): MoCoTask( - model=mocked_weights.meta["model"], + model=mocked_weights.meta['model'], weights=str(mocked_weights), - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], ) @pytest.mark.slow def test_weight_enum_download(self, weights: WeightsEnum) -> None: MoCoTask( - model=weights.meta["model"], + model=weights.meta['model'], weights=weights, - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], ) @pytest.mark.slow def test_weight_str_download(self, weights: WeightsEnum) -> None: MoCoTask( - model=weights.meta["model"], + model=weights.meta['model'], weights=str(weights), - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], ) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 3e1aa2f5157..ef3c6164d98 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -43,7 +43,7 @@ def __init__(self, in_chans: int = 3, num_classes: int = 1, **kwargs: Any) -> No class PredictRegressionDataModule(TropicalCycloneDataModule): def setup(self, stage: str) -> None: - self.predict_dataset = TropicalCyclone(split="test", **self.kwargs) + self.predict_dataset = TropicalCyclone(split='test', **self.kwargs) def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: @@ -65,38 +65,38 @@ def create_model(*args: Any, **kwargs: Any) -> Module: return RegressionTestModel(**kwargs) @pytest.mark.parametrize( - "name", ["cowc_counting", "cyclone", "sustainbench_crop_yield", "skippd"] + 'name', ['cowc_counting', 'cyclone', 'sustainbench_crop_yield', 'skippd'] ) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - if name == "skippd": - pytest.importorskip("h5py", minversion="3") + if name == 'skippd': + pytest.importorskip('h5py', minversion='3') - config = os.path.join("tests", "conf", name + ".yaml") + config = os.path.join('tests', 'conf', name + '.yaml') - monkeypatch.setattr(timm, "create_model", self.create_model) + monkeypatch.setattr(timm, 'create_model', self.create_model) args = [ - "--config", + '--config', config, - "--trainer.accelerator", - "cpu", - "--trainer.fast_dev_run", + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', str(fast_dev_run), - "--trainer.max_epochs", - "1", - "--trainer.log_every_n_steps", - "1", + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', ] - main(["fit"] + args) + main(['fit'] + args) try: - main(["test"] + args) + main(['test'] + args) except MisconfigurationException: pass try: - main(["predict"] + args) + main(['predict'] + args) except MisconfigurationException: pass @@ -108,62 +108,62 @@ def weights(self) -> WeightsEnum: def mocked_weights( self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum ) -> WeightsEnum: - path = tmp_path / f"{weights}.pth" + path = tmp_path / f'{weights}.pth' model = timm.create_model( - weights.meta["model"], in_chans=weights.meta["in_chans"] + weights.meta['model'], in_chans=weights.meta['in_chans'] ) torch.save(model.state_dict(), path) try: - monkeypatch.setattr(weights.value, "url", str(path)) + monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: - monkeypatch.setattr(weights, "url", str(path)) - monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + monkeypatch.setattr(weights, 'url', str(path)) + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: with pytest.warns(UserWarning): - RegressionTask(model="resnet18", weights=checkpoint) + RegressionTask(model='resnet18', weights=checkpoint) def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: with pytest.warns(UserWarning): RegressionTask( - model=mocked_weights.meta["model"], + model=mocked_weights.meta['model'], weights=mocked_weights, - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], ) def test_weight_str(self, mocked_weights: WeightsEnum) -> None: with pytest.warns(UserWarning): RegressionTask( - model=mocked_weights.meta["model"], + model=mocked_weights.meta['model'], weights=str(mocked_weights), - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], ) @pytest.mark.slow def test_weight_enum_download(self, weights: WeightsEnum) -> None: RegressionTask( - model=weights.meta["model"], + model=weights.meta['model'], weights=weights, - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], ) @pytest.mark.slow def test_weight_str_download(self, weights: WeightsEnum) -> None: RegressionTask( - model=weights.meta["model"], + model=weights.meta['model'], weights=str(weights), - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], ) def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(TropicalCycloneDataModule, "plot", plot) + monkeypatch.setattr(TropicalCycloneDataModule, 'plot', plot) datamodule = TropicalCycloneDataModule( - root="tests/data/cyclone", batch_size=1, num_workers=0 + root='tests/data/cyclone', batch_size=1, num_workers=0 ) - model = RegressionTask(model="resnet18") + model = RegressionTask(model='resnet18') trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, @@ -171,13 +171,13 @@ def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> N trainer.validate(model=model, datamodule=datamodule) def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(TropicalCycloneDataModule, "plot", plot_missing_bands) + monkeypatch.setattr(TropicalCycloneDataModule, 'plot', plot_missing_bands) datamodule = TropicalCycloneDataModule( - root="tests/data/cyclone", batch_size=1, num_workers=0 + root='tests/data/cyclone', batch_size=1, num_workers=0 ) - model = RegressionTask(model="resnet18") + model = RegressionTask(model='resnet18') trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, @@ -186,11 +186,11 @@ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: def test_predict(self, fast_dev_run: bool) -> None: datamodule = PredictRegressionDataModule( - root="tests/data/cyclone", batch_size=1, num_workers=0 + root='tests/data/cyclone', batch_size=1, num_workers=0 ) - model = RegressionTask(model="resnet18") + model = RegressionTask(model='resnet18') trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, @@ -200,10 +200,10 @@ def test_predict(self, fast_dev_run: bool) -> None: def test_invalid_loss(self) -> None: match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): - RegressionTask(model="resnet18", loss="invalid_loss") + RegressionTask(model='resnet18', loss='invalid_loss') @pytest.mark.parametrize( - "model_name", ["resnet18", "efficientnetv2_s", "vit_base_patch16_384"] + 'model_name', ['resnet18', 'efficientnetv2_s', 'vit_base_patch16_384'] ) def test_freeze_backbone(self, model_name: str) -> None: model = RegressionTask(model=model_name, freeze_backbone=True) @@ -218,42 +218,42 @@ class TestPixelwiseRegressionTask: def create_model(*args: Any, **kwargs: Any) -> Module: return PixelwiseRegressionTestModel(**kwargs) - @pytest.mark.parametrize("name", ["inria_unet", "inria_deeplab", "inria_fcn"]) + @pytest.mark.parametrize('name', ['inria_unet', 'inria_deeplab', 'inria_fcn']) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - config = os.path.join("tests", "conf", name + ".yaml") + config = os.path.join('tests', 'conf', name + '.yaml') - monkeypatch.setattr(smp, "Unet", self.create_model) - monkeypatch.setattr(smp, "DeepLabV3Plus", self.create_model) + monkeypatch.setattr(smp, 'Unet', self.create_model) + monkeypatch.setattr(smp, 'DeepLabV3Plus', self.create_model) args = [ - "--config", + '--config', config, - "--trainer.accelerator", - "cpu", - "--trainer.fast_dev_run", + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', str(fast_dev_run), - "--trainer.max_epochs", - "1", - "--trainer.log_every_n_steps", - "1", + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', ] - main(["fit"] + args) + main(['fit'] + args) try: - main(["test"] + args) + main(['test'] + args) except MisconfigurationException: pass try: - main(["predict"] + args) + main(['predict'] + args) except MisconfigurationException: pass def test_invalid_model(self) -> None: match = "Model type 'invalid_model' is not valid." with pytest.raises(ValueError, match=match): - PixelwiseRegressionTask(model="invalid_model") + PixelwiseRegressionTask(model='invalid_model') @pytest.fixture def weights(self) -> WeightsEnum: @@ -263,58 +263,58 @@ def weights(self) -> WeightsEnum: def mocked_weights( self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum ) -> WeightsEnum: - path = tmp_path / f"{weights}.pth" + path = tmp_path / f'{weights}.pth' model = timm.create_model( - weights.meta["model"], in_chans=weights.meta["in_chans"] + weights.meta['model'], in_chans=weights.meta['in_chans'] ) torch.save(model.state_dict(), path) try: - monkeypatch.setattr(weights.value, "url", str(path)) + monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: - monkeypatch.setattr(weights, "url", str(path)) - monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + monkeypatch.setattr(weights, 'url', str(path)) + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: - PixelwiseRegressionTask(model="unet", backbone="resnet18", weights=checkpoint) + PixelwiseRegressionTask(model='unet', backbone='resnet18', weights=checkpoint) def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: PixelwiseRegressionTask( - model="unet", - backbone=mocked_weights.meta["model"], + model='unet', + backbone=mocked_weights.meta['model'], weights=mocked_weights, - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], ) def test_weight_str(self, mocked_weights: WeightsEnum) -> None: PixelwiseRegressionTask( - model="unet", - backbone=mocked_weights.meta["model"], + model='unet', + backbone=mocked_weights.meta['model'], weights=str(mocked_weights), - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], ) @pytest.mark.slow def test_weight_enum_download(self, weights: WeightsEnum) -> None: PixelwiseRegressionTask( - model="unet", - backbone=weights.meta["model"], + model='unet', + backbone=weights.meta['model'], weights=weights, - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], ) @pytest.mark.slow def test_weight_str_download(self, weights: WeightsEnum) -> None: PixelwiseRegressionTask( - model="unet", - backbone=weights.meta["model"], + model='unet', + backbone=weights.meta['model'], weights=str(weights), - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], ) - @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"]) + @pytest.mark.parametrize('model_name', ['unet', 'deeplabv3+']) @pytest.mark.parametrize( - "backbone", ["resnet18", "mobilenet_v2", "efficientnet-b0"] + 'backbone', ['resnet18', 'mobilenet_v2', 'efficientnet-b0'] ) def test_freeze_backbone(self, model_name: str, backbone: str) -> None: model = PixelwiseRegressionTask( @@ -331,10 +331,10 @@ def test_freeze_backbone(self, model_name: str, backbone: str) -> None: ] ) - @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"]) + @pytest.mark.parametrize('model_name', ['unet', 'deeplabv3+']) def test_freeze_decoder(self, model_name: str) -> None: model = PixelwiseRegressionTask( - model=model_name, backbone="resnet18", freeze_decoder=True + model=model_name, backbone='resnet18', freeze_decoder=True ) assert all( [param.requires_grad is False for param in model.model.decoder.parameters()] diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 42fe7693e0b..e5c15269669 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -53,71 +53,71 @@ def plot_missing_bands(*args: Any, **kwargs: Any) -> None: class TestSemanticSegmentationTask: @pytest.mark.parametrize( - "name", + 'name', [ - "agrifieldnet", - "chabud", - "chesapeake_cvpr_5", - "chesapeake_cvpr_7", - "deepglobelandcover", - "etci2021", - "gid15", - "inria", - "l7irish", - "l8biome", - "landcoverai", - "loveda", - "naipchesapeake", - "potsdam2d", - "sen12ms_all", - "sen12ms_s1", - "sen12ms_s2_all", - "sen12ms_s2_reduced", - "sentinel2_cdl", - "sentinel2_eurocrops", - "sentinel2_nccm", - "sentinel2_south_america_soybean", - "spacenet1", - "ssl4eo_l_benchmark_cdl", - "ssl4eo_l_benchmark_nlcd", - "vaihingen2d", + 'agrifieldnet', + 'chabud', + 'chesapeake_cvpr_5', + 'chesapeake_cvpr_7', + 'deepglobelandcover', + 'etci2021', + 'gid15', + 'inria', + 'l7irish', + 'l8biome', + 'landcoverai', + 'loveda', + 'naipchesapeake', + 'potsdam2d', + 'sen12ms_all', + 'sen12ms_s1', + 'sen12ms_s2_all', + 'sen12ms_s2_reduced', + 'sentinel2_cdl', + 'sentinel2_eurocrops', + 'sentinel2_nccm', + 'sentinel2_south_america_soybean', + 'spacenet1', + 'ssl4eo_l_benchmark_cdl', + 'ssl4eo_l_benchmark_nlcd', + 'vaihingen2d', ], ) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - if name == "naipchesapeake": - pytest.importorskip("zipfile_deflate64") + if name == 'naipchesapeake': + pytest.importorskip('zipfile_deflate64') - if name == "landcoverai": - sha256 = "ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b" - monkeypatch.setattr(LandCoverAI, "sha256", sha256) + if name == 'landcoverai': + sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' + monkeypatch.setattr(LandCoverAI, 'sha256', sha256) - config = os.path.join("tests", "conf", name + ".yaml") + config = os.path.join('tests', 'conf', name + '.yaml') - monkeypatch.setattr(smp, "Unet", create_model) - monkeypatch.setattr(smp, "DeepLabV3Plus", create_model) + monkeypatch.setattr(smp, 'Unet', create_model) + monkeypatch.setattr(smp, 'DeepLabV3Plus', create_model) args = [ - "--config", + '--config', config, - "--trainer.accelerator", - "cpu", - "--trainer.fast_dev_run", + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', str(fast_dev_run), - "--trainer.max_epochs", - "1", - "--trainer.log_every_n_steps", - "1", + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', ] - main(["fit"] + args) + main(['fit'] + args) try: - main(["test"] + args) + main(['test'] + args) except MisconfigurationException: pass try: - main(["predict"] + args) + main(['predict'] + args) except MisconfigurationException: pass @@ -129,71 +129,71 @@ def weights(self) -> WeightsEnum: def mocked_weights( self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum ) -> WeightsEnum: - path = tmp_path / f"{weights}.pth" + path = tmp_path / f'{weights}.pth' model = timm.create_model( - weights.meta["model"], in_chans=weights.meta["in_chans"] + weights.meta['model'], in_chans=weights.meta['in_chans'] ) torch.save(model.state_dict(), path) try: - monkeypatch.setattr(weights.value, "url", str(path)) + monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: - monkeypatch.setattr(weights, "url", str(path)) - monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + monkeypatch.setattr(weights, 'url', str(path)) + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: - SemanticSegmentationTask(backbone="resnet18", weights=checkpoint, num_classes=6) + SemanticSegmentationTask(backbone='resnet18', weights=checkpoint, num_classes=6) def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: SemanticSegmentationTask( - backbone=mocked_weights.meta["model"], + backbone=mocked_weights.meta['model'], weights=mocked_weights, - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], ) def test_weight_str(self, mocked_weights: WeightsEnum) -> None: SemanticSegmentationTask( - backbone=mocked_weights.meta["model"], + backbone=mocked_weights.meta['model'], weights=str(mocked_weights), - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], ) @pytest.mark.slow def test_weight_enum_download(self, weights: WeightsEnum) -> None: SemanticSegmentationTask( - backbone=weights.meta["model"], + backbone=weights.meta['model'], weights=weights, - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], ) @pytest.mark.slow def test_weight_str_download(self, weights: WeightsEnum) -> None: SemanticSegmentationTask( - backbone=weights.meta["model"], + backbone=weights.meta['model'], weights=str(weights), - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], ) def test_invalid_model(self) -> None: match = "Model type 'invalid_model' is not valid." with pytest.raises(ValueError, match=match): - SemanticSegmentationTask(model="invalid_model") + SemanticSegmentationTask(model='invalid_model') def test_invalid_loss(self) -> None: match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): - SemanticSegmentationTask(loss="invalid_loss") + SemanticSegmentationTask(loss='invalid_loss') def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(SEN12MSDataModule, "plot", plot) + monkeypatch.setattr(SEN12MSDataModule, 'plot', plot) datamodule = SEN12MSDataModule( - root="tests/data/sen12ms", batch_size=1, num_workers=0 + root='tests/data/sen12ms', batch_size=1, num_workers=0 ) model = SemanticSegmentationTask( - backbone="resnet18", in_channels=15, num_classes=6 + backbone='resnet18', in_channels=15, num_classes=6 ) trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, @@ -201,24 +201,24 @@ def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> N trainer.validate(model=model, datamodule=datamodule) def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(SEN12MSDataModule, "plot", plot_missing_bands) + monkeypatch.setattr(SEN12MSDataModule, 'plot', plot_missing_bands) datamodule = SEN12MSDataModule( - root="tests/data/sen12ms", batch_size=1, num_workers=0 + root='tests/data/sen12ms', batch_size=1, num_workers=0 ) model = SemanticSegmentationTask( - backbone="resnet18", in_channels=15, num_classes=6 + backbone='resnet18', in_channels=15, num_classes=6 ) trainer = Trainer( - accelerator="cpu", + accelerator='cpu', fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1, ) trainer.validate(model=model, datamodule=datamodule) - @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"]) + @pytest.mark.parametrize('model_name', ['unet', 'deeplabv3+']) @pytest.mark.parametrize( - "backbone", ["resnet18", "mobilenet_v2", "efficientnet-b0"] + 'backbone', ['resnet18', 'mobilenet_v2', 'efficientnet-b0'] ) def test_freeze_backbone(self, model_name: str, backbone: str) -> None: model = SemanticSegmentationTask( @@ -235,7 +235,7 @@ def test_freeze_backbone(self, model_name: str, backbone: str) -> None: ] ) - @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"]) + @pytest.mark.parametrize('model_name', ['unet', 'deeplabv3+']) def test_freeze_decoder(self, model_name: str) -> None: model = SemanticSegmentationTask(model=model_name, freeze_decoder=True) assert all( diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py index 6d15931ef4d..b3cbee1fcab 100644 --- a/tests/trainers/test_simclr.py +++ b/tests/trainers/test_simclr.py @@ -32,53 +32,53 @@ def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: class TestSimCLRTask: @pytest.mark.parametrize( - "name", + 'name', [ - "chesapeake_cvpr_prior_simclr", - "seco_simclr_1", - "seco_simclr_2", - "ssl4eo_l_simclr_1", - "ssl4eo_l_simclr_2", - "ssl4eo_s12_simclr_1", - "ssl4eo_s12_simclr_2", + 'chesapeake_cvpr_prior_simclr', + 'seco_simclr_1', + 'seco_simclr_2', + 'ssl4eo_l_simclr_1', + 'ssl4eo_l_simclr_2', + 'ssl4eo_s12_simclr_1', + 'ssl4eo_s12_simclr_2', ], ) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - config = os.path.join("tests", "conf", name + ".yaml") + config = os.path.join('tests', 'conf', name + '.yaml') - if name.startswith("seco"): - monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2) + if name.startswith('seco'): + monkeypatch.setattr(SeasonalContrastS2, '__len__', lambda self: 2) - if name.startswith("ssl4eo_s12"): - monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2) + if name.startswith('ssl4eo_s12'): + monkeypatch.setattr(SSL4EOS12, '__len__', lambda self: 2) - monkeypatch.setattr(timm, "create_model", create_model) + monkeypatch.setattr(timm, 'create_model', create_model) args = [ - "--config", + '--config', config, - "--trainer.accelerator", - "cpu", - "--trainer.fast_dev_run", + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', str(fast_dev_run), - "--trainer.max_epochs", - "1", - "--trainer.log_every_n_steps", - "1", + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', ] - main(["fit"] + args) + main(['fit'] + args) def test_version_warnings(self) -> None: - with pytest.warns(UserWarning, match="SimCLR v1 only uses 2 layers"): + with pytest.warns(UserWarning, match='SimCLR v1 only uses 2 layers'): SimCLRTask(version=1, layers=3, memory_bank_size=0) - with pytest.warns(UserWarning, match="SimCLR v1 does not use a memory bank"): + with pytest.warns(UserWarning, match='SimCLR v1 does not use a memory bank'): SimCLRTask(version=1, layers=2, memory_bank_size=10) - with pytest.warns(UserWarning, match=r"SimCLR v2 uses 3\+ layers"): + with pytest.warns(UserWarning, match=r'SimCLR v2 uses 3\+ layers'): SimCLRTask(version=2, layers=2, memory_bank_size=10) - with pytest.warns(UserWarning, match="SimCLR v2 uses a memory bank"): + with pytest.warns(UserWarning, match='SimCLR v2 uses a memory bank'): SimCLRTask(version=2, layers=3, memory_bank_size=0) @pytest.fixture @@ -89,53 +89,53 @@ def weights(self) -> WeightsEnum: def mocked_weights( self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum ) -> WeightsEnum: - path = tmp_path / f"{weights}.pth" + path = tmp_path / f'{weights}.pth' model = timm.create_model( - weights.meta["model"], in_chans=weights.meta["in_chans"] + weights.meta['model'], in_chans=weights.meta['in_chans'] ) torch.save(model.state_dict(), path) try: - monkeypatch.setattr(weights.value, "url", str(path)) + monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: - monkeypatch.setattr(weights, "url", str(path)) - monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + monkeypatch.setattr(weights, 'url', str(path)) + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: - match = "num classes .* != num classes in pretrained model" + match = 'num classes .* != num classes in pretrained model' with pytest.warns(UserWarning, match=match): - SimCLRTask(model="resnet18", weights=checkpoint) + SimCLRTask(model='resnet18', weights=checkpoint) def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: - match = "num classes .* != num classes in pretrained model" + match = 'num classes .* != num classes in pretrained model' with pytest.warns(UserWarning, match=match): SimCLRTask( - model=mocked_weights.meta["model"], + model=mocked_weights.meta['model'], weights=mocked_weights, - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], ) def test_weight_str(self, mocked_weights: WeightsEnum) -> None: - match = "num classes .* != num classes in pretrained model" + match = 'num classes .* != num classes in pretrained model' with pytest.warns(UserWarning, match=match): SimCLRTask( - model=mocked_weights.meta["model"], + model=mocked_weights.meta['model'], weights=str(mocked_weights), - in_channels=mocked_weights.meta["in_chans"], + in_channels=mocked_weights.meta['in_chans'], ) @pytest.mark.slow def test_weight_enum_download(self, weights: WeightsEnum) -> None: SimCLRTask( - model=weights.meta["model"], + model=weights.meta['model'], weights=weights, - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], ) @pytest.mark.slow def test_weight_str_download(self, weights: WeightsEnum) -> None: SimCLRTask( - model=weights.meta["model"], + model=weights.meta['model'], weights=str(weights), - in_channels=weights.meta["in_chans"], + in_channels=weights.meta['in_chans'], ) diff --git a/tests/trainers/test_utils.py b/tests/trainers/test_utils.py index f324227a3d4..0b5fbe15b55 100644 --- a/tests/trainers/test_utils.py +++ b/tests/trainers/test_utils.py @@ -25,17 +25,17 @@ def test_extract_backbone(checkpoint: str) -> None: def test_extract_backbone_unsupported_model(tmp_path: Path) -> None: - checkpoint = {"hyper_parameters": {"some_unsupported_model": "resnet18"}} - path = os.path.join(str(tmp_path), "dummy.ckpt") + checkpoint = {'hyper_parameters': {'some_unsupported_model': 'resnet18'}} + path = os.path.join(str(tmp_path), 'dummy.ckpt') torch.save(checkpoint, path) - err = "Unknown checkpoint task. Only backbone or model extraction is supported" + err = 'Unknown checkpoint task. Only backbone or model extraction is supported' with pytest.raises(ValueError, match=err): extract_backbone(path) def test_get_input_layer_name_and_module() -> None: - key, module = _get_input_layer_name_and_module(timm.create_model("resnet18")) - assert key == "conv1" + key, module = _get_input_layer_name_and_module(timm.create_model('resnet18')) + assert key == 'conv1' assert isinstance(module, nn.Conv2d) assert module.in_channels == 3 @@ -49,17 +49,17 @@ def test_load_state_dict_unequal_input_channels( monkeypatch: MonkeyPatch, checkpoint: str, model: Module ) -> None: _, state_dict = extract_backbone(checkpoint) - expected_in_channels = state_dict["conv1.weight"].shape[1] + expected_in_channels = state_dict['conv1.weight'].shape[1] in_channels = 7 conv1 = nn.Conv2d( in_channels, out_channels=64, kernel_size=7, stride=1, padding=2, bias=False ) - monkeypatch.setattr(model, "conv1", conv1) + monkeypatch.setattr(model, 'conv1', conv1) warning = ( - f"input channels {in_channels} != input channels in pretrained" - f" model {expected_in_channels}. Overriding with new input channels" + f'input channels {in_channels} != input channels in pretrained' + f' model {expected_in_channels}. Overriding with new input channels' ) with pytest.warns(UserWarning, match=warning): load_state_dict(model, state_dict) @@ -69,16 +69,16 @@ def test_load_state_dict_unequal_classes( monkeypatch: MonkeyPatch, checkpoint: str, model: Module ) -> None: _, state_dict = extract_backbone(checkpoint) - expected_num_classes = state_dict["fc.weight"].shape[0] + expected_num_classes = state_dict['fc.weight'].shape[0] num_classes = 10 in_features = cast(int, cast(nn.Module, model.fc).in_features) fc = nn.Linear(in_features, out_features=num_classes) - monkeypatch.setattr(model, "fc", fc) + monkeypatch.setattr(model, 'fc', fc) warning = ( - f"num classes {num_classes} != num classes in pretrained model" - f" {expected_num_classes}. Overriding with new num classes" + f'num classes {num_classes} != num classes in pretrained model' + f' {expected_num_classes}. Overriding with new num classes' ) with pytest.warns(UserWarning, match=warning): load_state_dict(model, state_dict) diff --git a/tests/transforms/test_color.py b/tests/transforms/test_color.py index f3bf554549d..2e271f89bc9 100644 --- a/tests/transforms/test_color.py +++ b/tests/transforms/test_color.py @@ -11,21 +11,21 @@ @pytest.fixture def sample() -> dict[str, Tensor]: return { - "image": torch.arange(3 * 4 * 4, dtype=torch.float).view(3, 4, 4), - "mask": torch.arange(4 * 4, dtype=torch.long).view(1, 4, 4), + 'image': torch.arange(3 * 4 * 4, dtype=torch.float).view(3, 4, 4), + 'mask': torch.arange(4 * 4, dtype=torch.long).view(1, 4, 4), } @pytest.fixture def batch() -> dict[str, Tensor]: return { - "image": torch.arange(2 * 3 * 4 * 4, dtype=torch.float).view(2, 3, 4, 4), - "mask": torch.arange(2 * 4 * 4, dtype=torch.long).view(2, 1, 4, 4), + 'image': torch.arange(2 * 3 * 4 * 4, dtype=torch.float).view(2, 3, 4, 4), + 'mask': torch.arange(2 * 4 * 4, dtype=torch.long).view(2, 1, 4, 4), } @pytest.mark.parametrize( - "weights", + 'weights', [ torch.tensor([1.0, 1.0, 1.0]), torch.tensor([0.299, 0.587, 0.114]), @@ -33,16 +33,16 @@ def batch() -> dict[str, Tensor]: ], ) def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> None: - aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=["image"]) + aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image']) output = aug(sample) - assert output["image"].shape == sample["image"].shape - assert output["image"].sum() == sample["image"].sum() + assert output['image'].shape == sample['image'].shape + assert output['image'].sum() == sample['image'].sum() for i in range(1, 3): - assert torch.allclose(output["image"][0, 0], output["image"][0, i]) + assert torch.allclose(output['image'][0, 0], output['image'][0, i]) @pytest.mark.parametrize( - "weights", + 'weights', [ torch.tensor([1.0, 1.0, 1.0]), torch.tensor([0.299, 0.587, 0.114]), @@ -50,9 +50,9 @@ def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> ], ) def test_random_grayscale_batch(weights: Tensor, batch: dict[str, Tensor]) -> None: - aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=["image"]) + aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image']) output = aug(batch) - assert output["image"].shape == batch["image"].shape - assert output["image"].sum() == batch["image"].sum() + assert output['image'].shape == batch['image'].shape + assert output['image'].sum() == batch['image'].sum() for i in range(1, 3): - assert torch.allclose(output["image"][0, 0], output["image"][0, i]) + assert torch.allclose(output['image'][0, 0], output['image'][0, i]) diff --git a/tests/transforms/test_indices.py b/tests/transforms/test_indices.py index 1c1f74e2561..3d83f857304 100644 --- a/tests/transforms/test_indices.py +++ b/tests/transforms/test_indices.py @@ -27,51 +27,51 @@ @pytest.fixture def sample() -> dict[str, Tensor]: return { - "image": torch.arange(3 * 4 * 4, dtype=torch.float).view(3, 4, 4), - "mask": torch.arange(4 * 4, dtype=torch.long).view(1, 4, 4), + 'image': torch.arange(3 * 4 * 4, dtype=torch.float).view(3, 4, 4), + 'mask': torch.arange(4 * 4, dtype=torch.long).view(1, 4, 4), } @pytest.fixture def batch() -> dict[str, Tensor]: return { - "image": torch.arange(2 * 3 * 4 * 4, dtype=torch.float).view(2, 3, 4, 4), - "mask": torch.arange(2 * 4 * 4, dtype=torch.long).view(2, 1, 4, 4), + 'image': torch.arange(2 * 3 * 4 * 4, dtype=torch.float).view(2, 3, 4, 4), + 'mask': torch.arange(2 * 4 * 4, dtype=torch.long).view(2, 1, 4, 4), } def test_append_index_sample(sample: dict[str, Tensor]) -> None: - c, h, w = sample["image"].shape + c, h, w = sample['image'].shape aug = AugmentationSequential( AppendNormalizedDifferenceIndex(index_a=0, index_b=1), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) output = aug(sample) - assert output["image"].shape == (1, c + 1, h, w) + assert output['image'].shape == (1, c + 1, h, w) def test_append_index_batch(batch: dict[str, Tensor]) -> None: - b, c, h, w = batch["image"].shape + b, c, h, w = batch['image'].shape aug = AugmentationSequential( AppendNormalizedDifferenceIndex(index_a=0, index_b=1), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) output = aug(batch) - assert output["image"].shape == (b, c + 1, h, w) + assert output['image'].shape == (b, c + 1, h, w) def test_append_triband_index_batch(batch: dict[str, Tensor]) -> None: - b, c, h, w = batch["image"].shape + b, c, h, w = batch['image'].shape aug = AugmentationSequential( AppendTriBandNormalizedDifferenceIndex(index_a=0, index_b=1, index_c=2), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) output = aug(batch) - assert output["image"].shape == (b, c + 1, h, w) + assert output['image'].shape == (b, c + 1, h, w) @pytest.mark.parametrize( - "index", + 'index', [ AppendBNDVI, AppendNBR, @@ -87,17 +87,17 @@ def test_append_triband_index_batch(batch: dict[str, Tensor]) -> None: def test_append_normalized_difference_indices( sample: dict[str, Tensor], index: AppendNormalizedDifferenceIndex ) -> None: - c, h, w = sample["image"].shape - aug = AugmentationSequential(index(0, 1), data_keys=["image", "mask"]) + c, h, w = sample['image'].shape + aug = AugmentationSequential(index(0, 1), data_keys=['image', 'mask']) output = aug(sample) - assert output["image"].shape == (1, c + 1, h, w) + assert output['image'].shape == (1, c + 1, h, w) -@pytest.mark.parametrize("index", [AppendGBNDVI, AppendGRNDVI, AppendRBNDVI]) +@pytest.mark.parametrize('index', [AppendGBNDVI, AppendGRNDVI, AppendRBNDVI]) def test_append_tri_band_normalized_difference_indices( sample: dict[str, Tensor], index: AppendTriBandNormalizedDifferenceIndex ) -> None: - c, h, w = sample["image"].shape - aug = AugmentationSequential(index(0, 1, 2), data_keys=["image", "mask"]) + c, h, w = sample['image'].shape + aug = AugmentationSequential(index(0, 1, 2), data_keys=['image', 'mask']) output = aug(sample) - assert output["image"].shape == (1, c + 1, h, w) + assert output['image'].shape == (1, c + 1, h, w) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 3af8df3816e..1f2071ae812 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -21,17 +21,17 @@ @pytest.fixture def batch_gray() -> dict[str, Tensor]: return { - "image": torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float), - "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), - "labels": torch.tensor([[0, 1]]), + 'image': torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float), + 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), + 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), + 'labels': torch.tensor([[0, 1]]), } @pytest.fixture def batch_rgb() -> dict[str, Tensor]: return { - "image": torch.tensor( + 'image': torch.tensor( [ [ [[1, 2, 3], [4, 5, 6], [7, 8, 9]], @@ -41,16 +41,16 @@ def batch_rgb() -> dict[str, Tensor]: ], dtype=torch.float, ), - "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), - "labels": torch.tensor([[0, 1]]), + 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), + 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), + 'labels': torch.tensor([[0, 1]]), } @pytest.fixture def batch_multispectral() -> dict[str, Tensor]: return { - "image": torch.tensor( + 'image': torch.tensor( [ [ [[1, 2, 3], [4, 5, 6], [7, 8, 9]], @@ -62,28 +62,28 @@ def batch_multispectral() -> dict[str, Tensor]: ], dtype=torch.float, ), - "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), - "labels": torch.tensor([[0, 1]]), + 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), + 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), + 'labels': torch.tensor([[0, 1]]), } def assert_matching(output: dict[str, Tensor], expected: dict[str, Tensor]) -> None: for key in expected: - err = f"output[{key}] != expected[{key}]" + err = f'output[{key}] != expected[{key}]' equal = torch.allclose(output[key], expected[key]) assert equal, err def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None: expected = { - "image": torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float), - "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), - "labels": torch.tensor([[0, 1]]), + 'image': torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float), + 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), + 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), + 'labels': torch.tensor([[0, 1]]), } augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=["image", "mask", "boxes"] + K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] ) output = augs(batch_gray) assert_matching(output, expected) @@ -91,7 +91,7 @@ def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None: def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None: expected = { - "image": torch.tensor( + 'image': torch.tensor( [ [ [[3, 2, 1], [6, 5, 4], [9, 8, 7]], @@ -101,12 +101,12 @@ def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None: ], dtype=torch.float, ), - "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), - "labels": torch.tensor([[0, 1]]), + 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), + 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), + 'labels': torch.tensor([[0, 1]]), } augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=["image", "mask", "boxes"] + K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] ) output = augs(batch_rgb) assert_matching(output, expected) @@ -116,7 +116,7 @@ def test_augmentation_sequential_multispectral( batch_multispectral: dict[str, Tensor], ) -> None: expected = { - "image": torch.tensor( + 'image': torch.tensor( [ [ [[3, 2, 1], [6, 5, 4], [9, 8, 7]], @@ -128,12 +128,12 @@ def test_augmentation_sequential_multispectral( ], dtype=torch.float, ), - "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), - "labels": torch.tensor([[0, 1]]), + 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), + 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), + 'labels': torch.tensor([[0, 1]]), } augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=["image", "mask", "boxes"] + K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] ) output = augs(batch_multispectral) assert_matching(output, expected) @@ -143,7 +143,7 @@ def test_augmentation_sequential_image_only( batch_multispectral: dict[str, Tensor], ) -> None: expected = { - "image": torch.tensor( + 'image': torch.tensor( [ [ [[3, 2, 1], [6, 5, 4], [9, 8, 7]], @@ -155,12 +155,12 @@ def test_augmentation_sequential_image_only( ], dtype=torch.float, ), - "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), - "labels": torch.tensor([[0, 1]]), + 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), + 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), + 'labels': torch.tensor([[0, 1]]), } augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=["image"] + K.RandomHorizontalFlip(p=1.0), data_keys=['image'] ) output = augs(batch_multispectral) assert_matching(output, expected) @@ -170,7 +170,7 @@ def test_sequential_transforms_augmentations( batch_multispectral: dict[str, Tensor], ) -> None: expected = { - "image": torch.tensor( + 'image': torch.tensor( [ [ [[3, 2, 1], [6, 5, 4], [9, 8, 7]], @@ -187,9 +187,9 @@ def test_sequential_transforms_augmentations( ], dtype=torch.float, ), - "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), - "labels": torch.tensor([[0, 1]]), + 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), + 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), + 'labels': torch.tensor([[0, 1]]), } train_transforms = transforms.AugmentationSequential( indices.AppendNBR(index_nir=0, index_swir=0), @@ -198,7 +198,7 @@ def test_sequential_transforms_augmentations( indices.AppendNDVI(index_red=0, index_nir=0), indices.AppendNDWI(index_green=0, index_nir=0), K.RandomHorizontalFlip(p=1.0), - data_keys=["image", "mask", "boxes"], + data_keys=['image', 'mask', 'boxes'], ) output = train_transforms(batch_multispectral) assert_matching(output, expected) @@ -212,46 +212,46 @@ def test_extract_patches() -> None: # test default settings (when stride is not defined, s=p) batch = { - "image": torch.randn(size=(b, c, h, w)), - "mask": torch.randint(low=0, high=2, size=(b, h, w)), + 'image': torch.randn(size=(b, c, h, w)), + 'mask': torch.randint(low=0, high=2, size=(b, h, w)), } train_transforms = transforms.AugmentationSequential( - _ExtractPatches(window_size=p), same_on_batch=True, data_keys=["image", "mask"] + _ExtractPatches(window_size=p), same_on_batch=True, data_keys=['image', 'mask'] ) output = train_transforms(batch) - assert batch["image"].shape == (b * num_patches, c, p, p) - assert batch["mask"].shape == (b * num_patches, p, p) + assert batch['image'].shape == (b * num_patches, c, p, p) + assert batch['mask'].shape == (b * num_patches, p, p) # Test different stride s = 16 num_patches = ((h - p + s) // s) * ((w - p + s) // s) batch = { - "image": torch.randn(size=(b, c, h, w)), - "mask": torch.randint(low=0, high=2, size=(b, h, w)), + 'image': torch.randn(size=(b, c, h, w)), + 'mask': torch.randint(low=0, high=2, size=(b, h, w)), } train_transforms = transforms.AugmentationSequential( _ExtractPatches(window_size=p, stride=s), same_on_batch=True, - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) output = train_transforms(batch) - assert batch["image"].shape == (b * num_patches, c, p, p) - assert batch["mask"].shape == (b * num_patches, p, p) + assert batch['image'].shape == (b * num_patches, c, p, p) + assert batch['mask'].shape == (b * num_patches, p, p) # Test keepdim=False s = p num_patches = ((h - p + s) // s) * ((w - p + s) // s) batch = { - "image": torch.randn(size=(b, c, h, w)), - "mask": torch.randint(low=0, high=2, size=(b, h, w)), + 'image': torch.randn(size=(b, c, h, w)), + 'mask': torch.randint(low=0, high=2, size=(b, h, w)), } train_transforms = transforms.AugmentationSequential( _ExtractPatches(window_size=p, stride=s, keepdim=False), same_on_batch=True, - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) output = train_transforms(batch) for k, v in output.items(): print(k, v.shape, v.dtype) - assert batch["image"].shape == (b, num_patches, c, p, p) - assert batch["mask"].shape == (b, num_patches, 1, p, p) + assert batch['image'].shape == (b, num_patches, c, p, p) + assert batch['mask'].shape == (b, num_patches, 1, p, p) diff --git a/torchgeo/__init__.py b/torchgeo/__init__.py index 0c1615db840..21e4dc5ee3f 100644 --- a/torchgeo/__init__.py +++ b/torchgeo/__init__.py @@ -10,5 +10,5 @@ common image transformations for geospatial data. """ -__author__ = "Adam J. Stewart" -__version__ = "0.6.0.dev0" +__author__ = 'Adam J. Stewart' +__version__ = '0.6.0.dev0' diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 5ee3a47aaaa..96e78befc31 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -48,54 +48,54 @@ __all__ = ( # GeoDataset - "AgriFieldNetDataModule", - "ChesapeakeCVPRDataModule", - "L7IrishDataModule", - "L8BiomeDataModule", - "NAIPChesapeakeDataModule", - "Sentinel2CDLDataModule", - "Sentinel2EuroCropsDataModule", - "Sentinel2NCCMDataModule", - "Sentinel2SouthAmericaSoybeanDataModule", + 'AgriFieldNetDataModule', + 'ChesapeakeCVPRDataModule', + 'L7IrishDataModule', + 'L8BiomeDataModule', + 'NAIPChesapeakeDataModule', + 'Sentinel2CDLDataModule', + 'Sentinel2EuroCropsDataModule', + 'Sentinel2NCCMDataModule', + 'Sentinel2SouthAmericaSoybeanDataModule', # NonGeoDataset - "BigEarthNetDataModule", - "ChaBuDDataModule", - "COWCCountingDataModule", - "DeepGlobeLandCoverDataModule", - "ETCI2021DataModule", - "EuroSATDataModule", - "EuroSAT100DataModule", - "FAIR1MDataModule", - "FireRiskDataModule", - "GID15DataModule", - "InriaAerialImageLabelingDataModule", - "LandCoverAIDataModule", - "LEVIRCDDataModule", - "LEVIRCDPlusDataModule", - "LoveDADataModule", - "NASAMarineDebrisDataModule", - "OSCDDataModule", - "Potsdam2DDataModule", - "RESISC45DataModule", - "SeasonalContrastS2DataModule", - "SEN12MSDataModule", - "SKIPPDDataModule", - "So2SatDataModule", - "SpaceNet1DataModule", - "SSL4EOLBenchmarkDataModule", - "SSL4EOLDataModule", - "SSL4EOS12DataModule", - "SustainBenchCropYieldDataModule", - "TropicalCycloneDataModule", - "UCMercedDataModule", - "USAVarsDataModule", - "Vaihingen2DDataModule", - "VHR10DataModule", - "XView2DataModule", + 'BigEarthNetDataModule', + 'ChaBuDDataModule', + 'COWCCountingDataModule', + 'DeepGlobeLandCoverDataModule', + 'ETCI2021DataModule', + 'EuroSATDataModule', + 'EuroSAT100DataModule', + 'FAIR1MDataModule', + 'FireRiskDataModule', + 'GID15DataModule', + 'InriaAerialImageLabelingDataModule', + 'LandCoverAIDataModule', + 'LEVIRCDDataModule', + 'LEVIRCDPlusDataModule', + 'LoveDADataModule', + 'NASAMarineDebrisDataModule', + 'OSCDDataModule', + 'Potsdam2DDataModule', + 'RESISC45DataModule', + 'SeasonalContrastS2DataModule', + 'SEN12MSDataModule', + 'SKIPPDDataModule', + 'So2SatDataModule', + 'SpaceNet1DataModule', + 'SSL4EOLBenchmarkDataModule', + 'SSL4EOLDataModule', + 'SSL4EOS12DataModule', + 'SustainBenchCropYieldDataModule', + 'TropicalCycloneDataModule', + 'UCMercedDataModule', + 'USAVarsDataModule', + 'Vaihingen2DDataModule', + 'VHR10DataModule', + 'XView2DataModule', # Base classes - "BaseDataModule", - "GeoDataModule", - "NonGeoDataModule", + 'BaseDataModule', + 'GeoDataModule', + 'NonGeoDataModule', # Utilities - "MisconfigurationException", + 'MisconfigurationException', ) diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index fb485eec9b7..bed6365d4a2 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -54,9 +54,9 @@ def __init__( K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], extra_args={ - DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} + DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) @@ -72,15 +72,15 @@ def setup(self, stage: str) -> None: random_bbox_assignment(dataset, [0.8, 0.1, 0.1], generator) ) - if stage in ["fit"]: + if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( self.train_dataset, self.patch_size, self.batch_size, self.length ) - if stage in ["fit", "validate"]: + if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( self.val_dataset, self.patch_size, self.patch_size ) - if stage in ["test"]: + if stage in ['test']: self.test_sampler = GridGeoSampler( self.test_dataset, self.patch_size, self.patch_size ) diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py index b75c38af3dd..695de0bb635 100644 --- a/torchgeo/datamodules/bigearthnet.py +++ b/torchgeo/datamodules/bigearthnet.py @@ -76,11 +76,11 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.BigEarthNet`. """ - bands = kwargs.get("bands", "all") - if bands == "all": + bands = kwargs.get('bands', 'all') + if bands == 'all': mins = self.mins maxs = self.maxs - elif bands == "s1": + elif bands == 's1': mins = self.mins[:2] maxs = self.maxs[:2] else: diff --git a/torchgeo/datamodules/chabud.py b/torchgeo/datamodules/chabud.py index 2f50d77846d..ae210cac0ed 100644 --- a/torchgeo/datamodules/chabud.py +++ b/torchgeo/datamodules/chabud.py @@ -52,14 +52,14 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.ChaBuD`. """ - bands = kwargs.get("bands", ChaBuD.all_bands) + bands = kwargs.get('bands', ChaBuD.all_bands) band_indices = [ChaBuD.all_bands.index(b) for b in bands] mins = self.min[band_indices] maxs = self.max[band_indices] # Change detection, 2 images from different times - mins = repeat(mins, "c -> (t c)", t=2) - maxs = repeat(maxs, "c -> (t c)", t=2) + mins = repeat(mins, 'c -> (t c)', t=2) + maxs = repeat(maxs, 'c -> (t c)', t=2) self.mean = mins self.std = maxs - mins @@ -72,6 +72,6 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit", "validate"]: - self.train_dataset = ChaBuD(split="train", **self.kwargs) - self.val_dataset = ChaBuD(split="val", **self.kwargs) + if stage in ['fit', 'validate']: + self.train_dataset = ChaBuD(split='train', **self.kwargs) + self.val_dataset = ChaBuD(split='val', **self.kwargs) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 7c1d3846bc6..37a0d32edbd 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -38,14 +38,14 @@ def forward(self, sample: dict[str, Any]) -> dict[str, Any]: Returns: Augmented sample. """ - for key in ["image", "mask"]: + for key in ['image', 'mask']: dtype = sample[key].dtype # All inputs must be float sample[key] = sample[key].float() sample[key] = self.aug(sample[key]) sample[key] = sample[key].to(dtype) # Kornia adds batch dimension - sample[key] = rearrange(sample[key], "() c h w -> c h w") + sample[key] = rearrange(sample[key], '() c h w -> c h w') return sample @@ -94,7 +94,7 @@ def __init__( # This is a rough estimate of how large of a patch we will need to sample in # EPSG:3857 in order to guarantee a large enough patch in the local CRS. self.original_patch_size = patch_size * 3 - kwargs["transforms"] = _Transform(K.CenterCrop(patch_size)) + kwargs['transforms'] = _Transform(K.CenterCrop(patch_size)) super().__init__( ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs @@ -103,8 +103,8 @@ def __init__( assert class_set in [5, 7] if use_prior_labels and class_set == 7: raise ValueError( - "The pre-generated prior labels are only valid for the 5" - + " class set of labels" + 'The pre-generated prior labels are only valid for the 5' + + ' class set of labels' ) self.train_splits = train_splits @@ -116,14 +116,14 @@ def __init__( if self.use_prior_labels: self.layers = [ - "naip-new", - "prior_from_cooccurrences_101_31_no_osm_no_buildings", + 'naip-new', + 'prior_from_cooccurrences_101_31_no_osm_no_buildings', ] else: - self.layers = ["naip-new", "lc"] + self.layers = ['naip-new', 'lc'] self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] + K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] ) def setup(self, stage: str) -> None: @@ -132,7 +132,7 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit"]: + if stage in ['fit']: self.train_dataset = ChesapeakeCVPR( splits=self.train_splits, layers=self.layers, **self.kwargs ) @@ -142,14 +142,14 @@ def setup(self, stage: str) -> None: self.batch_size, self.length, ) - if stage in ["fit", "validate"]: + if stage in ['fit', 'validate']: self.val_dataset = ChesapeakeCVPR( splits=self.val_splits, layers=self.layers, **self.kwargs ) self.val_sampler = GridGeoSampler( self.val_dataset, self.original_patch_size, self.original_patch_size ) - if stage in ["test"]: + if stage in ['test']: self.test_dataset = ChesapeakeCVPR( splits=self.test_splits, layers=self.layers, **self.kwargs ) @@ -170,13 +170,13 @@ def on_after_batch_transfer( A batch of data. """ if self.use_prior_labels: - batch["mask"] = F.normalize(batch["mask"].float(), p=1, dim=1) - batch["mask"] = F.normalize( - batch["mask"] + self.prior_smoothing_constant, p=1, dim=1 + batch['mask'] = F.normalize(batch['mask'].float(), p=1, dim=1) + batch['mask'] = F.normalize( + batch['mask'] + self.prior_smoothing_constant, p=1, dim=1 ).long() else: if self.class_set == 5: - batch["mask"][batch["mask"] == 5] = 4 - batch["mask"][batch["mask"] == 6] = 4 + batch['mask'][batch['mask'] == 5] = 4 + batch['mask'][batch['mask'] == 6] = 4 return super().on_after_batch_transfer(batch, dataloader_idx) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index 799850f4e4a..bd24d58b49a 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -34,8 +34,8 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.dataset = COWCCounting(split="train", **self.kwargs) - self.test_dataset = COWCCounting(split="test", **self.kwargs) + self.dataset = COWCCounting(split='train', **self.kwargs) + self.test_dataset = COWCCounting(split='test', **self.kwargs) self.train_dataset, self.val_dataset = random_split( self.dataset, [len(self.dataset) - len(self.test_dataset), len(self.test_dataset)], diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 34446280286..39021fc2acc 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -42,12 +42,12 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit", "validate"]: - self.dataset = TropicalCyclone(split="train", **self.kwargs) + if stage in ['fit', 'validate']: + self.dataset = TropicalCyclone(split='train', **self.kwargs) storm_ids = [] for item in self.dataset.collection: - storm_id = item["href"].split("/")[0].split("_")[-2] + storm_id = item['href'].split('/')[0].split('_')[-2] storm_ids.append(storm_id) train_indices, val_indices = group_shuffle_split( @@ -56,5 +56,5 @@ def setup(self, stage: str) -> None: self.train_dataset = Subset(self.dataset, train_indices) self.val_dataset = Subset(self.dataset, val_indices) - if stage in ["test"]: - self.test_dataset = TropicalCyclone(split="test", **self.kwargs) + if stage in ['test']: + self.test_dataset = TropicalCyclone(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 2195aca3f0f..3e99d31e6b8 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -48,7 +48,7 @@ def __init__( self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) def setup(self, stage: str) -> None: @@ -57,10 +57,10 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit", "validate"]: - self.dataset = DeepGlobeLandCover(split="train", **self.kwargs) + if stage in ['fit', 'validate']: + self.dataset = DeepGlobeLandCover(split='train', **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( self.dataset, self.val_split_pct ) - if stage in ["test"]: - self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs) + if stage in ['test']: + self.test_dataset = DeepGlobeLandCover(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index ea78a0ebb79..233fa43261b 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -54,13 +54,13 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit"]: - self.train_dataset = ETCI2021(split="train", **self.kwargs) - if stage in ["fit", "validate"]: - self.val_dataset = ETCI2021(split="val", **self.kwargs) - if stage in ["predict"]: + if stage in ['fit']: + self.train_dataset = ETCI2021(split='train', **self.kwargs) + if stage in ['fit', 'validate']: + self.val_dataset = ETCI2021(split='val', **self.kwargs) + if stage in ['predict']: # Test set masks are not public, use for prediction instead - self.predict_dataset = ETCI2021(split="test", **self.kwargs) + self.predict_dataset = ETCI2021(split='test', **self.kwargs) def on_after_batch_transfer( self, batch: dict[str, Tensor], dataloader_idx: int @@ -77,6 +77,6 @@ def on_after_batch_transfer( if self.trainer: if not self.trainer.predicting: # Evaluate against flood mask, not water mask - batch["mask"] = (batch["mask"][:, 1] > 0).long() + batch['mask'] = (batch['mask'][:, 1] > 0).long() return super().on_after_batch_transfer(batch, dataloader_idx) diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index ccf2a90f691..99dd7709714 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -11,35 +11,35 @@ from .geo import NonGeoDataModule MEAN = { - "B01": 1354.40546513, - "B02": 1118.24399958, - "B03": 1042.92983953, - "B04": 947.62620298, - "B05": 1199.47283961, - "B06": 1999.79090914, - "B07": 2369.22292565, - "B08": 2296.82608323, - "B8A": 732.08340178, - "B09": 12.11327804, - "B10": 1819.01027855, - "B11": 1118.92391149, - "B12": 2594.14080798, + 'B01': 1354.40546513, + 'B02': 1118.24399958, + 'B03': 1042.92983953, + 'B04': 947.62620298, + 'B05': 1199.47283961, + 'B06': 1999.79090914, + 'B07': 2369.22292565, + 'B08': 2296.82608323, + 'B8A': 732.08340178, + 'B09': 12.11327804, + 'B10': 1819.01027855, + 'B11': 1118.92391149, + 'B12': 2594.14080798, } STD = { - "B01": 245.71762908, - "B02": 333.00778264, - "B03": 395.09249139, - "B04": 593.75055589, - "B05": 566.4170017, - "B06": 861.18399006, - "B07": 1086.63139075, - "B08": 1117.98170791, - "B8A": 404.91978886, - "B09": 4.77584468, - "B10": 1002.58768311, - "B11": 761.30323499, - "B12": 1231.58581042, + 'B01': 245.71762908, + 'B02': 333.00778264, + 'B03': 395.09249139, + 'B04': 593.75055589, + 'B05': 566.4170017, + 'B06': 861.18399006, + 'B07': 1086.63139075, + 'B08': 1117.98170791, + 'B8A': 404.91978886, + 'B09': 4.77584468, + 'B10': 1002.58768311, + 'B11': 761.30323499, + 'B12': 1231.58581042, } @@ -64,7 +64,7 @@ def __init__( """ super().__init__(EuroSAT, batch_size, num_workers, **kwargs) - bands = kwargs.get("bands", EuroSAT.all_band_names) + bands = kwargs.get('bands', EuroSAT.all_band_names) self.mean = torch.tensor([MEAN[b] for b in bands]) self.std = torch.tensor([STD[b] for b in bands]) @@ -90,6 +90,6 @@ def __init__( """ super().__init__(EuroSAT100, batch_size, num_workers, **kwargs) - bands = kwargs.get("bands", EuroSAT.all_band_names) + bands = kwargs.get('bands', EuroSAT.all_band_names) self.mean = torch.tensor([MEAN[b] for b in bands]) self.std = torch.tensor([STD[b] for b in bands]) diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index f23dc64ec55..291dd617e04 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -24,12 +24,12 @@ def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: .. versionadded:: 0.5 """ output: dict[str, Any] = {} - output["image"] = torch.stack([sample["image"] for sample in batch]) + output['image'] = torch.stack([sample['image'] for sample in batch]) - if "boxes" in batch[0]: - output["boxes"] = [sample["boxes"] for sample in batch] - if "label" in batch[0]: - output["label"] = [sample["label"] for sample in batch] + if 'boxes' in batch[0]: + output['boxes'] = [sample['boxes'] for sample in batch] + if 'label' in batch[0]: + output['label'] = [sample['label'] for sample in batch] return output @@ -63,10 +63,10 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit"]: - self.train_dataset = FAIR1M(split="train", **self.kwargs) - if stage in ["fit", "validate"]: - self.val_dataset = FAIR1M(split="val", **self.kwargs) - if stage in ["predict"]: + if stage in ['fit']: + self.train_dataset = FAIR1M(split='train', **self.kwargs) + if stage in ['fit', 'validate']: + self.val_dataset = FAIR1M(split='val', **self.kwargs) + if stage in ['predict']: # Test set labels are not publicly available - self.predict_dataset = FAIR1M(split="test", **self.kwargs) + self.predict_dataset = FAIR1M(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/fire_risk.py b/torchgeo/datamodules/fire_risk.py index 9bfe6b4f2f3..1a0d6c7c047 100644 --- a/torchgeo/datamodules/fire_risk.py +++ b/torchgeo/datamodules/fire_risk.py @@ -38,7 +38,7 @@ def __init__( K.RandomSharpness(p=0.5), K.RandomErasing(p=0.1), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=["image"], + data_keys=['image'], ) def setup(self, stage: str) -> None: @@ -47,7 +47,7 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit"]: - self.train_dataset = FireRisk(split="train", **self.kwargs) - if stage in ["fit", "validate"]: - self.val_dataset = FireRisk(split="val", **self.kwargs) + if stage in ['fit']: + self.train_dataset = FireRisk(split='train', **self.kwargs) + if stage in ['fit', 'validate']: + self.val_dataset = FireRisk(split='val', **self.kwargs) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 06d09b37f40..e50e48dbfa1 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -71,7 +71,7 @@ def __init__( # Data augmentation Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]] self.aug: Transform = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] + K.Normalize(mean=self.mean, std=self.std), data_keys=['image'] ) self.train_aug: Transform | None = None self.val_aug: Transform | None = None @@ -85,7 +85,7 @@ def prepare_data(self) -> None: to avoid corrupted data. This method should not set state since it is not called on every device, use ``setup`` instead. """ - if self.kwargs.get("download", False): + if self.kwargs.get('download', False): self.dataset_class(**self.kwargs) def _valid_attribute(self, *args: str) -> Any: @@ -107,12 +107,12 @@ def _valid_attribute(self, *args: str) -> Any: continue if not obj: - msg = f"{self.__class__.__name__}.{arg} has length 0." + msg = f'{self.__class__.__name__}.{arg} has length 0.' raise MisconfigurationException(msg) return obj - msg = f"{self.__class__.__name__}.setup must define one of {args}." + msg = f'{self.__class__.__name__}.setup must define one of {args}.' raise MisconfigurationException(msg) def on_after_batch_transfer( @@ -129,15 +129,15 @@ def on_after_batch_transfer( """ if self.trainer: if self.trainer.training: - split = "train" + split = 'train' elif self.trainer.validating or self.trainer.sanity_checking: - split = "val" + split = 'val' elif self.trainer.testing: - split = "test" + split = 'test' elif self.trainer.predicting: - split = "predict" + split = 'predict' - aug = self._valid_attribute(f"{split}_aug", "aug") + aug = self._valid_attribute(f'{split}_aug', 'aug') batch = aug(batch) return batch @@ -158,7 +158,7 @@ def plot(self, *args: Any, **kwargs: Any) -> Figure | None: fig: Figure | None = None dataset = self.dataset or self.val_dataset if dataset is not None: - if hasattr(dataset, "plot"): + if hasattr(dataset, 'plot'): fig = dataset.plot(*args, **kwargs) return fig @@ -220,31 +220,31 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit"]: + if stage in ['fit']: self.train_dataset = cast( GeoDataset, self.dataset_class( # type: ignore[call-arg] - split="train", **self.kwargs + split='train', **self.kwargs ), ) self.train_batch_sampler = RandomBatchGeoSampler( self.train_dataset, self.patch_size, self.batch_size, self.length ) - if stage in ["fit", "validate"]: + if stage in ['fit', 'validate']: self.val_dataset = cast( GeoDataset, self.dataset_class( # type: ignore[call-arg] - split="val", **self.kwargs + split='val', **self.kwargs ), ) self.val_sampler = GridGeoSampler( self.val_dataset, self.patch_size, self.patch_size ) - if stage in ["test"]: + if stage in ['test']: self.test_dataset = cast( GeoDataset, self.dataset_class( # type: ignore[call-arg] - split="test", **self.kwargs + split='test', **self.kwargs ), ) self.test_sampler = GridGeoSampler( @@ -264,11 +264,11 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a dataset or sampler, or if the dataset or sampler has length 0. """ - dataset = self._valid_attribute(f"{split}_dataset", "dataset") + dataset = self._valid_attribute(f'{split}_dataset', 'dataset') sampler = self._valid_attribute( - f"{split}_batch_sampler", f"{split}_sampler", "batch_sampler", "sampler" + f'{split}_batch_sampler', f'{split}_sampler', 'batch_sampler', 'sampler' ) - batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size") + batch_size = self._valid_attribute(f'{split}_batch_size', 'batch_size') if isinstance(sampler, BatchGeoSampler): batch_size = 1 @@ -296,7 +296,7 @@ def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a dataset or sampler, or if the dataset or sampler has length 0. """ - return self._dataloader_factory("train") + return self._dataloader_factory('train') def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for validation. @@ -308,7 +308,7 @@ def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a dataset or sampler, or if the dataset or sampler has length 0. """ - return self._dataloader_factory("val") + return self._dataloader_factory('val') def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for testing. @@ -320,7 +320,7 @@ def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a dataset or sampler, or if the dataset or sampler has length 0. """ - return self._dataloader_factory("test") + return self._dataloader_factory('test') def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for prediction. @@ -332,7 +332,7 @@ def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a dataset or sampler, or if the dataset or sampler has length 0. """ - return self._dataloader_factory("predict") + return self._dataloader_factory('predict') def transfer_batch_to_device( self, batch: dict[str, Tensor], device: torch.device, dataloader_idx: int @@ -350,8 +350,8 @@ def transfer_batch_to_device( A reference to the data on the new device. """ # Non-Tensor values cannot be moved to a device - del batch["crs"] - del batch["bbox"] + del batch['crs'] + del batch['bbox'] batch = super().transfer_batch_to_device(batch, device, dataloader_idx) return batch @@ -393,17 +393,17 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit"]: + if stage in ['fit']: self.train_dataset = self.dataset_class( # type: ignore[call-arg] - split="train", **self.kwargs + split='train', **self.kwargs ) - if stage in ["fit", "validate"]: + if stage in ['fit', 'validate']: self.val_dataset = self.dataset_class( # type: ignore[call-arg] - split="val", **self.kwargs + split='val', **self.kwargs ) - if stage in ["test"]: + if stage in ['test']: self.test_dataset = self.dataset_class( # type: ignore[call-arg] - split="test", **self.kwargs + split='test', **self.kwargs ) def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: @@ -419,12 +419,12 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a dataset or sampler, or if the dataset or sampler has length 0. """ - dataset = self._valid_attribute(f"{split}_dataset", "dataset") - batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size") + dataset = self._valid_attribute(f'{split}_dataset', 'dataset') + batch_size = self._valid_attribute(f'{split}_batch_size', 'batch_size') return DataLoader( dataset=dataset, batch_size=batch_size, - shuffle=split == "train", + shuffle=split == 'train', num_workers=self.num_workers, collate_fn=self.collate_fn, ) @@ -439,7 +439,7 @@ def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a dataset, or if the dataset has length 0. """ - return self._dataloader_factory("train") + return self._dataloader_factory('train') def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for validation. @@ -451,7 +451,7 @@ def val_dataloader(self) -> DataLoader[dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a dataset, or if the dataset has length 0. """ - return self._dataloader_factory("val") + return self._dataloader_factory('val') def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for testing. @@ -463,7 +463,7 @@ def test_dataloader(self) -> DataLoader[dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a dataset, or if the dataset has length 0. """ - return self._dataloader_factory("test") + return self._dataloader_factory('test') def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for prediction. @@ -475,4 +475,4 @@ def predict_dataloader(self) -> DataLoader[dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a dataset, or if the dataset has length 0. """ - return self._dataloader_factory("predict") + return self._dataloader_factory('predict') diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 9f6b2f5da2f..72f800d301a 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -50,12 +50,12 @@ def __init__( self.train_aug = self.val_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) self.predict_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=["image"], + data_keys=['image'], ) def setup(self, stage: str) -> None: @@ -64,11 +64,11 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit", "validate"]: - self.dataset = GID15(split="train", **self.kwargs) + if stage in ['fit', 'validate']: + self.dataset = GID15(split='train', **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( self.dataset, self.val_split_pct ) - if stage in ["test"]: + if stage in ['test']: # Test set masks are not public, use for prediction instead - self.predict_dataset = GID15(split="test", **self.kwargs) + self.predict_dataset = GID15(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index c8489441f60..39e8ede22c5 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -49,17 +49,17 @@ def __init__( K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), _RandomNCrop(self.patch_size, batch_size), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) self.predict_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=["image"], + data_keys=['image'], ) def setup(self, stage: str) -> None: @@ -68,10 +68,10 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit"]: - self.train_dataset = InriaAerialImageLabeling(split="train", **self.kwargs) - if stage in ["fit", "validate"]: - self.val_dataset = InriaAerialImageLabeling(split="val", **self.kwargs) - if stage in ["predict"]: + if stage in ['fit']: + self.train_dataset = InriaAerialImageLabeling(split='train', **self.kwargs) + if stage in ['fit', 'validate']: + self.val_dataset = InriaAerialImageLabeling(split='val', **self.kwargs) + if stage in ['predict']: # Test set masks are not public, use for prediction instead - self.predict_dataset = InriaAerialImageLabeling(split="test", **self.kwargs) + self.predict_dataset = InriaAerialImageLabeling(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/l7irish.py b/torchgeo/datamodules/l7irish.py index 82a6f10d956..35408feddbb 100644 --- a/torchgeo/datamodules/l7irish.py +++ b/torchgeo/datamodules/l7irish.py @@ -54,9 +54,9 @@ def __init__( K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], extra_args={ - DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} + DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) @@ -72,15 +72,15 @@ def setup(self, stage: str) -> None: random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator) ) - if stage in ["fit"]: + if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( self.train_dataset, self.patch_size, self.batch_size, self.length ) - if stage in ["fit", "validate"]: + if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( self.val_dataset, self.patch_size, self.patch_size ) - if stage in ["test"]: + if stage in ['test']: self.test_sampler = GridGeoSampler( self.test_dataset, self.patch_size, self.patch_size ) diff --git a/torchgeo/datamodules/l8biome.py b/torchgeo/datamodules/l8biome.py index 7201e69d135..ddc802a5ce3 100644 --- a/torchgeo/datamodules/l8biome.py +++ b/torchgeo/datamodules/l8biome.py @@ -54,9 +54,9 @@ def __init__( K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], extra_args={ - DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} + DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) @@ -72,15 +72,15 @@ def setup(self, stage: str) -> None: random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator) ) - if stage in ["fit"]: + if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( self.train_dataset, self.patch_size, self.batch_size, self.length ) - if stage in ["fit", "validate"]: + if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( self.val_dataset, self.patch_size, self.patch_size ) - if stage in ["test"]: + if stage in ['test']: self.test_sampler = GridGeoSampler( self.test_dataset, self.patch_size, self.patch_size ) diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index c6994f99d54..d775cf21a04 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -38,8 +38,8 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] + K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] ) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 2107928d456..06abd837e60 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -46,15 +46,15 @@ def __init__( self.train_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=["image1", "image2", "mask"], + data_keys=['image1', 'image2', 'mask'], ) self.val_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - data_keys=["image1", "image2", "mask"], + data_keys=['image1', 'image2', 'mask'], ) self.test_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - data_keys=["image1", "image2", "mask"], + data_keys=['image1', 'image2', 'mask'], ) @@ -94,15 +94,15 @@ def __init__( self.train_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=["image1", "image2", "mask"], + data_keys=['image1', 'image2', 'mask'], ) self.val_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - data_keys=["image1", "image2", "mask"], + data_keys=['image1', 'image2', 'mask'], ) self.test_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - data_keys=["image1", "image2", "mask"], + data_keys=['image1', 'image2', 'mask'], ) def setup(self, stage: str) -> None: @@ -111,10 +111,10 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit", "validate"]: - self.dataset = LEVIRCDPlus(split="train", **self.kwargs) + if stage in ['fit', 'validate']: + self.dataset = LEVIRCDPlus(split='train', **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( self.dataset, val_pct=self.val_split_pct ) - if stage in ["test"]: - self.test_dataset = LEVIRCDPlus(split="test", **self.kwargs) + if stage in ['test']: + self.test_dataset = LEVIRCDPlus(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index f8462ff6588..41b885d9aae 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -36,10 +36,10 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit"]: - self.train_dataset = LoveDA(split="train", **self.kwargs) - if stage in ["fit", "validate"]: - self.val_dataset = LoveDA(split="val", **self.kwargs) - if stage in ["predict"]: + if stage in ['fit']: + self.train_dataset = LoveDA(split='train', **self.kwargs) + if stage in ['fit', 'validate']: + self.val_dataset = LoveDA(split='val', **self.kwargs) + if stage in ['predict']: # Test set masks are not public, use for prediction instead - self.predict_dataset = LoveDA(split="test", **self.kwargs) + self.predict_dataset = LoveDA(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 5fb24a43aec..b414cc0991e 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -43,9 +43,9 @@ def __init__( self.naip_kwargs = {} self.chesapeake_kwargs = {} for key, val in kwargs.items(): - if key.startswith("naip_"): + if key.startswith('naip_'): self.naip_kwargs[key[5:]] = val - elif key.startswith("chesapeake_"): + elif key.startswith('chesapeake_'): self.chesapeake_kwargs[key[11:]] = val super().__init__( @@ -58,7 +58,7 @@ def __init__( ) self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] + K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] ) def setup(self, stage: str) -> None: @@ -75,19 +75,19 @@ def setup(self, stage: str) -> None: midx = roi.minx + (roi.maxx - roi.minx) / 2 midy = roi.miny + (roi.maxy - roi.miny) / 2 - if stage in ["fit"]: + if stage in ['fit']: train_roi = BoundingBox( roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt ) self.train_batch_sampler = RandomBatchGeoSampler( self.dataset, self.patch_size, self.batch_size, self.length, train_roi ) - if stage in ["fit", "validate"]: + if stage in ['fit', 'validate']: val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt) self.val_sampler = GridGeoSampler( self.dataset, self.patch_size, self.patch_size, val_roi ) - if stage in ["test"]: + if stage in ['test']: test_roi = BoundingBox( roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt ) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index 76848bc4e4b..ac566c5b8b0 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -47,7 +47,7 @@ def __init__( self.aug = AugPipe( AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "boxes"] + K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'boxes'] ), batch_size, ) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 61c5ca9aa87..2cf696cc845 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -16,35 +16,35 @@ from .utils import dataset_split MEAN = { - "B01": 1583.0741, - "B02": 1374.3202, - "B03": 1294.1616, - "B04": 1325.6158, - "B05": 1478.7408, - "B06": 1933.0822, - "B07": 2166.0608, - "B08": 2076.4868, - "B8A": 2306.0652, - "B09": 690.9814, - "B10": 16.2360, - "B11": 2080.3347, - "B12": 1524.6930, + 'B01': 1583.0741, + 'B02': 1374.3202, + 'B03': 1294.1616, + 'B04': 1325.6158, + 'B05': 1478.7408, + 'B06': 1933.0822, + 'B07': 2166.0608, + 'B08': 2076.4868, + 'B8A': 2306.0652, + 'B09': 690.9814, + 'B10': 16.2360, + 'B11': 2080.3347, + 'B12': 1524.6930, } STD = { - "B01": 52.1937, - "B02": 83.4168, - "B03": 105.6966, - "B04": 151.1401, - "B05": 147.4615, - "B06": 115.9289, - "B07": 123.1974, - "B08": 114.6483, - "B8A": 141.4530, - "B09": 73.2758, - "B10": 4.8368, - "B11": 213.4821, - "B12": 179.4793, + 'B01': 52.1937, + 'B02': 83.4168, + 'B03': 105.6966, + 'B04': 151.1401, + 'B05': 147.4615, + 'B06': 115.9289, + 'B07': 123.1974, + 'B08': 114.6483, + 'B8A': 141.4530, + 'B09': 73.2758, + 'B10': 4.8368, + 'B11': 213.4821, + 'B12': 179.4793, } @@ -81,14 +81,14 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.bands = kwargs.get("bands", OSCD.all_bands) + self.bands = kwargs.get('bands', OSCD.all_bands) self.mean = torch.tensor([MEAN[b] for b in self.bands]) self.std = torch.tensor([STD[b] for b in self.bands]) self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=["image1", "image2", "mask"], + data_keys=['image1', 'image2', 'mask'], ) def setup(self, stage: str) -> None: @@ -97,10 +97,10 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit", "validate"]: - self.dataset = OSCD(split="train", **self.kwargs) + if stage in ['fit', 'validate']: + self.dataset = OSCD(split='train', **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( self.dataset, val_pct=self.val_split_pct ) - if stage in ["test"]: - self.test_dataset = OSCD(split="test", **self.kwargs) + if stage in ['test']: + self.test_dataset = OSCD(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 397ef25d7b5..8b19c0b3235 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -50,7 +50,7 @@ def __init__( self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) def setup(self, stage: str) -> None: @@ -59,10 +59,10 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit", "validate"]: - self.dataset = Potsdam2D(split="train", **self.kwargs) + if stage in ['fit', 'validate']: + self.dataset = Potsdam2D(split='train', **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( self.dataset, self.val_split_pct ) - if stage in ["test"]: - self.test_dataset = Potsdam2D(split="test", **self.kwargs) + if stage in ['test']: + self.test_dataset = Potsdam2D(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index 77d926c4c32..e88e139f481 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -44,5 +44,5 @@ def __init__( K.RandomSharpness(p=0.5), K.RandomErasing(p=0.1), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=["image"], + data_keys=['image'], ) diff --git a/torchgeo/datamodules/seco.py b/torchgeo/datamodules/seco.py index 529999fbfff..f1ed2346164 100644 --- a/torchgeo/datamodules/seco.py +++ b/torchgeo/datamodules/seco.py @@ -33,8 +33,8 @@ def __init__( """ super().__init__(SeasonalContrastS2, batch_size, num_workers, **kwargs) - bands = kwargs.get("bands", SeasonalContrastS2.rgb_bands) - seasons = kwargs.get("seasons", 1) + bands = kwargs.get('bands', SeasonalContrastS2.rgb_bands) + seasons = kwargs.get('seasons', 1) # Normalization only available for RGB dataset, defined here: # https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501 @@ -44,16 +44,16 @@ def __init__( _mean = torch.tensor([0.485, 0.456, 0.406]) _std = torch.tensor([0.229, 0.224, 0.225]) - _min = repeat(_min, "c -> (t c)", t=seasons) - _max = repeat(_max, "c -> (t c)", t=seasons) - _mean = repeat(_mean, "c -> (t c)", t=seasons) - _std = repeat(_std, "c -> (t c)", t=seasons) + _min = repeat(_min, 'c -> (t c)', t=seasons) + _max = repeat(_max, 'c -> (t c)', t=seasons) + _mean = repeat(_mean, 'c -> (t c)', t=seasons) + _std = repeat(_std, 'c -> (t c)', t=seasons) self.aug = AugmentationSequential( K.Normalize(mean=_min, std=_max - _min), K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), K.Normalize(mean=_mean, std=_std), - data_keys=["image"], + data_keys=['image'], ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index a7c1710aaea..2ca50fb10ae 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -38,7 +38,7 @@ def __init__( self, batch_size: int = 64, num_workers: int = 0, - band_set: str = "all", + band_set: str = 'all', **kwargs: Any, ) -> None: """Initialize a new SEN12MSDataModule instance. @@ -52,13 +52,13 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.SEN12MS`. """ - kwargs["bands"] = SEN12MS.BAND_SETS[band_set] + kwargs['bands'] = SEN12MS.BAND_SETS[band_set] - if band_set == "s1": + if band_set == 's1': self.std = self.std[:2] - elif band_set == "s2-all": + elif band_set == 's2-all': self.std = self.std[2:] - elif band_set == "s2-reduced": + elif band_set == 's2-reduced': self.std = self.std[torch.tensor([3, 4, 5, 9, 12, 13])] super().__init__(SEN12MS, batch_size, num_workers, **kwargs) @@ -69,10 +69,10 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit", "validate"]: - season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} + if stage in ['fit', 'validate']: + season_to_int = {'winter': 0, 'spring': 1000, 'summer': 2000, 'fall': 3000} - self.dataset = SEN12MS(split="train", **self.kwargs) + self.dataset = SEN12MS(split='train', **self.kwargs) # A patch is a filename like: # "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" @@ -82,7 +82,7 @@ def setup(self, stage: str) -> None: # as (season_id + scene_id). scenes = [] for scene_fn in self.dataset.ids: - parts = scene_fn.split("_") + parts = scene_fn.split('_') season_id = season_to_int[parts[1]] scene_id = int(parts[3]) scenes.append(season_id + scene_id) @@ -93,8 +93,8 @@ def setup(self, stage: str) -> None: self.train_dataset = Subset(self.dataset, train_indices) self.val_dataset = Subset(self.dataset, val_indices) - if stage in ["test"]: - self.test_dataset = SEN12MS(split="test", **self.kwargs) + if stage in ['test']: + self.test_dataset = SEN12MS(split='test', **self.kwargs) def on_after_batch_transfer( self, batch: dict[str, Tensor], dataloader_idx: int @@ -108,6 +108,6 @@ def on_after_batch_transfer( Returns: A batch of data. """ - batch["mask"] = torch.take(self.DFC2020_CLASS_MAPPING, batch["mask"]) + batch['mask'] = torch.take(self.DFC2020_CLASS_MAPPING, batch['mask']) return super().on_after_batch_transfer(batch, dataloader_idx) diff --git a/torchgeo/datamodules/sentinel2_cdl.py b/torchgeo/datamodules/sentinel2_cdl.py index b219d7315e2..97c3d05392e 100644 --- a/torchgeo/datamodules/sentinel2_cdl.py +++ b/torchgeo/datamodules/sentinel2_cdl.py @@ -44,8 +44,8 @@ def __init__( (prefix keys with ``sentinel2_``). """ # Define prefix for Cropland Data Layer (CDL) and Sentinel-2 arguments - cdl_signature = "cdl_" - sentinel2_signature = "sentinel2_" + cdl_signature = 'cdl_' + sentinel2_signature = 'sentinel2_' self.cdl_kwargs = {} self.sentinel2_kwargs = {} @@ -68,14 +68,14 @@ def __init__( K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], extra_args={ - DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} + DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] + K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] ) def setup(self, stage: str) -> None: @@ -95,15 +95,15 @@ def setup(self, stage: str) -> None: self.dataset, [0.8, 0.1, 0.1], grid_size=8, generator=generator ) ) - if stage in ["fit"]: + if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( self.train_dataset, self.patch_size, self.batch_size, self.length ) - if stage in ["fit", "validate"]: + if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( self.val_dataset, self.patch_size, self.patch_size ) - if stage in ["test"]: + if stage in ['test']: self.test_sampler = GridGeoSampler( self.test_dataset, self.patch_size, self.patch_size ) diff --git a/torchgeo/datamodules/sentinel2_eurocrops.py b/torchgeo/datamodules/sentinel2_eurocrops.py index 076131f8a89..4e0893e4f8d 100644 --- a/torchgeo/datamodules/sentinel2_eurocrops.py +++ b/torchgeo/datamodules/sentinel2_eurocrops.py @@ -45,8 +45,8 @@ def __init__( and :class:`~torchgeo.datasets.Sentinel2` (prefix keys with ``sentinel2_``). """ - eurocrops_signature = "eurocrops_" - sentinel2_signature = "sentinel2_" + eurocrops_signature = 'eurocrops_' + sentinel2_signature = 'sentinel2_' self.eurocrops_kwargs = {} self.sentinel2_kwargs = {} for key, val in kwargs.items(): @@ -69,14 +69,14 @@ def __init__( K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], extra_args={ - DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} + DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] + K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] ) def setup(self, stage: str) -> None: @@ -95,15 +95,15 @@ def setup(self, stage: str) -> None: self.dataset, [0.8, 0.1, 0.1], grid_size=8, generator=generator ) ) - if stage in ["fit"]: + if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( self.train_dataset, self.patch_size, self.batch_size, self.length ) - if stage in ["fit", "validate"]: + if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( self.val_dataset, self.patch_size, self.patch_size ) - if stage in ["test"]: + if stage in ['test']: self.test_sampler = GridGeoSampler( self.test_dataset, self.patch_size, self.patch_size ) diff --git a/torchgeo/datamodules/sentinel2_nccm.py b/torchgeo/datamodules/sentinel2_nccm.py index 952bc2812c0..91b4f936fdc 100644 --- a/torchgeo/datamodules/sentinel2_nccm.py +++ b/torchgeo/datamodules/sentinel2_nccm.py @@ -44,8 +44,8 @@ def __init__( (prefix keys with ``sentinel2_``). """ # Define prefix for NCCM and Sentinel-2 arguments - nccm_signature = "nccm_" - sentinel2_signature = "sentinel2_" + nccm_signature = 'nccm_' + sentinel2_signature = 'sentinel2_' self.nccm_kwargs = {} self.sentinel2_kwargs = {} @@ -68,14 +68,14 @@ def __init__( K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], extra_args={ - DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} + DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] + K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] ) def setup(self, stage: str) -> None: @@ -95,15 +95,15 @@ def setup(self, stage: str) -> None: self.dataset, [0.8, 0.1, 0.1], grid_size=8, generator=generator ) ) - if stage in ["fit"]: + if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( self.train_dataset, self.patch_size, self.batch_size, self.length ) - if stage in ["fit", "validate"]: + if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( self.val_dataset, self.patch_size, self.patch_size ) - if stage in ["test"]: + if stage in ['test']: self.test_sampler = GridGeoSampler( self.test_dataset, self.patch_size, self.patch_size ) diff --git a/torchgeo/datamodules/sentinel2_south_america_soybean.py b/torchgeo/datamodules/sentinel2_south_america_soybean.py index adf4ba194c4..e3363e857f5 100644 --- a/torchgeo/datamodules/sentinel2_south_america_soybean.py +++ b/torchgeo/datamodules/sentinel2_south_america_soybean.py @@ -48,9 +48,9 @@ def __init__( self.south_america_soybean_kwargs = {} self.sentinel2_kwargs = {} for key, val in kwargs.items(): - if key.startswith("south_america_soybean_"): + if key.startswith('south_america_soybean_'): self.south_america_soybean_kwargs[key[22:]] = val - elif key.startswith("sentinel2_"): + elif key.startswith('sentinel2_'): self.sentinel2_kwargs[key[10:]] = val super().__init__( @@ -67,14 +67,14 @@ def __init__( K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], extra_args={ - DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} + DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] + K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] ) def setup(self, stage: str) -> None: @@ -96,15 +96,15 @@ def setup(self, stage: str) -> None: ) ) - if stage in ["fit"]: + if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( self.train_dataset, self.patch_size, self.batch_size, self.length ) - if stage in ["fit", "validate"]: + if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( self.val_dataset, self.patch_size, self.patch_size ) - if stage in ["test"]: + if stage in ['test']: self.test_sampler = GridGeoSampler( self.test_dataset, self.patch_size, self.patch_size ) diff --git a/torchgeo/datamodules/skippd.py b/torchgeo/datamodules/skippd.py index b76eb3e1e92..f058cfbffdd 100644 --- a/torchgeo/datamodules/skippd.py +++ b/torchgeo/datamodules/skippd.py @@ -45,10 +45,10 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit", "validate"]: - self.dataset = SKIPPD(split="trainval", **self.kwargs) + if stage in ['fit', 'validate']: + self.dataset = SKIPPD(split='trainval', **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( self.dataset, val_pct=self.val_split_pct ) - if stage in ["test"]: - self.test_dataset = SKIPPD(split="test", **self.kwargs) + if stage in ['test']: + self.test_dataset = SKIPPD(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index ec3e85097e9..64701cef519 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -22,7 +22,7 @@ class So2SatDataModule(NonGeoDataModule): """ means_per_version: dict[str, Tensor] = { - "2": torch.tensor( + '2': torch.tensor( [ -0.00003591224260, -0.00000765856128, @@ -44,7 +44,7 @@ class So2SatDataModule(NonGeoDataModule): 0.10905050699570, ] ), - "3_random": torch.tensor( + '3_random': torch.tensor( [ -0.00005541164581, -0.00001363245448, @@ -66,7 +66,7 @@ class So2SatDataModule(NonGeoDataModule): 0.11122536338577, ] ), - "3_block": torch.tensor( + '3_block': torch.tensor( [ -0.00004632368791, 0.00001260869365, @@ -89,10 +89,10 @@ class So2SatDataModule(NonGeoDataModule): ] ), } - means_per_version["3_culture_10"] = means_per_version["2"] + means_per_version['3_culture_10'] = means_per_version['2'] stds_per_version: dict[str, Tensor] = { - "2": torch.tensor( + '2': torch.tensor( [ 0.17555201, 0.17556463, @@ -114,7 +114,7 @@ class So2SatDataModule(NonGeoDataModule): 0.08780632, ] ), - "3_random": torch.tensor( + '3_random': torch.tensor( [ 0.1756914, 0.1761190, @@ -136,7 +136,7 @@ class So2SatDataModule(NonGeoDataModule): 0.0873386, ] ), - "3_block": torch.tensor( + '3_block': torch.tensor( [ 0.1751797, 0.1754073, @@ -159,13 +159,13 @@ class So2SatDataModule(NonGeoDataModule): ] ), } - stds_per_version["3_culture_10"] = stds_per_version["2"] + stds_per_version['3_culture_10'] = stds_per_version['2'] def __init__( self, batch_size: int = 64, num_workers: int = 0, - band_set: str = "all", + band_set: str = 'all', val_split_pct: float = 0.2, **kwargs: Any, ) -> None: @@ -184,18 +184,18 @@ def __init__( The *val_split_pct* parameter, and the 'rgb' argument to *band_set*. """ # https://github.com/Lightning-AI/lightning/issues/18616 - kwargs["version"] = str(kwargs.get("version", "2")) - version = kwargs["version"] - kwargs["bands"] = So2Sat.BAND_SETS[band_set] + kwargs['version'] = str(kwargs.get('version', '2')) + version = kwargs['version'] + kwargs['bands'] = So2Sat.BAND_SETS[band_set] self.val_split_pct = val_split_pct - if band_set == "s1": + if band_set == 's1': self.mean = self.means_per_version[version][:8] self.std = self.stds_per_version[version][:8] - elif band_set == "s2": + elif band_set == 's2': self.mean = self.means_per_version[version][8:] self.std = self.stds_per_version[version][8:] - elif band_set == "rgb": + elif band_set == 'rgb': self.mean = self.means_per_version[version][[10, 9, 8]] self.std = self.stds_per_version[version][[10, 9, 8]] @@ -211,16 +211,16 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if self.kwargs.get("version", "2") == "2": - if stage in ["fit"]: - self.train_dataset = So2Sat(split="train", **self.kwargs) - if stage in ["fit", "validate"]: - self.val_dataset = So2Sat(split="validation", **self.kwargs) - if stage in ["test"]: - self.test_dataset = So2Sat(split="test", **self.kwargs) + if self.kwargs.get('version', '2') == '2': + if stage in ['fit']: + self.train_dataset = So2Sat(split='train', **self.kwargs) + if stage in ['fit', 'validate']: + self.val_dataset = So2Sat(split='validation', **self.kwargs) + if stage in ['test']: + self.test_dataset = So2Sat(split='test', **self.kwargs) else: - if stage in ["fit", "validate"]: - dataset = So2Sat(split="train", **self.kwargs) + if stage in ['fit', 'validate']: + dataset = So2Sat(split='train', **self.kwargs) val_length = round(len(dataset) * self.val_split_pct) train_length = len(dataset) - val_length self.train_dataset, self.val_dataset = random_split( @@ -228,5 +228,5 @@ def setup(self, stage: str) -> None: [train_length, val_length], generator=Generator().manual_seed(0), ) - if stage in ["test"]: - self.test_dataset = So2Sat(split="test", **self.kwargs) + if stage in ['test']: + self.test_dataset = So2Sat(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 3a3b7531cba..92039695533 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -53,12 +53,12 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) def setup(self, stage: str) -> None: @@ -87,6 +87,6 @@ def on_after_batch_transfer( # We add 1 to the mask to map the current {background, building} labels to # the values {1, 2}. This is necessary because we add 0 padding to the # mask that we want to ignore in the loss function. - batch["mask"] += 1 + batch['mask'] += 1 return super().on_after_batch_transfer(batch, dataloader_idx) diff --git a/torchgeo/datamodules/ssl4eo_benchmark.py b/torchgeo/datamodules/ssl4eo_benchmark.py index d1b9d51844e..c9eb1d2e315 100644 --- a/torchgeo/datamodules/ssl4eo_benchmark.py +++ b/torchgeo/datamodules/ssl4eo_benchmark.py @@ -45,18 +45,18 @@ def __init__( K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], extra_args={ - DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} + DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) self.val_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.CenterCrop(self.patch_size), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) self.test_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.CenterCrop(self.patch_size), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) diff --git a/torchgeo/datamodules/sustainbench_crop_yield.py b/torchgeo/datamodules/sustainbench_crop_yield.py index b3283bd36e7..7009dfb378f 100644 --- a/torchgeo/datamodules/sustainbench_crop_yield.py +++ b/torchgeo/datamodules/sustainbench_crop_yield.py @@ -34,9 +34,9 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit"]: - self.train_dataset = SustainBenchCropYield(split="train", **self.kwargs) - if stage in ["fit", "validate"]: - self.val_dataset = SustainBenchCropYield(split="dev", **self.kwargs) - if stage in ["test"]: - self.test_dataset = SustainBenchCropYield(split="test", **self.kwargs) + if stage in ['fit']: + self.train_dataset = SustainBenchCropYield(split='train', **self.kwargs) + if stage in ['fit', 'validate']: + self.val_dataset = SustainBenchCropYield(split='dev', **self.kwargs) + if stage in ['test']: + self.test_dataset = SustainBenchCropYield(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index dc8f43b5828..59bb49444ee 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -34,5 +34,5 @@ def __init__( self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.Resize(size=256), - data_keys=["image"], + data_keys=['image'], ) diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index a6069250d71..15cb796151d 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -50,28 +50,28 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: Returns: Augmented batch. """ - batch_len = len(batch["image"]) + batch_len = len(batch['image']) for bs in range(batch_len): batch_dict = { - "image": batch["image"][bs], - "labels": batch["labels"][bs], - "boxes": batch["boxes"][bs], + 'image': batch['image'][bs], + 'labels': batch['labels'][bs], + 'boxes': batch['boxes'][bs], } - if "masks" in batch: - batch_dict["masks"] = batch["masks"][bs] + if 'masks' in batch: + batch_dict['masks'] = batch['masks'][bs] batch_dict = self.augs(batch_dict) - batch["image"][bs] = batch_dict["image"] - batch["labels"][bs] = batch_dict["labels"] - batch["boxes"][bs] = batch_dict["boxes"] + batch['image'][bs] = batch_dict['image'] + batch['labels'][bs] = batch_dict['labels'] + batch['boxes'][bs] = batch_dict['boxes'] - if "masks" in batch: - batch["masks"][bs] = batch_dict["masks"] + if 'masks' in batch: + batch['masks'][bs] = batch_dict['masks'] # Stack images - batch["image"] = rearrange(batch["image"], "b () c h w -> b c h w") + batch['image'] = rearrange(batch['image'], 'b () c h w -> b c h w') return batch @@ -88,17 +88,17 @@ def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: .. versionadded:: 0.6 """ output: dict[str, Any] = {} - output["image"] = [sample["image"] for sample in batch] - output["boxes"] = [sample["boxes"].float() for sample in batch] - if "labels" in batch[0]: - output["labels"] = [sample["labels"] for sample in batch] + output['image'] = [sample['image'] for sample in batch] + output['boxes'] = [sample['boxes'].float() for sample in batch] + if 'labels' in batch[0]: + output['labels'] = [sample['labels'] for sample in batch] else: - output["labels"] = [ - torch.tensor([1] * len(sample["boxes"])) for sample in batch + output['labels'] = [ + torch.tensor([1] * len(sample['boxes'])) for sample in batch ] - if "masks" in batch[0]: - output["masks"] = [sample["masks"] for sample in batch] + if 'masks' in batch[0]: + output['masks'] = [sample['masks'] for sample in batch] return output @@ -169,11 +169,11 @@ def group_shuffle_split( ValueError if the number of training or testing groups turns out to be 0. """ if train_size is None and test_size is None: - raise ValueError("You must specify `train_size`, `test_size`, or both.") + raise ValueError('You must specify `train_size`, `test_size`, or both.') if (train_size is not None and test_size is not None) and ( not math.isclose(train_size + test_size, 1) ): - raise ValueError("`train_size` and `test_size` must sum to 1.") + raise ValueError('`train_size` and `test_size` must sum to 1.') if train_size is None and test_size is not None: train_size = 1 - test_size @@ -183,7 +183,7 @@ def group_shuffle_split( assert train_size is not None and test_size is not None if train_size <= 0 or train_size >= 1 or test_size <= 0 or test_size >= 1: - raise ValueError("`train_size` and `test_size` must be in the range (0,1).") + raise ValueError('`train_size` and `test_size` must be in the range (0,1).') group_vals = sorted(set(groups)) n_groups = len(group_vals) @@ -192,8 +192,8 @@ def group_shuffle_split( if n_train_groups == 0 or n_test_groups == 0: raise ValueError( - f"{n_groups} groups were found, however the current settings of " - + "`train_size` and `test_size` result in 0 training or testing groups." + f'{n_groups} groups were found, however the current settings of ' + + '`train_size` and `test_size` result in 0 training or testing groups.' ) generator = np.random.default_rng(seed=random_state) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 441fafdd9d2..586e3d1cea1 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -50,7 +50,7 @@ def __init__( self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=["image", "mask"], + data_keys=['image', 'mask'], ) def setup(self, stage: str) -> None: @@ -59,10 +59,10 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit", "validate"]: - self.dataset = Vaihingen2D(split="train", **self.kwargs) + if stage in ['fit', 'validate']: + self.dataset = Vaihingen2D(split='train', **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( self.dataset, self.val_split_pct ) - if stage in ["test"]: - self.test_dataset = Vaihingen2D(split="test", **self.kwargs) + if stage in ['test']: + self.test_dataset = Vaihingen2D(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 0bafef71b27..b86b36b7d16 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -58,7 +58,7 @@ def __init__( K.RandomHorizontalFlip(), K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=0.7), K.RandomVerticalFlip(), - data_keys=["image", "boxes", "masks"], + data_keys=['image', 'boxes', 'masks'], ), batch_size, ) @@ -66,7 +66,7 @@ def __init__( AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.Resize(self.patch_size), - data_keys=["image", "boxes", "masks"], + data_keys=['image', 'boxes', 'masks'], ), batch_size, ) diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index 8f96d786bea..2faceebe915 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -44,10 +44,10 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if stage in ["fit", "validate"]: - self.dataset = XView2(split="train", **self.kwargs) + if stage in ['fit', 'validate']: + self.dataset = XView2(split='train', **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( self.dataset, val_pct=self.val_split_pct ) - if stage in ["test"]: - self.test_dataset = XView2(split="test", **self.kwargs) + if stage in ['test']: + self.test_dataset = XView2(split='test', **self.kwargs) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 5f3f974e2b2..9837bf88e63 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -140,142 +140,142 @@ __all__ = ( # GeoDataset - "AbovegroundLiveWoodyBiomassDensity", - "AgriFieldNet", - "Airphen", - "AsterGDEM", - "CanadianBuildingFootprints", - "CDL", - "Chesapeake", - "Chesapeake7", - "Chesapeake13", - "ChesapeakeDC", - "ChesapeakeDE", - "ChesapeakeMD", - "ChesapeakeNY", - "ChesapeakePA", - "ChesapeakeVA", - "ChesapeakeWV", - "ChesapeakeCVPR", - "CMSGlobalMangroveCanopy", - "CropHarvest", - "EDDMapS", - "Esri2020", - "EuroCrops", - "EUDEM", - "GBIF", - "GlobBiomass", - "INaturalist", - "L7Irish", - "L8Biome", - "LandCoverAIBase", - "LandCoverAIGeo", - "Landsat", - "Landsat1", - "Landsat2", - "Landsat3", - "Landsat4MSS", - "Landsat4TM", - "Landsat5MSS", - "Landsat5TM", - "Landsat7", - "Landsat8", - "Landsat9", - "NAIP", - "NCCM", - "NLCD", - "OpenBuildings", - "PRISMA", - "Sentinel", - "Sentinel1", - "Sentinel2", - "SouthAfricaCropType", - "SouthAmericaSoybean", + 'AbovegroundLiveWoodyBiomassDensity', + 'AgriFieldNet', + 'Airphen', + 'AsterGDEM', + 'CanadianBuildingFootprints', + 'CDL', + 'Chesapeake', + 'Chesapeake7', + 'Chesapeake13', + 'ChesapeakeDC', + 'ChesapeakeDE', + 'ChesapeakeMD', + 'ChesapeakeNY', + 'ChesapeakePA', + 'ChesapeakeVA', + 'ChesapeakeWV', + 'ChesapeakeCVPR', + 'CMSGlobalMangroveCanopy', + 'CropHarvest', + 'EDDMapS', + 'Esri2020', + 'EuroCrops', + 'EUDEM', + 'GBIF', + 'GlobBiomass', + 'INaturalist', + 'L7Irish', + 'L8Biome', + 'LandCoverAIBase', + 'LandCoverAIGeo', + 'Landsat', + 'Landsat1', + 'Landsat2', + 'Landsat3', + 'Landsat4MSS', + 'Landsat4TM', + 'Landsat5MSS', + 'Landsat5TM', + 'Landsat7', + 'Landsat8', + 'Landsat9', + 'NAIP', + 'NCCM', + 'NLCD', + 'OpenBuildings', + 'PRISMA', + 'Sentinel', + 'Sentinel1', + 'Sentinel2', + 'SouthAfricaCropType', + 'SouthAmericaSoybean', # NonGeoDataset - "ADVANCE", - "BeninSmallHolderCashews", - "BigEarthNet", - "BioMassters", - "ChaBuD", - "CloudCoverDetection", - "COWC", - "COWCCounting", - "COWCDetection", - "CV4AKenyaCropType", - "DeepGlobeLandCover", - "DFC2022", - "EnviroAtlas", - "ETCI2021", - "EuroSAT", - "EuroSAT100", - "FAIR1M", - "FireRisk", - "ForestDamage", - "GID15", - "IDTReeS", - "InriaAerialImageLabeling", - "LandCoverAI", - "LEVIRCD", - "LEVIRCDBase", - "LEVIRCDPlus", - "LoveDA", - "MapInWild", - "MillionAID", - "NASAMarineDebris", - "OSCD", - "PASTIS", - "PatternNet", - "Potsdam2D", - "RESISC45", - "ReforesTree", - "RwandaFieldBoundary", - "SeasonalContrastS2", - "SeasoNet", - "SEN12MS", - "SKIPPD", - "So2Sat", - "SpaceNet", - "SpaceNet1", - "SpaceNet2", - "SpaceNet3", - "SpaceNet4", - "SpaceNet5", - "SpaceNet6", - "SpaceNet7", - "SSL4EO", - "SSL4EOLBenchmark", - "SSL4EOL", - "SSL4EOS12", - "SustainBenchCropYield", - "TropicalCyclone", - "UCMerced", - "USAVars", - "Vaihingen2D", - "VHR10", - "WesternUSALiveFuelMoisture", - "XView2", - "ZueriCrop", + 'ADVANCE', + 'BeninSmallHolderCashews', + 'BigEarthNet', + 'BioMassters', + 'ChaBuD', + 'CloudCoverDetection', + 'COWC', + 'COWCCounting', + 'COWCDetection', + 'CV4AKenyaCropType', + 'DeepGlobeLandCover', + 'DFC2022', + 'EnviroAtlas', + 'ETCI2021', + 'EuroSAT', + 'EuroSAT100', + 'FAIR1M', + 'FireRisk', + 'ForestDamage', + 'GID15', + 'IDTReeS', + 'InriaAerialImageLabeling', + 'LandCoverAI', + 'LEVIRCD', + 'LEVIRCDBase', + 'LEVIRCDPlus', + 'LoveDA', + 'MapInWild', + 'MillionAID', + 'NASAMarineDebris', + 'OSCD', + 'PASTIS', + 'PatternNet', + 'Potsdam2D', + 'RESISC45', + 'ReforesTree', + 'RwandaFieldBoundary', + 'SeasonalContrastS2', + 'SeasoNet', + 'SEN12MS', + 'SKIPPD', + 'So2Sat', + 'SpaceNet', + 'SpaceNet1', + 'SpaceNet2', + 'SpaceNet3', + 'SpaceNet4', + 'SpaceNet5', + 'SpaceNet6', + 'SpaceNet7', + 'SSL4EO', + 'SSL4EOLBenchmark', + 'SSL4EOL', + 'SSL4EOS12', + 'SustainBenchCropYield', + 'TropicalCyclone', + 'UCMerced', + 'USAVars', + 'Vaihingen2D', + 'VHR10', + 'WesternUSALiveFuelMoisture', + 'XView2', + 'ZueriCrop', # Base classes - "GeoDataset", - "IntersectionDataset", - "NonGeoClassificationDataset", - "NonGeoDataset", - "RasterDataset", - "UnionDataset", - "VectorDataset", + 'GeoDataset', + 'IntersectionDataset', + 'NonGeoClassificationDataset', + 'NonGeoDataset', + 'RasterDataset', + 'UnionDataset', + 'VectorDataset', # Utilities - "BoundingBox", - "concat_samples", - "merge_samples", - "stack_samples", - "unbind_samples", + 'BoundingBox', + 'concat_samples', + 'merge_samples', + 'stack_samples', + 'unbind_samples', # Splits - "random_bbox_assignment", - "random_bbox_splitting", - "random_grid_cell_assignment", - "roi_split", - "time_series_split", + 'random_bbox_assignment', + 'random_bbox_splitting', + 'random_grid_cell_assignment', + 'roi_split', + 'time_series_split', # Errors - "DatasetNotFoundError", - "RGBBandsMissingError", + 'DatasetNotFoundError', + 'RGBBandsMissingError', ) diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py index 139bba0f404..a2fa1e9e4ea 100644 --- a/torchgeo/datasets/advance.py +++ b/torchgeo/datasets/advance.py @@ -63,31 +63,31 @@ class ADVANCE(NonGeoDataset): """ urls = [ - "https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1", - "https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1", + 'https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1', + 'https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1', ] - filenames = ["ADVANCE_vision.zip", "ADVANCE_sound.zip"] - md5s = ["a9e8748219ef5864d3b5a8979a67b471", "a2d12f2d2a64f5c3d3a9d8c09aaf1c31"] - directories = ["vision", "sound"] + filenames = ['ADVANCE_vision.zip', 'ADVANCE_sound.zip'] + md5s = ['a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31'] + directories = ['vision', 'sound'] classes = [ - "airport", - "beach", - "bridge", - "farmland", - "forest", - "grassland", - "harbour", - "lake", - "orchard", - "residential", - "sparse shrub land", - "sports land", - "train station", + 'airport', + 'beach', + 'bridge', + 'farmland', + 'forest', + 'grassland', + 'harbour', + 'lake', + 'orchard', + 'residential', + 'sparse shrub land', + 'sports land', + 'train station', ] def __init__( self, - root: str = "data", + root: str = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -115,7 +115,7 @@ def __init__( raise DatasetNotFoundError(self) self.files = self._load_files(self.root) - self.classes = sorted({f["cls"] for f in self.files}) + self.classes = sorted({f['cls'] for f in self.files}) self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} def __getitem__(self, index: int) -> dict[str, Tensor]: @@ -128,11 +128,11 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - image = self._load_image(files["image"]) - audio = self._load_target(files["audio"]) - cls_label = self.class_to_idx[files["cls"]] + image = self._load_image(files['image']) + audio = self._load_target(files['audio']) + cls_label = self.class_to_idx[files['cls']] label = torch.tensor(cls_label, dtype=torch.long) - sample = {"image": image, "audio": audio, "label": label} + sample = {'image': image, 'audio': audio, 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -156,8 +156,8 @@ def _load_files(self, root: str) -> list[dict[str, str]]: Returns: list of dicts containing paths for each pair of image, audio, label """ - images = sorted(glob.glob(os.path.join(root, "vision", "**", "*.jpg"))) - wavs = sorted(glob.glob(os.path.join(root, "sound", "**", "*.wav"))) + images = sorted(glob.glob(os.path.join(root, 'vision', '**', '*.jpg'))) + wavs = sorted(glob.glob(os.path.join(root, 'sound', '**', '*.wav'))) labels = [image.split(os.sep)[-2] for image in images] files = [ dict(image=image, audio=wav, cls=label) @@ -175,7 +175,7 @@ def _load_image(self, path: str) -> Tensor: the image """ with Image.open(path) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('RGB')) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -194,7 +194,7 @@ def _load_target(self, path: str) -> Tensor: from scipy.io import wavfile except ImportError: raise ImportError( - "scipy is not installed and is required to use this dataset" + 'scipy is not installed and is required to use this dataset' ) array = wavfile.read(path, mmap=True)[1] @@ -217,7 +217,7 @@ def _check_integrity(self) -> bool: def _download(self) -> None: """Download the dataset and extract it.""" if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return for filename, url, md5 in zip(self.filenames, self.urls, self.md5s): @@ -243,22 +243,22 @@ def plot( .. versionadded:: 0.2 """ - image = np.rollaxis(sample["image"].numpy(), 0, 3) - label = cast(int, sample["label"].item()) + image = np.rollaxis(sample['image'].numpy(), 0, 3) + label = cast(int, sample['label'].item()) label_class = self.classes[label] - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = cast(int, sample["prediction"].item()) + prediction = cast(int, sample['prediction'].item()) prediction_class = self.classes[prediction] fig, ax = plt.subplots(figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: - title = f"Label: {label_class}" + title = f'Label: {label_class}' if showing_predictions: - title += f"\nPrediction: {prediction_class}" + title += f'\nPrediction: {prediction_class}' ax.set_title(title) if suptitle is not None: diff --git a/torchgeo/datasets/agb_live_woody_density.py b/torchgeo/datasets/agb_live_woody_density.py index b12da8f50cf..4e094c516fc 100644 --- a/torchgeo/datasets/agb_live_woody_density.py +++ b/torchgeo/datasets/agb_live_woody_density.py @@ -44,11 +44,11 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset): is_image = False - url = "https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326" # noqa: E501 + url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326' # noqa: E501 - base_filename = "Aboveground_Live_Woody_Biomass_Density.geojson" + base_filename = 'Aboveground_Live_Woody_Biomass_Density.geojson' - filename_glob = "*N_*E.*" + filename_glob = '*N_*E.*' filename_regex = r"""^ (?P[0-9][0-9][A-Z])_ (?P[0-9][0-9][0-9][A-Z])* @@ -56,7 +56,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset): def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -110,11 +110,11 @@ def _download(self) -> None: with open(os.path.join(self.paths, self.base_filename)) as f: content = json.load(f) - for item in content["features"]: + for item in content['features']: download_url( - item["properties"]["Mg_px_1_download"], + item['properties']['Mg_px_1_download'], self.paths, - item["properties"]["tile_id"] + ".tif", + item['properties']['tile_id'] + '.tif', ) def plot( @@ -133,29 +133,29 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - mask = sample["mask"].squeeze() + mask = sample['mask'].squeeze() ncols = 1 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - pred = sample["prediction"].squeeze() + pred = sample['prediction'].squeeze() ncols = 2 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4)) if showing_predictions: axs[0].imshow(mask) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(pred) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Mask") - axs[1].set_title("Prediction") + axs[0].set_title('Mask') + axs[1].set_title('Prediction') else: axs.imshow(mask) - axs.axis("off") + axs.axis('off') if show_titles: - axs.set_title("Mask") + axs.set_title('Mask') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py index e47b93c5db4..b8e9437be4c 100644 --- a/torchgeo/datasets/agrifieldnet.py +++ b/torchgeo/datasets/agrifieldnet.py @@ -78,20 +78,20 @@ class AgriFieldNet(RasterDataset): _(?PB[0-9A-Z]{2})_10m """ - rgb_bands = ["B04", "B03", "B02"] + rgb_bands = ['B04', 'B03', 'B02'] all_bands = [ - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B11", - "B12", + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', ] cmap = { @@ -113,7 +113,7 @@ class AgriFieldNet(RasterDataset): def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, classes: list[int] = list(cmap.keys()), bands: Sequence[str] = all_bands, @@ -138,8 +138,8 @@ def __init__( """ assert ( set(classes) <= self.cmap.keys() - ), f"Only the following classes are valid: {list(self.cmap.keys())}." - assert 0 in classes, "Classes must include the background class: 0" + ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths self.classes = classes @@ -171,7 +171,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if not filepaths: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) data_list: list[Tensor] = [] @@ -183,9 +183,9 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: directory = os.path.dirname(filepath) match = re.match(filename_regex, filename) if match: - if "band" in match.groupdict(): - start = match.start("band") - end = match.end("band") + if 'band' in match.groupdict(): + start = match.start('band') + end = match.end('band') filename = filename[:start] + band + filename[end:] filepath = os.path.join(directory, filename) band_filepaths.append(filepath) @@ -193,9 +193,9 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: image = torch.cat(data_list) mask_filepaths = [] - for root, dirs, files in os.walk(os.path.join(self.paths, "train_labels")): + for root, dirs, files in os.walk(os.path.join(self.paths, 'train_labels')): for file in files: - if not file.endswith("_field_ids.tif") and file.endswith(".tif"): + if not file.endswith('_field_ids.tif') and file.endswith('.tif'): file_path = os.path.join(root, file) mask_filepaths.append(file_path) @@ -203,10 +203,10 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: mask = self.ordinal_map[mask.squeeze().long()] sample = { - "crs": self.crs, - "bbox": query, - "image": image.float(), - "mask": mask.long(), + 'crs': self.crs, + 'bbox': query, + 'image': image.float(), + 'mask': mask.long(), } if self.transforms is not None: @@ -240,31 +240,31 @@ def plot( else: raise RGBBandsMissingError() - image = sample["image"][rgb_indices].permute(1, 2, 0) + image = sample['image'][rgb_indices].permute(1, 2, 0) image = (image - image.min()) / (image.max() - image.min()) - mask = sample["mask"].squeeze() + mask = sample['mask'].squeeze() ncols = 2 - showing_prediction = "prediction" in sample + showing_prediction = 'prediction' in sample if showing_prediction: - pred = sample["prediction"].squeeze() + pred = sample['prediction'].squeeze() ncols += 1 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4)) axs[0].imshow(image) - axs[0].axis("off") - axs[1].imshow(self.ordinal_cmap[mask], interpolation="none") - axs[1].axis("off") + axs[0].axis('off') + axs[1].imshow(self.ordinal_cmap[mask], interpolation='none') + axs[1].axis('off') if show_titles: - axs[0].set_title("Image") - axs[1].set_title("Mask") + axs[0].set_title('Image') + axs[1].set_title('Mask') if showing_prediction: - axs[2].imshow(self.ordinal_cmap[pred], interpolation="none") - axs[2].axis("off") + axs[2].imshow(self.ordinal_cmap[pred], interpolation='none') + axs[2].axis('off') if show_titles: - axs[2].set_title("Prediction") + axs[2].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/airphen.py b/torchgeo/datasets/airphen.py index a38db917a03..03e21b10e01 100644 --- a/torchgeo/datasets/airphen.py +++ b/torchgeo/datasets/airphen.py @@ -39,8 +39,8 @@ class Airphen(RasterDataset): # Each camera measures a custom set of spectral bands chosen at purchase time. # Hiphen offers 8 bands to choose from, sorted from short to long wavelength. - all_bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8"] - rgb_bands = ["B4", "B3", "B1"] + all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8'] + rgb_bands = ['B4', 'B3', 'B1'] def plot( self, @@ -68,15 +68,15 @@ def plot( else: raise RGBBandsMissingError() - image = sample["image"][rgb_indices].permute(1, 2, 0).float() + image = sample['image'][rgb_indices].permute(1, 2, 0).float() image = percentile_normalization(image, axis=(0, 1)) fig, ax = plt.subplots(1, 1, figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: - ax.set_title("Image") + ax.set_title('Image') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/astergdem.py b/torchgeo/datasets/astergdem.py index 801b24a52fe..e2e9d3c745e 100644 --- a/torchgeo/datasets/astergdem.py +++ b/torchgeo/datasets/astergdem.py @@ -38,7 +38,7 @@ class AsterGDEM(RasterDataset): """ is_image = False - filename_glob = "ASTGTMV003_*_dem*" + filename_glob = 'ASTGTMV003_*_dem*' filename_regex = r""" (?P[ASTGTMV003]{10}) _(?P[A-Z0-9]{7}) @@ -47,7 +47,7 @@ class AsterGDEM(RasterDataset): def __init__( self, - paths: str | list[str] = "data", + paths: str | list[str] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -102,29 +102,29 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - mask = sample["mask"].squeeze() + mask = sample['mask'].squeeze() ncols = 1 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = sample["prediction"].squeeze() + prediction = sample['prediction'].squeeze() ncols = 2 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4 * ncols, 4)) if showing_predictions: axs[0].imshow(mask) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(prediction) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Mask") - axs[1].set_title("Prediction") + axs[0].set_title('Mask') + axs[1].set_title('Prediction') else: axs.imshow(mask) - axs.axis("off") + axs.axis('off') if show_titles: - axs.set_title("Mask") + axs.set_title('Mask') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index 0aefc55d7a4..4c9a0e0fcd7 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -62,114 +62,114 @@ class BeninSmallHolderCashews(NonGeoDataset): imagery and labels from the Radiant Earth MLHub """ - dataset_id = "ts_cashew_benin" - collection_ids = ["ts_cashew_benin_source", "ts_cashew_benin_labels"] + dataset_id = 'ts_cashew_benin' + collection_ids = ['ts_cashew_benin_source', 'ts_cashew_benin_labels'] image_meta = { - "filename": "ts_cashew_benin_source.tar.gz", - "md5": "957272c86e518a925a4e0d90dab4f92d", + 'filename': 'ts_cashew_benin_source.tar.gz', + 'md5': '957272c86e518a925a4e0d90dab4f92d', } target_meta = { - "filename": "ts_cashew_benin_labels.tar.gz", - "md5": "f9d3f0c671427d852fae9b52a0ae0051", + 'filename': 'ts_cashew_benin_labels.tar.gz', + 'md5': 'f9d3f0c671427d852fae9b52a0ae0051', } dates = ( - "2019_11_05", - "2019_11_10", - "2019_11_15", - "2019_11_20", - "2019_11_30", - "2019_12_05", - "2019_12_10", - "2019_12_15", - "2019_12_20", - "2019_12_25", - "2019_12_30", - "2020_01_04", - "2020_01_09", - "2020_01_14", - "2020_01_19", - "2020_01_24", - "2020_01_29", - "2020_02_08", - "2020_02_13", - "2020_02_18", - "2020_02_23", - "2020_02_28", - "2020_03_04", - "2020_03_09", - "2020_03_14", - "2020_03_19", - "2020_03_24", - "2020_03_29", - "2020_04_03", - "2020_04_08", - "2020_04_13", - "2020_04_18", - "2020_04_23", - "2020_04_28", - "2020_05_03", - "2020_05_08", - "2020_05_13", - "2020_05_18", - "2020_05_23", - "2020_05_28", - "2020_06_02", - "2020_06_07", - "2020_06_12", - "2020_06_17", - "2020_06_22", - "2020_06_27", - "2020_07_02", - "2020_07_07", - "2020_07_12", - "2020_07_17", - "2020_07_22", - "2020_07_27", - "2020_08_01", - "2020_08_06", - "2020_08_11", - "2020_08_16", - "2020_08_21", - "2020_08_26", - "2020_08_31", - "2020_09_05", - "2020_09_10", - "2020_09_15", - "2020_09_20", - "2020_09_25", - "2020_09_30", - "2020_10_10", - "2020_10_15", - "2020_10_20", - "2020_10_25", - "2020_10_30", + '2019_11_05', + '2019_11_10', + '2019_11_15', + '2019_11_20', + '2019_11_30', + '2019_12_05', + '2019_12_10', + '2019_12_15', + '2019_12_20', + '2019_12_25', + '2019_12_30', + '2020_01_04', + '2020_01_09', + '2020_01_14', + '2020_01_19', + '2020_01_24', + '2020_01_29', + '2020_02_08', + '2020_02_13', + '2020_02_18', + '2020_02_23', + '2020_02_28', + '2020_03_04', + '2020_03_09', + '2020_03_14', + '2020_03_19', + '2020_03_24', + '2020_03_29', + '2020_04_03', + '2020_04_08', + '2020_04_13', + '2020_04_18', + '2020_04_23', + '2020_04_28', + '2020_05_03', + '2020_05_08', + '2020_05_13', + '2020_05_18', + '2020_05_23', + '2020_05_28', + '2020_06_02', + '2020_06_07', + '2020_06_12', + '2020_06_17', + '2020_06_22', + '2020_06_27', + '2020_07_02', + '2020_07_07', + '2020_07_12', + '2020_07_17', + '2020_07_22', + '2020_07_27', + '2020_08_01', + '2020_08_06', + '2020_08_11', + '2020_08_16', + '2020_08_21', + '2020_08_26', + '2020_08_31', + '2020_09_05', + '2020_09_10', + '2020_09_15', + '2020_09_20', + '2020_09_25', + '2020_09_30', + '2020_10_10', + '2020_10_15', + '2020_10_20', + '2020_10_25', + '2020_10_30', ) all_bands = ( - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B11", - "B12", - "CLD", + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', + 'CLD', ) - rgb_bands = ("B04", "B03", "B02") + rgb_bands = ('B04', 'B03', 'B02') classes = [ - "No data", - "Well-managed planatation", - "Poorly-managed planatation", - "Non-planatation", - "Residential", - "Background", - "Uncertain", + 'No data', + 'Well-managed planatation', + 'Poorly-managed planatation', + 'Non-planatation', + 'Residential', + 'Background', + 'Uncertain', ] # Same for all tiles @@ -178,7 +178,7 @@ class BeninSmallHolderCashews(NonGeoDataset): def __init__( self, - root: str = "data", + root: str = 'data', chip_size: int = 256, stride: int = 128, bands: tuple[str, ...] = all_bands, @@ -250,12 +250,12 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: labels = labels[y : y + self.chip_size, x : x + self.chip_size] sample = { - "image": img, - "mask": labels, - "x": torch.tensor(x), - "y": torch.tensor(y), - "transform": transform, - "crs": crs, + 'image': img, + 'mask': labels, + 'x': torch.tensor(x), + 'y': torch.tensor(y), + 'transform': transform, + 'crs': crs, } if self.transforms is not None: @@ -281,7 +281,7 @@ def _validate_bands(self, bands: tuple[str, ...]) -> None: AssertionError: if ``bands`` is not a tuple ValueError: if an invalid band name is provided """ - assert isinstance(bands, tuple), "The list of bands must be a tuple" + assert isinstance(bands, tuple), 'The list of bands must be a tuple' for band in bands: if band not in self.all_bands: raise ValueError(f"'{band}' is an invalid band name.") @@ -304,7 +304,7 @@ def _load_all_imagery( coordinate reference system of transform """ if self.verbose: - print("Loading all imagery") + print('Loading all imagery') img = torch.zeros( len(self.dates), @@ -343,7 +343,7 @@ def _load_single_scene( assert date in self.dates if self.verbose: - print(f"Loading imagery at {date}") + print(f'Loading imagery at {date}') img = torch.zeros( len(bands), self.tile_height, self.tile_width, dtype=torch.float32 @@ -351,9 +351,9 @@ def _load_single_scene( for band_index, band_name in enumerate(self.bands): filepath = os.path.join( self.root, - "ts_cashew_benin_source", - f"ts_cashew_benin_source_00_{date}", - f"{band_name}.tif", + 'ts_cashew_benin_source', + f'ts_cashew_benin_source_00_{date}', + f'{band_name}.tif', ) with rasterio.open(filepath) as src: transform = src.transform # same transform for every bands @@ -368,14 +368,14 @@ def _load_mask(self, transform: rasterio.Affine) -> Tensor: """Rasterizes the dataset's labels (in geojson format).""" # Create a mask layer out of the geojson mask_geojson_fn = os.path.join( - self.root, "ts_cashew_benin_labels", "_common", "labels.geojson" + self.root, 'ts_cashew_benin_labels', '_common', 'labels.geojson' ) with open(mask_geojson_fn) as f: geojson = json.load(f) labels = [ - (feature["geometry"], feature["properties"]["class"]) - for feature in geojson["features"] + (feature['geometry'], feature['properties']['class']) + for feature in geojson['features'] ] mask_data = rasterio.features.rasterize( @@ -397,13 +397,13 @@ def _check_integrity(self) -> bool: True if dataset files are found and/or MD5s match, else False """ images: bool = check_integrity( - os.path.join(self.root, self.image_meta["filename"]), - self.image_meta["md5"] if self.checksum else None, + os.path.join(self.root, self.image_meta['filename']), + self.image_meta['md5'] if self.checksum else None, ) targets: bool = check_integrity( - os.path.join(self.root, self.target_meta["filename"]), - self.target_meta["md5"] if self.checksum else None, + os.path.join(self.root, self.target_meta['filename']), + self.target_meta['md5'] if self.checksum else None, ) return images and targets @@ -418,14 +418,14 @@ def _download(self, api_key: str | None = None) -> None: RuntimeError: if download doesn't work correctly or checksums don't match """ if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return for collection_id in self.collection_ids: download_radiant_mlhub_collection(collection_id, self.root, api_key) - image_archive_path = os.path.join(self.root, self.image_meta["filename"]) - target_archive_path = os.path.join(self.root, self.target_meta["filename"]) + image_archive_path = os.path.join(self.root, self.image_meta['filename']) + target_archive_path = os.path.join(self.root, self.target_meta['filename']) for fn in [image_archive_path, target_archive_path]: extract_archive(fn, self.root) @@ -459,36 +459,36 @@ def plot( else: raise RGBBandsMissingError() - num_time_points = sample["image"].shape[0] + num_time_points = sample['image'].shape[0] assert time_step < num_time_points - image = np.rollaxis(sample["image"][time_step, rgb_indices].numpy(), 0, 3) + image = np.rollaxis(sample['image'][time_step, rgb_indices].numpy(), 0, 3) image = np.clip(image / 3000, 0, 1) - mask = sample["mask"].numpy() + mask = sample['mask'].numpy() num_panels = 2 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - predictions = sample["prediction"].numpy() + predictions = sample['prediction'].numpy() num_panels += 1 fig, axs = plt.subplots(ncols=num_panels, figsize=(4 * num_panels, 4)) axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') if show_titles: - axs[0].set_title(f"t={time_step}") + axs[0].set_title(f't={time_step}') - axs[1].imshow(mask, vmin=0, vmax=6, interpolation="none") - axs[1].axis("off") + axs[1].imshow(mask, vmin=0, vmax=6, interpolation='none') + axs[1].axis('off') if show_titles: - axs[1].set_title("Mask") + axs[1].set_title('Mask') if showing_predictions: - axs[2].imshow(predictions, vmin=0, vmax=6, interpolation="none") - axs[2].axis("off") + axs[2].imshow(predictions, vmin=0, vmax=6, interpolation='none') + axs[2].axis('off') if show_titles: - axs[2].set_title("Predictions") + axs[2].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 8053e5692d6..01894da58f1 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -132,72 +132,72 @@ class BigEarthNet(NonGeoDataset): class_sets = { 19: [ - "Urban fabric", - "Industrial or commercial units", - "Arable land", - "Permanent crops", - "Pastures", - "Complex cultivation patterns", - "Land principally occupied by agriculture, with significant areas of" - " natural vegetation", - "Agro-forestry areas", - "Broad-leaved forest", - "Coniferous forest", - "Mixed forest", - "Natural grassland and sparsely vegetated areas", - "Moors, heathland and sclerophyllous vegetation", - "Transitional woodland, shrub", - "Beaches, dunes, sands", - "Inland wetlands", - "Coastal wetlands", - "Inland waters", - "Marine waters", + 'Urban fabric', + 'Industrial or commercial units', + 'Arable land', + 'Permanent crops', + 'Pastures', + 'Complex cultivation patterns', + 'Land principally occupied by agriculture, with significant areas of' + ' natural vegetation', + 'Agro-forestry areas', + 'Broad-leaved forest', + 'Coniferous forest', + 'Mixed forest', + 'Natural grassland and sparsely vegetated areas', + 'Moors, heathland and sclerophyllous vegetation', + 'Transitional woodland, shrub', + 'Beaches, dunes, sands', + 'Inland wetlands', + 'Coastal wetlands', + 'Inland waters', + 'Marine waters', ], 43: [ - "Continuous urban fabric", - "Discontinuous urban fabric", - "Industrial or commercial units", - "Road and rail networks and associated land", - "Port areas", - "Airports", - "Mineral extraction sites", - "Dump sites", - "Construction sites", - "Green urban areas", - "Sport and leisure facilities", - "Non-irrigated arable land", - "Permanently irrigated land", - "Rice fields", - "Vineyards", - "Fruit trees and berry plantations", - "Olive groves", - "Pastures", - "Annual crops associated with permanent crops", - "Complex cultivation patterns", - "Land principally occupied by agriculture, with significant areas of" - " natural vegetation", - "Agro-forestry areas", - "Broad-leaved forest", - "Coniferous forest", - "Mixed forest", - "Natural grassland", - "Moors and heathland", - "Sclerophyllous vegetation", - "Transitional woodland/shrub", - "Beaches, dunes, sands", - "Bare rock", - "Sparsely vegetated areas", - "Burnt areas", - "Inland marshes", - "Peatbogs", - "Salt marshes", - "Salines", - "Intertidal flats", - "Water courses", - "Water bodies", - "Coastal lagoons", - "Estuaries", - "Sea and ocean", + 'Continuous urban fabric', + 'Discontinuous urban fabric', + 'Industrial or commercial units', + 'Road and rail networks and associated land', + 'Port areas', + 'Airports', + 'Mineral extraction sites', + 'Dump sites', + 'Construction sites', + 'Green urban areas', + 'Sport and leisure facilities', + 'Non-irrigated arable land', + 'Permanently irrigated land', + 'Rice fields', + 'Vineyards', + 'Fruit trees and berry plantations', + 'Olive groves', + 'Pastures', + 'Annual crops associated with permanent crops', + 'Complex cultivation patterns', + 'Land principally occupied by agriculture, with significant areas of' + ' natural vegetation', + 'Agro-forestry areas', + 'Broad-leaved forest', + 'Coniferous forest', + 'Mixed forest', + 'Natural grassland', + 'Moors and heathland', + 'Sclerophyllous vegetation', + 'Transitional woodland/shrub', + 'Beaches, dunes, sands', + 'Bare rock', + 'Sparsely vegetated areas', + 'Burnt areas', + 'Inland marshes', + 'Peatbogs', + 'Salt marshes', + 'Salines', + 'Intertidal flats', + 'Water courses', + 'Water bodies', + 'Coastal lagoons', + 'Estuaries', + 'Sea and ocean', ], } @@ -237,43 +237,43 @@ class BigEarthNet(NonGeoDataset): } splits_metadata = { - "train": { - "url": "https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false", # noqa: E501 - "filename": "bigearthnet-train.csv", - "md5": "623e501b38ab7b12fe44f0083c00986d", + 'train': { + 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false', # noqa: E501 + 'filename': 'bigearthnet-train.csv', + 'md5': '623e501b38ab7b12fe44f0083c00986d', }, - "val": { - "url": "https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false", # noqa: E501 - "filename": "bigearthnet-val.csv", - "md5": "22efe8ed9cbd71fa10742ff7df2b7978", + 'val': { + 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false', # noqa: E501 + 'filename': 'bigearthnet-val.csv', + 'md5': '22efe8ed9cbd71fa10742ff7df2b7978', }, - "test": { - "url": "https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false", # noqa: E501 - "filename": "bigearthnet-test.csv", - "md5": "697fb90677e30571b9ac7699b7e5b432", + 'test': { + 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false', # noqa: E501 + 'filename': 'bigearthnet-test.csv', + 'md5': '697fb90677e30571b9ac7699b7e5b432', }, } metadata = { - "s1": { - "url": "https://bigearth.net/downloads/BigEarthNet-S1-v1.0.tar.gz", - "md5": "94ced73440dea8c7b9645ee738c5a172", - "filename": "BigEarthNet-S1-v1.0.tar.gz", - "directory": "BigEarthNet-S1-v1.0", + 's1': { + 'url': 'https://bigearth.net/downloads/BigEarthNet-S1-v1.0.tar.gz', + 'md5': '94ced73440dea8c7b9645ee738c5a172', + 'filename': 'BigEarthNet-S1-v1.0.tar.gz', + 'directory': 'BigEarthNet-S1-v1.0', }, - "s2": { - "url": "https://bigearth.net/downloads/BigEarthNet-S2-v1.0.tar.gz", - "md5": "5a64e9ce38deb036a435a7b59494924c", - "filename": "BigEarthNet-S2-v1.0.tar.gz", - "directory": "BigEarthNet-v1.0", + 's2': { + 'url': 'https://bigearth.net/downloads/BigEarthNet-S2-v1.0.tar.gz', + 'md5': '5a64e9ce38deb036a435a7b59494924c', + 'filename': 'BigEarthNet-S2-v1.0.tar.gz', + 'directory': 'BigEarthNet-v1.0', }, } image_size = (120, 120) def __init__( self, - root: str = "data", - split: str = "train", - bands: str = "all", + root: str = 'data', + split: str = 'train', + bands: str = 'all', num_classes: int = 19, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -295,7 +295,7 @@ def __init__( DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits_metadata - assert bands in ["s1", "s2", "all"] + assert bands in ['s1', 's2', 'all'] assert num_classes in [43, 19] self.root = root self.split = split @@ -319,7 +319,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ image = self._load_image(index) label = self._load_target(index) - sample: dict[str, Tensor] = {"image": image, "label": label} + sample: dict[str, Tensor] = {'image': image, 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -340,18 +340,18 @@ def _load_folders(self) -> list[dict[str, str]]: Returns: list of dicts of s1 and s2 folder paths """ - filename = self.splits_metadata[self.split]["filename"] - dir_s1 = self.metadata["s1"]["directory"] - dir_s2 = self.metadata["s2"]["directory"] + filename = self.splits_metadata[self.split]['filename'] + dir_s1 = self.metadata['s1']['directory'] + dir_s2 = self.metadata['s2']['directory'] with open(os.path.join(self.root, filename)) as f: lines = f.read().strip().splitlines() - pairs = [line.split(",") for line in lines] + pairs = [line.split(',') for line in lines] folders = [ { - "s1": os.path.join(self.root, dir_s1, pair[1]), - "s2": os.path.join(self.root, dir_s2, pair[0]), + 's1': os.path.join(self.root, dir_s1, pair[1]), + 's2': os.path.join(self.root, dir_s2, pair[0]), } for pair in pairs ] @@ -366,21 +366,21 @@ def _load_paths(self, index: int) -> list[str]: Returns: list of file paths """ - if self.bands == "all": - folder_s1 = self.folders[index]["s1"] - folder_s2 = self.folders[index]["s2"] - paths_s1 = glob.glob(os.path.join(folder_s1, "*.tif")) - paths_s2 = glob.glob(os.path.join(folder_s2, "*.tif")) + if self.bands == 'all': + folder_s1 = self.folders[index]['s1'] + folder_s2 = self.folders[index]['s2'] + paths_s1 = glob.glob(os.path.join(folder_s1, '*.tif')) + paths_s2 = glob.glob(os.path.join(folder_s2, '*.tif')) paths_s1 = sorted(paths_s1) paths_s2 = sorted(paths_s2, key=sort_sentinel2_bands) paths = paths_s1 + paths_s2 - elif self.bands == "s1": - folder = self.folders[index]["s1"] - paths = glob.glob(os.path.join(folder, "*.tif")) + elif self.bands == 's1': + folder = self.folders[index]['s1'] + paths = glob.glob(os.path.join(folder, '*.tif')) paths = sorted(paths) else: - folder = self.folders[index]["s2"] - paths = glob.glob(os.path.join(folder, "*.tif")) + folder = self.folders[index]['s2'] + paths = glob.glob(os.path.join(folder, '*.tif')) paths = sorted(paths, key=sort_sentinel2_bands) return paths @@ -403,11 +403,11 @@ def _load_image(self, index: int) -> Tensor: array = dataset.read( indexes=1, out_shape=self.image_size, - out_dtype="int32", + out_dtype='int32', resampling=Resampling.bilinear, ) images.append(array) - arrays: "np.typing.NDArray[np.int_]" = np.stack(images, axis=0) + arrays: 'np.typing.NDArray[np.int_]' = np.stack(images, axis=0) tensor = torch.from_numpy(arrays).float() return tensor @@ -420,14 +420,14 @@ def _load_target(self, index: int) -> Tensor: Returns: the target label """ - if self.bands == "s2": - folder = self.folders[index]["s2"] + if self.bands == 's2': + folder = self.folders[index]['s2'] else: - folder = self.folders[index]["s1"] + folder = self.folders[index]['s1'] - path = glob.glob(os.path.join(folder, "*.json"))[0] + path = glob.glob(os.path.join(folder, '*.json'))[0] with open(path) as f: - labels = json.load(f)["labels"] + labels = json.load(f)['labels'] # labels -> indices indices = [self.class2idx[label] for label in labels] @@ -443,15 +443,15 @@ def _load_target(self, index: int) -> Tensor: def _verify(self) -> None: """Verify the integrity of the dataset.""" - keys = ["s1", "s2"] if self.bands == "all" else [self.bands] - urls = [self.metadata[k]["url"] for k in keys] - md5s = [self.metadata[k]["md5"] for k in keys] - filenames = [self.metadata[k]["filename"] for k in keys] - directories = [self.metadata[k]["directory"] for k in keys] - urls.extend([self.splits_metadata[k]["url"] for k in self.splits_metadata]) - md5s.extend([self.splits_metadata[k]["md5"] for k in self.splits_metadata]) + keys = ['s1', 's2'] if self.bands == 'all' else [self.bands] + urls = [self.metadata[k]['url'] for k in keys] + md5s = [self.metadata[k]['md5'] for k in keys] + filenames = [self.metadata[k]['filename'] for k in keys] + directories = [self.metadata[k]['directory'] for k in keys] + urls.extend([self.splits_metadata[k]['url'] for k in self.splits_metadata]) + md5s.extend([self.splits_metadata[k]['md5'] for k in self.splits_metadata]) filenames_splits = [ - self.splits_metadata[k]["filename"] for k in self.splits_metadata + self.splits_metadata[k]['filename'] for k in self.splits_metadata ] filenames.extend(filenames_splits) @@ -509,11 +509,11 @@ def _extract(self, filepath: str) -> None: Args: filepath: path to file to be extracted """ - if not filepath.endswith(".csv"): + if not filepath.endswith('.csv'): extract_archive(filepath) def _onehot_labels_to_names( - self, label_mask: "np.typing.NDArray[np.bool_]" + self, label_mask: 'np.typing.NDArray[np.bool_]' ) -> list[str]: """Gets a list of class names given a label mask. @@ -547,26 +547,26 @@ def plot( .. versionadded:: 0.2 """ - if self.bands == "s2": - image = np.rollaxis(sample["image"][[3, 2, 1]].numpy(), 0, 3) + if self.bands == 's2': + image = np.rollaxis(sample['image'][[3, 2, 1]].numpy(), 0, 3) image = np.clip(image / 2000, 0, 1) - elif self.bands == "all": - image = np.rollaxis(sample["image"][[5, 4, 3]].numpy(), 0, 3) + elif self.bands == 'all': + image = np.rollaxis(sample['image'][[5, 4, 3]].numpy(), 0, 3) image = np.clip(image / 2000, 0, 1) - elif self.bands == "s1": - image = sample["image"][0].numpy() + elif self.bands == 's1': + image = sample['image'][0].numpy() - label_mask = sample["label"].numpy().astype(np.bool_) + label_mask = sample['label'].numpy().astype(np.bool_) labels = self._onehot_labels_to_names(label_mask) - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction_mask = sample["prediction"].numpy().astype(np.bool_) + prediction_mask = sample['prediction'].numpy().astype(np.bool_) predictions = self._onehot_labels_to_names(prediction_mask) fig, ax = plt.subplots(figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: title = f"Labels: {', '.join(labels)}" if showing_predictions: diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py index 4b73f58e8c9..d3cbeb15417 100644 --- a/torchgeo/datasets/biomassters.py +++ b/torchgeo/datasets/biomassters.py @@ -49,16 +49,16 @@ class BioMassters(NonGeoDataset): .. versionadded:: 0.5 """ - valid_splits = ["train", "test"] - valid_sensors = ("S1", "S2") + valid_splits = ['train', 'test'] + valid_sensors = ('S1', 'S2') - metadata_filename = "The_BioMassters_-_features_metadata.csv.csv" + metadata_filename = 'The_BioMassters_-_features_metadata.csv.csv' def __init__( self, - root: str = "data", - split: str = "train", - sensors: Sequence[str] = ["S1", "S2"], + root: str = 'data', + split: str = 'train', + sensors: Sequence[str] = ['S1', 'S2'], as_time_series: bool = False, ) -> None: """Initialize a new instance of BioMassters dataset. @@ -82,12 +82,12 @@ def __init__( assert ( split in self.valid_splits - ), f"Please choose one of the valid splits: {self.valid_splits}." + ), f'Please choose one of the valid splits: {self.valid_splits}.' self.split = split assert set(sensors).issubset( set(self.valid_sensors) - ), f"Please choose a subset of valid sensors: {self.valid_sensors}." + ), f'Please choose a subset of valid sensors: {self.valid_sensors}.' self.sensors = sensors self.as_time_series = as_time_series @@ -97,34 +97,34 @@ def __init__( self.df = pd.read_csv(os.path.join(self.root, self.metadata_filename)) # filter sensors - self.df = self.df[self.df["satellite"].isin(self.sensors)] + self.df = self.df[self.df['satellite'].isin(self.sensors)] # filter split - self.df = self.df[self.df["split"] == self.split] + self.df = self.df[self.df['split'] == self.split] # generate numerical month from filename since first month is September # and has numerical index of 0 - self.df["num_month"] = ( - self.df["filename"] - .str.split("_", expand=True)[2] - .str.split(".", expand=True)[0] + self.df['num_month'] = ( + self.df['filename'] + .str.split('_', expand=True)[2] + .str.split('.', expand=True)[0] .astype(int) ) # set dataframe index depending on the task for easier indexing if self.as_time_series: - self.df["num_index"] = self.df.groupby(["chip_id"]).ngroup() + self.df['num_index'] = self.df.groupby(['chip_id']).ngroup() else: filter_df = ( - self.df.groupby(["chip_id", "month"])["satellite"].count().reset_index() + self.df.groupby(['chip_id', 'month'])['satellite'].count().reset_index() ) - filter_df = filter_df[filter_df["satellite"] == len(self.sensors)].drop( - "satellite", axis=1 + filter_df = filter_df[filter_df['satellite'] == len(self.sensors)].drop( + 'satellite', axis=1 ) # guarantee that each sample has corresponding number of images available - self.df = self.df.merge(filter_df, on=["chip_id", "month"], how="inner") + self.df = self.df.merge(filter_df, on=['chip_id', 'month'], how='inner') - self.df["num_index"] = self.df.groupby(["chip_id", "month"]).ngroup() + self.df['num_index'] = self.df.groupby(['chip_id', 'month']).ngroup() def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. @@ -138,22 +138,22 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Raises: IndexError: if index is out of range of the dataset """ - sample_df = self.df[self.df["num_index"] == index].copy() + sample_df = self.df[self.df['num_index'] == index].copy() # sort by satellite and month to return correct order sample_df.sort_values( - by=["satellite", "num_month"], inplace=True, ascending=True + by=['satellite', 'num_month'], inplace=True, ascending=True ) - filepaths = sample_df["filename"].tolist() + filepaths = sample_df['filename'].tolist() sample: dict[str, Tensor] = {} for sens in self.sensors: sens_filepaths = [fp for fp in filepaths if sens in fp] - sample[f"image_{sens}"] = self._load_input(sens_filepaths) + sample[f'image_{sens}'] = self._load_input(sens_filepaths) - if self.split == "train": - sample["label"] = self._load_target( - sample_df["corresponding_agbm"].unique()[0] + if self.split == 'train': + sample['label'] = self._load_target( + sample_df['corresponding_agbm'].unique()[0] ) return sample @@ -164,7 +164,7 @@ def __len__(self) -> int: Returns: length of the dataset """ - return len(self.df["num_index"].unique()) + return len(self.df['num_index'].unique()) def _load_input(self, filenames: list[str]) -> Tensor: """Load the input imagery at the index. @@ -176,7 +176,7 @@ def _load_input(self, filenames: list[str]) -> Tensor: input image """ filepaths = [ - os.path.join(self.root, f"{self.split}_features", f) for f in filenames + os.path.join(self.root, f'{self.split}_features', f) for f in filenames ] arr_list = [rasterio.open(fp).read() for fp in filepaths] if self.as_time_series: @@ -194,8 +194,8 @@ def _load_target(self, filename: str) -> Tensor: Returns: target mask """ - with rasterio.open(os.path.join(self.root, "train_agbm", filename), "r") as src: - arr: "np.typing.NDArray[np.float_]" = src.read() + with rasterio.open(os.path.join(self.root, 'train_agbm', filename), 'r') as src: + arr: 'np.typing.NDArray[np.float_]' = src.read() target = torch.from_numpy(arr).float() return target @@ -205,7 +205,7 @@ def _verify(self) -> None: # Check if the extracted files already exist exists = [] - filenames = [f"{self.split}_features", self.metadata_filename] + filenames = [f'{self.split}_features', self.metadata_filename] for filename in filenames: pathname = os.path.join(self.root, filename) exists.append(os.path.exists(pathname)) @@ -232,17 +232,17 @@ def plot( """ ncols = len(self.sensors) + 1 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: ncols += 1 fig, axs = plt.subplots(1, ncols=ncols, figsize=(5 * ncols, 10)) for idx, sens in enumerate(self.sensors): - img = sample[f"image_{sens}"].numpy() + img = sample[f'image_{sens}'].numpy() if self.as_time_series: # plot last time step img = img[-1, ...] - if sens == "S2": + if sens == 'S2': img = img[[2, 1, 0], ...] img = percentile_normalization(img.transpose(1, 2, 0)) else: @@ -260,26 +260,26 @@ def plot( img = np.stack((co_polarization, cross_polarization, ratio), axis=-1) axs[idx].imshow(img) - axs[idx].axis("off") + axs[idx].axis('off') if show_titles: axs[idx].set_title(sens) if showing_predictions: pred = axs[ncols - 2].imshow( - sample["prediction"].permute(1, 2, 0), cmap="YlGn" + sample['prediction'].permute(1, 2, 0), cmap='YlGn' ) plt.colorbar(pred, ax=axs[ncols - 2], fraction=0.046, pad=0.04) - axs[ncols - 2].axis("off") + axs[ncols - 2].axis('off') if show_titles: - axs[ncols - 2].set_title("Prediction") + axs[ncols - 2].set_title('Prediction') # plot target / only available in train set - if "label" in sample: - target = axs[-1].imshow(sample["label"].permute(1, 2, 0), cmap="YlGn") + if 'label' in sample: + target = axs[-1].imshow(sample['label'].permute(1, 2, 0), cmap='YlGn') plt.colorbar(target, ax=axs[-1], fraction=0.046, pad=0.04) - axs[-1].axis("Off") + axs[-1].axis('Off') if show_titles: - axs[-1].set_title("Target") + axs[-1].set_title('Target') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/cbf.py b/torchgeo/datasets/cbf.py index 63b76e592be..d7a42bfc7ad 100644 --- a/torchgeo/datasets/cbf.py +++ b/torchgeo/datasets/cbf.py @@ -27,41 +27,41 @@ class CanadianBuildingFootprints(VectorDataset): # TODO: how does one cite this dataset? # https://github.com/microsoft/CanadianBuildingFootprints/issues/11 - url = "https://usbuildingdata.blob.core.windows.net/canadian-buildings-v2/" + url = 'https://usbuildingdata.blob.core.windows.net/canadian-buildings-v2/' provinces_territories = [ - "Alberta", - "BritishColumbia", - "Manitoba", - "NewBrunswick", - "NewfoundlandAndLabrador", - "NorthwestTerritories", - "NovaScotia", - "Nunavut", - "Ontario", - "PrinceEdwardIsland", - "Quebec", - "Saskatchewan", - "YukonTerritory", + 'Alberta', + 'BritishColumbia', + 'Manitoba', + 'NewBrunswick', + 'NewfoundlandAndLabrador', + 'NorthwestTerritories', + 'NovaScotia', + 'Nunavut', + 'Ontario', + 'PrinceEdwardIsland', + 'Quebec', + 'Saskatchewan', + 'YukonTerritory', ] md5s = [ - "8b4190424e57bb0902bd8ecb95a9235b", - "fea05d6eb0006710729c675de63db839", - "adf11187362624d68f9c69aaa693c46f", - "44269d4ec89521735389ef9752ee8642", - "65dd92b1f3f5f7222ae5edfad616d266", - "346d70a682b95b451b81b47f660fd0e2", - "bd57cb1a7822d72610215fca20a12602", - "c1f29b73cdff9a6a9dd7d086b31ef2cf", - "76ba4b7059c5717989ce34977cad42b2", - "2e4a3fa47b3558503e61572c59ac5963", - "9ff4417ae00354d39a0cf193c8df592c", - "a51078d8e60082c7d3a3818240da6dd5", - "c11f3bd914ecabd7cac2cb2871ec0261", + '8b4190424e57bb0902bd8ecb95a9235b', + 'fea05d6eb0006710729c675de63db839', + 'adf11187362624d68f9c69aaa693c46f', + '44269d4ec89521735389ef9752ee8642', + '65dd92b1f3f5f7222ae5edfad616d266', + '346d70a682b95b451b81b47f660fd0e2', + 'bd57cb1a7822d72610215fca20a12602', + 'c1f29b73cdff9a6a9dd7d086b31ef2cf', + '76ba4b7059c5717989ce34977cad42b2', + '2e4a3fa47b3558503e61572c59ac5963', + '9ff4417ae00354d39a0cf193c8df592c', + 'a51078d8e60082c7d3a3818240da6dd5', + 'c11f3bd914ecabd7cac2cb2871ec0261', ] def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float = 0.00001, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -105,7 +105,7 @@ def _check_integrity(self) -> bool: """ assert isinstance(self.paths, str) for prov_terr, md5 in zip(self.provinces_territories, self.md5s): - filepath = os.path.join(self.paths, prov_terr + ".zip") + filepath = os.path.join(self.paths, prov_terr + '.zip') if not check_integrity(filepath, md5 if self.checksum else None): return False return True @@ -113,12 +113,12 @@ def _check_integrity(self) -> bool: def _download(self) -> None: """Download the dataset and extract it.""" if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return assert isinstance(self.paths, str) for prov_terr, md5 in zip(self.provinces_territories, self.md5s): download_and_extract_archive( - self.url + prov_terr + ".zip", + self.url + prov_terr + '.zip', self.paths, md5=md5 if self.checksum else None, ) @@ -143,29 +143,29 @@ def plot( Method now takes a sample dict, not a Tensor. Additionally, it is possible to show subplot titles and/or use a custom suptitle. """ - image = sample["mask"].squeeze(0) + image = sample['mask'].squeeze(0) ncols = 1 - showing_prediction = "prediction" in sample + showing_prediction = 'prediction' in sample if showing_prediction: - pred = sample["prediction"].squeeze(0) + pred = sample['prediction'].squeeze(0) ncols = 2 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4, 4)) if showing_prediction: axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(pred) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Mask") - axs[1].set_title("Prediction") + axs[0].set_title('Mask') + axs[1].set_title('Prediction') else: axs.imshow(image) - axs.axis("off") + axs.axis('off') if show_titles: - axs.set_title("Mask") + axs.set_title('Mask') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index 87ebc9029cf..e4430cc7daa 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -38,33 +38,33 @@ class CDL(RasterDataset): * https://www.nass.usda.gov/Research_and_Science/Cropland/sarsfaqs2.php#Section1_14.0 """ # noqa: E501 - filename_glob = "*_30m_cdls.tif" + filename_glob = '*_30m_cdls.tif' filename_regex = r""" ^(?P\d+) _30m_cdls\..*$ """ - zipfile_glob = "*_30m_cdls.zip" - date_format = "%Y" + zipfile_glob = '*_30m_cdls.zip' + date_format = '%Y' is_image = False - url = "https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip" # noqa: E501 + url = 'https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip' # noqa: E501 md5s = { - 2023: "8c7685d6278d50c554f934b16a6076b7", - 2022: "754cf50670cdfee511937554785de3e6", - 2021: "27606eab08fe975aa138baad3e5dfcd8", - 2020: "483ee48c503aa81b684225179b402d42", - 2019: "a5168a2fc93acbeaa93e24eee3d8c696", - 2018: "4ad0d7802a9bb751685eb239b0fa8609", - 2017: "d173f942a70f94622f9b8290e7548684", - 2016: "fddc5dff0bccc617d70a12864c993e51", - 2015: "2e92038ab62ba75e1687f60eecbdd055", - 2014: "50bdf9da84ebd0457ddd9e0bf9bbcc1f", - 2013: "7be66c650416dc7c4a945dd7fd93c5b7", - 2012: "286504ff0512e9fe1a1975c635a1bec2", - 2011: "517bad1a99beec45d90abb651fb1f0e3", - 2010: "98d354c5a62c9e3e40ccadce265c721c", - 2009: "663c8a5fdd92ebfc0d6bee008586d19a", - 2008: "0610f2f17ab60a9fbb3baeb7543993a4", + 2023: '8c7685d6278d50c554f934b16a6076b7', + 2022: '754cf50670cdfee511937554785de3e6', + 2021: '27606eab08fe975aa138baad3e5dfcd8', + 2020: '483ee48c503aa81b684225179b402d42', + 2019: 'a5168a2fc93acbeaa93e24eee3d8c696', + 2018: '4ad0d7802a9bb751685eb239b0fa8609', + 2017: 'd173f942a70f94622f9b8290e7548684', + 2016: 'fddc5dff0bccc617d70a12864c993e51', + 2015: '2e92038ab62ba75e1687f60eecbdd055', + 2014: '50bdf9da84ebd0457ddd9e0bf9bbcc1f', + 2013: '7be66c650416dc7c4a945dd7fd93c5b7', + 2012: '286504ff0512e9fe1a1975c635a1bec2', + 2011: '517bad1a99beec45d90abb651fb1f0e3', + 2010: '98d354c5a62c9e3e40ccadce265c721c', + 2009: '663c8a5fdd92ebfc0d6bee008586d19a', + 2008: '0610f2f17ab60a9fbb3baeb7543993a4', } cmap = { @@ -206,7 +206,7 @@ class CDL(RasterDataset): def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float | None = None, years: list[int] = [2023], @@ -244,13 +244,13 @@ def __init__( *root* was renamed to *paths*. """ assert set(years) <= self.md5s.keys(), ( - "CDL data product only exists for the following years: " - f"{list(self.md5s.keys())}." + 'CDL data product only exists for the following years: ' + f'{list(self.md5s.keys())}.' ) assert ( set(classes) <= self.cmap.keys() - ), f"Only the following classes are valid: {list(self.cmap.keys())}." - assert 0 in classes, "Classes must include the background class: 0" + ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths self.years = years @@ -282,7 +282,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ sample = super().__getitem__(query) - sample["mask"] = self.ordinal_map[sample["mask"]] + sample['mask'] = self.ordinal_map[sample['mask']] return sample def _verify(self) -> None: @@ -296,7 +296,7 @@ def _verify(self) -> None: assert isinstance(self.paths, str) for year in self.years: pathname = os.path.join( - self.paths, self.zipfile_glob.replace("*", str(year)) + self.paths, self.zipfile_glob.replace('*', str(year)) ) if os.path.exists(pathname): exists.append(True) @@ -328,7 +328,7 @@ def _extract(self) -> None: """Extract the dataset.""" assert isinstance(self.paths, str) for year in self.years: - zipfile_name = self.zipfile_glob.replace("*", str(year)) + zipfile_name = self.zipfile_glob.replace('*', str(year)) pathname = os.path.join(self.paths, zipfile_name) extract_archive(pathname, self.paths) @@ -352,29 +352,29 @@ def plot( Method now takes a sample dict, not a Tensor. Additionally, possible to show subplot titles and/or use a custom suptitle. """ - mask = sample["mask"].squeeze() + mask = sample['mask'].squeeze() ncols = 1 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - pred = sample["prediction"].squeeze() + pred = sample['prediction'].squeeze() ncols = 2 fig, axs = plt.subplots( nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False ) - axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation="none") - axs[0, 0].axis("off") + axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation='none') + axs[0, 0].axis('off') if show_titles: - axs[0, 0].set_title("Mask") + axs[0, 0].set_title('Mask') if showing_predictions: - axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation="none") - axs[0, 1].axis("off") + axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation='none') + axs[0, 1].axis('off') if show_titles: - axs[0, 1].set_title("Prediction") + axs[0, 1].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/chabud.py b/torchgeo/datasets/chabud.py index 21362bb9f8e..b97f30091ee 100644 --- a/torchgeo/datasets/chabud.py +++ b/torchgeo/datasets/chabud.py @@ -53,29 +53,29 @@ class ChaBuD(NonGeoDataset): """ all_bands = [ - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B11", - "B12", + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', ] - rgb_bands = ["B04", "B03", "B02"] - folds = {"train": [1, 2, 3, 4], "val": [0]} - url = "https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5" # noqa: E501 - filename = "train_eval.hdf5" - md5 = "15d78fb825f9a81dad600db828d22c08" + rgb_bands = ['B04', 'B03', 'B02'] + folds = {'train': [1, 2, 3, 4], 'val': [0]} + url = 'https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5' # noqa: E501 + filename = 'train_eval.hdf5' + md5 = '15d78fb825f9a81dad600db828d22c08' def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', bands: list[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -114,7 +114,7 @@ def __init__( import h5py # noqa: F401 except ImportError: raise ImportError( - "h5py is not installed and is required to use this dataset" + 'h5py is not installed and is required to use this dataset' ) self.uuids = self._load_uuids() @@ -131,7 +131,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image = self._load_image(index) mask = self._load_target(index) - sample = {"image": image, "mask": mask} + sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -155,9 +155,9 @@ def _load_uuids(self) -> list[str]: import h5py uuids = [] - with h5py.File(self.filepath, "r") as f: + with h5py.File(self.filepath, 'r') as f: for k, v in f.items(): - if v.attrs["fold"] in self.folds[self.split] and "pre_fire" in v: + if v.attrs['fold'] in self.folds[self.split] and 'pre_fire' in v: uuids.append(k) uuids = sorted(uuids) @@ -175,9 +175,9 @@ def _load_image(self, index: int) -> Tensor: import h5py uuid = self.uuids[index] - with h5py.File(self.filepath, "r") as f: - pre_array = f[uuid]["pre_fire"][:] - post_array = f[uuid]["post_fire"][:] + with h5py.File(self.filepath, 'r') as f: + pre_array = f[uuid]['pre_fire'][:] + post_array = f[uuid]['post_fire'][:] # index specified bands and concatenate pre_array = pre_array[..., self.band_indices] @@ -201,8 +201,8 @@ def _load_target(self, index: int) -> Tensor: import h5py uuid = self.uuids[index] - with h5py.File(self.filepath, "r") as f: - array = f[uuid]["mask"][:].astype(np.int32).squeeze(axis=-1) + with h5py.File(self.filepath, 'r') as f: + array = f[uuid]['mask'][:].astype(np.int32).squeeze(axis=-1) tensor = torch.from_numpy(array) tensor = tensor.to(torch.long) @@ -254,38 +254,38 @@ def plot( else: raise ValueError("Dataset doesn't contain some of the RGB bands") - mask = sample["mask"].numpy() - image_pre = sample["image"][: len(self.bands)][rgb_indices].numpy() - image_post = sample["image"][len(self.bands) :][rgb_indices].numpy() + mask = sample['mask'].numpy() + image_pre = sample['image'][: len(self.bands)][rgb_indices].numpy() + image_post = sample['image'][len(self.bands) :][rgb_indices].numpy() image_pre = percentile_normalization(image_pre) image_post = percentile_normalization(image_post) ncols = 3 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = sample["prediction"] + prediction = sample['prediction'] ncols += 1 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) axs[0].imshow(np.transpose(image_pre, (1, 2, 0))) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(np.transpose(image_post, (1, 2, 0))) - axs[1].axis("off") + axs[1].axis('off') axs[2].imshow(mask) - axs[2].axis("off") + axs[2].axis('off') if showing_predictions: axs[3].imshow(prediction) - axs[3].axis("off") + axs[3].axis('off') if show_titles: - axs[0].set_title("Image Pre") - axs[1].set_title("Image Post") - axs[2].set_title("Mask") + axs[0].set_title('Image Pre') + axs[1].set_title('Image Post') + axs[2].set_title('Mask') if showing_predictions: - axs[3].set_title("Prediction") + axs[3].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index a0509302ea6..f0ccc30c3f4 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -84,13 +84,13 @@ def md5(self) -> str: @property def url(self) -> str: """URL to download dataset from.""" - url = "https://cicwebresources.blob.core.windows.net/chesapeakebaylandcover" - url += f"/{self.base_folder}/{self.zipfile}" + url = 'https://cicwebresources.blob.core.windows.net/chesapeakebaylandcover' + url += f'/{self.base_folder}/{self.zipfile}' return url def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -186,12 +186,12 @@ def plot( Method now takes a sample dict, not a Tensor. Additionally, possible to show subplot titles and/or use a custom suptitle. """ - mask = sample["mask"].squeeze(0) + mask = sample['mask'].squeeze(0) ncols = 1 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - pred = sample["prediction"].squeeze(0) + pred = sample['prediction'].squeeze(0) ncols = 2 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4 * ncols, 4)) @@ -202,20 +202,20 @@ def plot( vmin=0, vmax=self._cmap.N - 1, cmap=self._cmap, - interpolation="none", + interpolation='none', ) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow( pred, vmin=0, vmax=self._cmap.N - 1, cmap=self._cmap, - interpolation="none", + interpolation='none', ) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Mask") - axs[1].set_title("Prediction") + axs[0].set_title('Mask') + axs[1].set_title('Prediction') else: axs.imshow( @@ -223,11 +223,11 @@ def plot( vmin=0, vmax=self._cmap.N - 1, cmap=self._cmap, - interpolation="none", + interpolation='none', ) - axs.axis("off") + axs.axis('off') if show_titles: - axs.set_title("Mask") + axs.set_title('Mask') if suptitle is not None: plt.suptitle(suptitle) @@ -250,11 +250,11 @@ class Chesapeake7(Chesapeake): 7. Aberdeen Proving Ground: U.S. Army facility with no labels """ - base_folder = "BAYWIDE" - filename = "Baywide_7class_20132014.tif" + base_folder = 'BAYWIDE' + filename = 'Baywide_7class_20132014.tif' filename_glob = filename - zipfile = "Baywide_7Class_20132014.zip" - md5 = "61a4e948fb2551840b6557ef195c2084" + zipfile = 'Baywide_7Class_20132014.zip' + md5 = '61a4e948fb2551840b6557ef195c2084' cmap = { 0: (0, 0, 0, 0), @@ -289,31 +289,31 @@ class Chesapeake13(Chesapeake): 13. Aberdeen Proving Ground: U.S. Army facility with no labels """ - base_folder = "BAYWIDE" - filename = "Baywide_13Class_20132014.tif" + base_folder = 'BAYWIDE' + filename = 'Baywide_13Class_20132014.tif' filename_glob = filename - zipfile = "Baywide_13Class_20132014.zip" - md5 = "7e51118923c91e80e6e268156d25a4b9" + zipfile = 'Baywide_13Class_20132014.zip' + md5 = '7e51118923c91e80e6e268156d25a4b9' class ChesapeakeDC(Chesapeake): """This subset of the dataset contains data only for Washington, D.C.""" - base_folder = "DC" - filename = os.path.join("DC_11001", "DC_11001.img") + base_folder = 'DC' + filename = os.path.join('DC_11001', 'DC_11001.img') filename_glob = filename - zipfile = "DC_11001.zip" - md5 = "ed06ba7570d2955e8857d7d846c53b06" + zipfile = 'DC_11001.zip' + md5 = 'ed06ba7570d2955e8857d7d846c53b06' class ChesapeakeDE(Chesapeake): """This subset of the dataset contains data only for Delaware.""" - base_folder = "DE" - filename = "DE_STATEWIDE.tif" + base_folder = 'DE' + filename = 'DE_STATEWIDE.tif' filename_glob = filename - zipfile = "_DE_STATEWIDE.zip" - md5 = "5e12eff3b6950c01092c7e480b38e544" + zipfile = '_DE_STATEWIDE.zip' + md5 = '5e12eff3b6950c01092c7e480b38e544' class ChesapeakeMD(Chesapeake): @@ -327,11 +327,11 @@ class ChesapeakeMD(Chesapeake): the proprietary deflate64 compressed zip file. """ - base_folder = "MD" - filename = "MD_STATEWIDE.tif" + base_folder = 'MD' + filename = 'MD_STATEWIDE.tif' filename_glob = filename - zipfile = "_MD_STATEWIDE.zip" - md5 = "40c7cd697a887f2ffdb601b5c114e567" + zipfile = '_MD_STATEWIDE.zip' + md5 = '40c7cd697a887f2ffdb601b5c114e567' class ChesapeakeNY(Chesapeake): @@ -345,21 +345,21 @@ class ChesapeakeNY(Chesapeake): the proprietary deflate64 compressed zip file. """ - base_folder = "NY" - filename = "NY_STATEWIDE.tif" + base_folder = 'NY' + filename = 'NY_STATEWIDE.tif' filename_glob = filename - zipfile = "_NY_STATEWIDE.zip" - md5 = "1100078c526616454ef2e508affda915" + zipfile = '_NY_STATEWIDE.zip' + md5 = '1100078c526616454ef2e508affda915' class ChesapeakePA(Chesapeake): """This subset of the dataset contains data only for Pennsylvania.""" - base_folder = "PA" - filename = "PA_STATEWIDE.tif" + base_folder = 'PA' + filename = 'PA_STATEWIDE.tif' filename_glob = filename - zipfile = "_PA_STATEWIDE.zip" - md5 = "20a2a857c527a4dbadd6beed8b47e5ab" + zipfile = '_PA_STATEWIDE.zip' + md5 = '20a2a857c527a4dbadd6beed8b47e5ab' class ChesapeakeVA(Chesapeake): @@ -373,21 +373,21 @@ class ChesapeakeVA(Chesapeake): the proprietary deflate64 compressed zip file. """ - base_folder = "VA" - filename = "CIC2014_VA_STATEWIDE.tif" + base_folder = 'VA' + filename = 'CIC2014_VA_STATEWIDE.tif' filename_glob = filename - zipfile = "_VA_STATEWIDE.zip" - md5 = "6f2c97deaf73bb3e1ea9b21bd7a3fc8e" + zipfile = '_VA_STATEWIDE.zip' + md5 = '6f2c97deaf73bb3e1ea9b21bd7a3fc8e' class ChesapeakeWV(Chesapeake): """This subset of the dataset contains data only for West Virginia.""" - base_folder = "WV" - filename = "WV_STATEWIDE.tif" + base_folder = 'WV' + filename = 'WV_STATEWIDE.tif' filename_glob = filename - zipfile = "_WV_STATEWIDE.zip" - md5 = "350621ea293651fbc557a1c3e3c64cc3" + zipfile = '_WV_STATEWIDE.zip' + md5 = '350621ea293651fbc557a1c3e3c64cc3' class ChesapeakeCVPR(GeoDataset): @@ -412,18 +412,18 @@ class ChesapeakeCVPR(GeoDataset): * https://doi.org/10.1109/cvpr.2019.01301 """ - subdatasets = ["base", "prior_extension"] + subdatasets = ['base', 'prior_extension'] urls = { - "base": "https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip", # noqa: E501 - "prior_extension": "https://zenodo.org/record/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1", # noqa: E501 + 'base': 'https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip', # noqa: E501 + 'prior_extension': 'https://zenodo.org/record/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1', # noqa: E501 } filenames = { - "base": "cvpr_chesapeake_landcover.zip", - "prior_extension": "cvpr_chesapeake_landcover_prior_extension.zip", + 'base': 'cvpr_chesapeake_landcover.zip', + 'prior_extension': 'cvpr_chesapeake_landcover_prior_extension.zip', } md5s = { - "base": "1225ccbb9590e9396875f221e5031514", - "prior_extension": "402f41d07823c8faf7ea6960d7c4e17a", + 'base': '1225ccbb9590e9396875f221e5031514', + 'prior_extension': '402f41d07823c8faf7ea6960d7c4e17a', } crs = CRS.from_epsg(3857) @@ -450,68 +450,68 @@ class ChesapeakeCVPR(GeoDataset): ) valid_layers = [ - "naip-new", - "naip-old", - "landsat-leaf-on", - "landsat-leaf-off", - "nlcd", - "lc", - "buildings", - "prior_from_cooccurrences_101_31_no_osm_no_buildings", + 'naip-new', + 'naip-old', + 'landsat-leaf-on', + 'landsat-leaf-off', + 'nlcd', + 'lc', + 'buildings', + 'prior_from_cooccurrences_101_31_no_osm_no_buildings', ] - states = ["de", "md", "va", "wv", "pa", "ny"] + states = ['de', 'md', 'va', 'wv', 'pa', 'ny'] splits = ( - [f"{state}-train" for state in states] - + [f"{state}-val" for state in states] - + [f"{state}-test" for state in states] + [f'{state}-train' for state in states] + + [f'{state}-val' for state in states] + + [f'{state}-test' for state in states] ) # these are used to check the integrity of the dataset _files = [ - "de_1m_2013_extended-debuffered-test_tiles", - "de_1m_2013_extended-debuffered-train_tiles", - "de_1m_2013_extended-debuffered-val_tiles", - "md_1m_2013_extended-debuffered-test_tiles", - "md_1m_2013_extended-debuffered-train_tiles", - "md_1m_2013_extended-debuffered-val_tiles", - "ny_1m_2013_extended-debuffered-test_tiles", - "ny_1m_2013_extended-debuffered-train_tiles", - "ny_1m_2013_extended-debuffered-val_tiles", - "pa_1m_2013_extended-debuffered-test_tiles", - "pa_1m_2013_extended-debuffered-train_tiles", - "pa_1m_2013_extended-debuffered-val_tiles", - "va_1m_2014_extended-debuffered-test_tiles", - "va_1m_2014_extended-debuffered-train_tiles", - "va_1m_2014_extended-debuffered-val_tiles", - "wv_1m_2014_extended-debuffered-test_tiles", - "wv_1m_2014_extended-debuffered-train_tiles", - "wv_1m_2014_extended-debuffered-val_tiles", - "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_buildings.tif", - "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif", # noqa: E501 - "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif", # noqa: E501 - "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_lc.tif", - "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-new.tif", - "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-old.tif", - "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_nlcd.tif", - "wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif", # noqa: E501 - "spatial_index.geojson", + 'de_1m_2013_extended-debuffered-test_tiles', + 'de_1m_2013_extended-debuffered-train_tiles', + 'de_1m_2013_extended-debuffered-val_tiles', + 'md_1m_2013_extended-debuffered-test_tiles', + 'md_1m_2013_extended-debuffered-train_tiles', + 'md_1m_2013_extended-debuffered-val_tiles', + 'ny_1m_2013_extended-debuffered-test_tiles', + 'ny_1m_2013_extended-debuffered-train_tiles', + 'ny_1m_2013_extended-debuffered-val_tiles', + 'pa_1m_2013_extended-debuffered-test_tiles', + 'pa_1m_2013_extended-debuffered-train_tiles', + 'pa_1m_2013_extended-debuffered-val_tiles', + 'va_1m_2014_extended-debuffered-test_tiles', + 'va_1m_2014_extended-debuffered-train_tiles', + 'va_1m_2014_extended-debuffered-val_tiles', + 'wv_1m_2014_extended-debuffered-test_tiles', + 'wv_1m_2014_extended-debuffered-train_tiles', + 'wv_1m_2014_extended-debuffered-val_tiles', + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_buildings.tif', + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif', # noqa: E501 + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif', # noqa: E501 + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_lc.tif', + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-new.tif', + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-old.tif', + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_nlcd.tif', + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501 + 'spatial_index.geojson', ] - p_src_crs = pyproj.CRS("epsg:3857") + p_src_crs = pyproj.CRS('epsg:3857') p_transformers = { - "epsg:26917": pyproj.Transformer.from_crs( - p_src_crs, pyproj.CRS("epsg:26917"), always_xy=True + 'epsg:26917': pyproj.Transformer.from_crs( + p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True ).transform, - "epsg:26918": pyproj.Transformer.from_crs( - p_src_crs, pyproj.CRS("epsg:26918"), always_xy=True + 'epsg:26918': pyproj.Transformer.from_crs( + p_src_crs, pyproj.CRS('epsg:26918'), always_xy=True ).transform, } def __init__( self, - root: str = "data", - splits: Sequence[str] = ["de-train"], - layers: Sequence[str] = ["naip-new", "lc"], + root: str = 'data', + splits: Sequence[str] = ['de-train'], + layers: Sequence[str] = ['naip-new', 'lc'], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, download: bool = False, @@ -563,30 +563,30 @@ def __init__( # Add all tiles into the index in epsg:3857 based on the included geojson mint: float = 0 maxt: float = sys.maxsize - with fiona.open(os.path.join(root, "spatial_index.geojson"), "r") as f: + with fiona.open(os.path.join(root, 'spatial_index.geojson'), 'r') as f: for i, row in enumerate(f): - if row["properties"]["split"] in splits: - box = shapely.geometry.shape(row["geometry"]) + if row['properties']['split'] in splits: + box = shapely.geometry.shape(row['geometry']) minx, miny, maxx, maxy = box.bounds coords = (minx, maxx, miny, maxy, mint, maxt) - prior_fn = row["properties"]["lc"].replace( - "lc.tif", - "prior_from_cooccurrences_101_31_no_osm_no_buildings.tif", + prior_fn = row['properties']['lc'].replace( + 'lc.tif', + 'prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', ) self.index.insert( i, coords, { - "naip-new": row["properties"]["naip-new"], - "naip-old": row["properties"]["naip-old"], - "landsat-leaf-on": row["properties"]["landsat-leaf-on"], - "landsat-leaf-off": row["properties"]["landsat-leaf-off"], - "lc": row["properties"]["lc"], - "nlcd": row["properties"]["nlcd"], - "buildings": row["properties"]["buildings"], - "prior_from_cooccurrences_101_31_no_osm_no_buildings": prior_fn, # noqa: E501 + 'naip-new': row['properties']['naip-new'], + 'naip-old': row['properties']['naip-old'], + 'landsat-leaf-on': row['properties']['landsat-leaf-on'], + 'landsat-leaf-off': row['properties']['landsat-leaf-off'], + 'lc': row['properties']['lc'], + 'nlcd': row['properties']['nlcd'], + 'buildings': row['properties']['buildings'], + 'prior_from_cooccurrences_101_31_no_osm_no_buildings': prior_fn, # noqa: E501 }, ) @@ -605,11 +605,11 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: hits = self.index.intersection(tuple(query), objects=True) filepaths = cast(list[dict[str, str]], [hit.object for hit in hits]) - sample = {"image": [], "mask": [], "crs": self.crs, "bbox": query} + sample = {'image': [], 'mask': [], 'crs': self.crs, 'bbox': query} if len(filepaths) == 0: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) elif len(filepaths) == 1: filenames = filepaths[0] @@ -637,27 +637,27 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: ) if layer in [ - "naip-new", - "naip-old", - "landsat-leaf-on", - "landsat-leaf-off", + 'naip-new', + 'naip-old', + 'landsat-leaf-on', + 'landsat-leaf-off', ]: - sample["image"].append(data) + sample['image'].append(data) elif layer in [ - "lc", - "nlcd", - "buildings", - "prior_from_cooccurrences_101_31_no_osm_no_buildings", + 'lc', + 'nlcd', + 'buildings', + 'prior_from_cooccurrences_101_31_no_osm_no_buildings', ]: - sample["mask"].append(data) + sample['mask'].append(data) else: - raise IndexError(f"query: {query} spans multiple tiles which is not valid") + raise IndexError(f'query: {query} spans multiple tiles which is not valid') - sample["image"] = np.concatenate(sample["image"], axis=0) - sample["mask"] = np.concatenate(sample["mask"], axis=0) + sample['image'] = np.concatenate(sample['image'], axis=0) + sample['mask'] = np.concatenate(sample['mask'], axis=0) - sample["image"] = torch.from_numpy(sample["image"]).float() - sample["mask"] = torch.from_numpy(sample["mask"]).long() + sample['image'] = torch.from_numpy(sample['image']).float() + sample['mask'] = torch.from_numpy(sample['mask']).long() if self.transforms is not None: sample = self.transforms(sample) @@ -725,72 +725,72 @@ def plot( .. versionadded:: 0.4 """ - image = np.rollaxis(sample["image"].numpy(), 0, 3) - mask = sample["mask"].numpy() + image = np.rollaxis(sample['image'].numpy(), 0, 3) + mask = sample['mask'].numpy() if mask.ndim == 3: mask = np.rollaxis(mask, 0, 3) else: mask = np.expand_dims(mask, 2) num_panels = len(self.layers) - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - predictions = sample["prediction"].numpy() + predictions = sample['prediction'].numpy() num_panels += 1 fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5)) i = 0 for layer in self.layers: - if layer == "naip-new" or layer == "naip-old": + if layer == 'naip-new' or layer == 'naip-old': img = image[:, :, :3] / 255 image = image[:, :, 4:] - axs[i].axis("off") + axs[i].axis('off') axs[i].imshow(img) - elif layer == "landsat-leaf-on" or layer == "landsat-leaf-off": + elif layer == 'landsat-leaf-on' or layer == 'landsat-leaf-off': img = image[:, :, [3, 2, 1]] / 3000 image = image[:, :, 9:] - axs[i].axis("off") + axs[i].axis('off') axs[i].imshow(img) - elif layer == "nlcd": + elif layer == 'nlcd': img = mask[:, :, 0] mask = mask[:, :, 1:] axs[i].imshow( - img, vmin=0, vmax=95, cmap=self._nlcd_cmap, interpolation="none" + img, vmin=0, vmax=95, cmap=self._nlcd_cmap, interpolation='none' ) - axs[i].axis("off") - elif layer == "lc": + axs[i].axis('off') + elif layer == 'lc': img = mask[:, :, 0] mask = mask[:, :, 1:] axs[i].imshow( - img, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none" + img, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation='none' ) - axs[i].axis("off") - elif layer == "buildings": + axs[i].axis('off') + elif layer == 'buildings': img = mask[:, :, 0] mask = mask[:, :, 1:] - axs[i].imshow(img, vmin=0, vmax=1, cmap="gray", interpolation="none") - axs[i].axis("off") - elif layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings": + axs[i].imshow(img, vmin=0, vmax=1, cmap='gray', interpolation='none') + axs[i].axis('off') + elif layer == 'prior_from_cooccurrences_101_31_no_osm_no_buildings': img = (mask[:, :, :4] @ self.prior_color_matrix) / 255 mask = mask[:, :, 4:] axs[i].imshow(img) - axs[i].axis("off") + axs[i].axis('off') if show_titles: - if layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings": - axs[i].set_title("prior") + if layer == 'prior_from_cooccurrences_101_31_no_osm_no_buildings': + axs[i].set_title('prior') else: axs[i].set_title(layer) i += 1 if showing_predictions: axs[i].imshow( - predictions, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none" + predictions, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation='none' ) - axs[i].axis("off") + axs[i].axis('off') if show_titles: - axs[i].set_title("Predictions") + axs[i].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/cloud_cover.py b/torchgeo/datasets/cloud_cover.py index d785a73ff28..6ee54d8f24e 100644 --- a/torchgeo/datasets/cloud_cover.py +++ b/torchgeo/datasets/cloud_cover.py @@ -63,53 +63,53 @@ class CloudCoverDetection(NonGeoDataset): """ collection_ids = [ - "ref_cloud_cover_detection_challenge_v1_train_source", - "ref_cloud_cover_detection_challenge_v1_train_labels", - "ref_cloud_cover_detection_challenge_v1_test_source", - "ref_cloud_cover_detection_challenge_v1_test_labels", + 'ref_cloud_cover_detection_challenge_v1_train_source', + 'ref_cloud_cover_detection_challenge_v1_train_labels', + 'ref_cloud_cover_detection_challenge_v1_test_source', + 'ref_cloud_cover_detection_challenge_v1_test_labels', ] image_meta = { - "train": { - "filename": "ref_cloud_cover_detection_challenge_v1_train_source.tar.gz", - "md5": "32cfe38e313bcedc09dca3f0f9575eea", + 'train': { + 'filename': 'ref_cloud_cover_detection_challenge_v1_train_source.tar.gz', + 'md5': '32cfe38e313bcedc09dca3f0f9575eea', }, - "test": { - "filename": "ref_cloud_cover_detection_challenge_v1_test_source.tar.gz", - "md5": "6c67edae18716598d47298f24992db6c", + 'test': { + 'filename': 'ref_cloud_cover_detection_challenge_v1_test_source.tar.gz', + 'md5': '6c67edae18716598d47298f24992db6c', }, } target_meta = { - "train": { - "filename": "ref_cloud_cover_detection_challenge_v1_train_labels.tar.gz", - "md5": "695dfb1034924c10fbb17f9293815671", + 'train': { + 'filename': 'ref_cloud_cover_detection_challenge_v1_train_labels.tar.gz', + 'md5': '695dfb1034924c10fbb17f9293815671', }, - "test": { - "filename": "ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz", - "md5": "ec2b42bb43e9a03a01ae096f9e09db9c", + 'test': { + 'filename': 'ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz', + 'md5': 'ec2b42bb43e9a03a01ae096f9e09db9c', }, } collection_names = { - "train": [ - "ref_cloud_cover_detection_challenge_v1_train_source", - "ref_cloud_cover_detection_challenge_v1_train_labels", + 'train': [ + 'ref_cloud_cover_detection_challenge_v1_train_source', + 'ref_cloud_cover_detection_challenge_v1_train_labels', ], - "test": [ - "ref_cloud_cover_detection_challenge_v1_test_source", - "ref_cloud_cover_detection_challenge_v1_test_labels", + 'test': [ + 'ref_cloud_cover_detection_challenge_v1_test_source', + 'ref_cloud_cover_detection_challenge_v1_test_labels', ], } - band_names = ["B02", "B03", "B04", "B08"] + band_names = ['B02', 'B03', 'B04', 'B08'] - rgb_bands = ["B04", "B03", "B02"] + rgb_bands = ['B04', 'B03', 'B02'] def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', bands: Sequence[str] = band_names, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -166,7 +166,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ image = self._load_image(index) label = self._load_target(index) - sample: dict[str, Tensor] = {"image": image, "mask": label} + sample: dict[str, Tensor] = {'image': image, 'mask': label} if self.transforms is not None: sample = self.transforms(sample) @@ -182,13 +182,13 @@ def _load_image(self, index: int) -> Tensor: Returns: a tensor of stacked source image data """ - source_asset_paths = self.chip_paths[index]["source"] + source_asset_paths = self.chip_paths[index]['source'] images = [] for path in source_asset_paths: with rasterio.open(path) as image_data: image_array = image_data.read(1).astype(np.int32) images.append(image_array) - image_stack: "np.typing.NDArray[np.int_]" = np.stack(images, axis=0) + image_stack: 'np.typing.NDArray[np.int_]' = np.stack(images, axis=0) image_tensor = torch.from_numpy(image_stack) return image_tensor @@ -201,11 +201,11 @@ def _load_target(self, index: int) -> Tensor: Returns: a tensor of the label image data """ - label_asset_path = self.chip_paths[index]["target"][0] + label_asset_path = self.chip_paths[index]['target'][0] with rasterio.open(label_asset_path) as target_data: target_img = target_data.read(1).astype(np.int32) - target_array: "np.typing.NDArray[np.int_]" = np.array(target_img) + target_array: 'np.typing.NDArray[np.int_]' = np.array(target_img) target_tensor = torch.from_numpy(target_array) return target_tensor @@ -237,15 +237,15 @@ def _load_items(self, item_json: str) -> dict[str, list[str]]: label_data = self._read_json_data(item_json) label_asset_path = os.path.join( - os.path.split(item_json)[0], label_data["assets"]["labels"]["href"] + os.path.split(item_json)[0], label_data['assets']['labels']['href'] ) - item_meta["target"] = [label_asset_path] + item_meta['target'] = [label_asset_path] source_item_hrefs = [] - for link in label_data["links"]: - if link["rel"] == "source": + for link in label_data['links']: + if link['rel'] == 'source': source_item_hrefs.append( - os.path.join(self.root, link["href"].replace("../../", "")) + os.path.join(self.root, link['href'].replace('../../', '')) ) source_item_hrefs = sorted(source_item_hrefs) @@ -255,16 +255,16 @@ def _load_items(self, item_json: str) -> dict[str, list[str]]: source_item_path = os.path.split(item_href)[0] source_data = self._read_json_data(item_href) source_item_assets = [] - for asset_key, asset_value in source_data["assets"].items(): + for asset_key, asset_value in source_data['assets'].items(): if asset_key in self.bands: source_item_assets.append( - os.path.join(source_item_path, asset_value["href"]) + os.path.join(source_item_path, asset_value['href']) ) source_item_assets = sorted(source_item_assets) for source_item_asset in source_item_assets: source_item_paths.append(source_item_asset) - item_meta["source"] = source_item_paths + item_meta['source'] = source_item_paths return item_meta def _load_collections(self) -> list[dict[str, Any]]: @@ -279,15 +279,15 @@ def _load_collections(self) -> list[dict[str, Any]]: indexed_chips = [] label_collection: list[str] = [] for c in self.collection_names[self.split]: - if "label" in c: + if 'label' in c: label_collection.append(c) label_collection_path = os.path.join(self.root, label_collection[0]) - label_collection_json = os.path.join(label_collection_path, "collection.json") + label_collection_json = os.path.join(label_collection_path, 'collection.json') label_collection_item_hrefs = [] - for link in self._read_json_data(label_collection_json)["links"]: - if link["rel"] == "item": - label_collection_item_hrefs.append(link["href"]) + for link in self._read_json_data(label_collection_json)['links']: + if link['rel'] == 'item': + label_collection_item_hrefs.append(link['href']) label_collection_item_hrefs = sorted(label_collection_item_hrefs) @@ -318,13 +318,13 @@ def _check_integrity(self) -> bool: True if dataset files are found and/or MD5s match, else False """ images: bool = check_integrity( - os.path.join(self.root, self.image_meta[self.split]["filename"]), - self.image_meta[self.split]["md5"] if self.checksum else None, + os.path.join(self.root, self.image_meta[self.split]['filename']), + self.image_meta[self.split]['md5'] if self.checksum else None, ) targets: bool = check_integrity( - os.path.join(self.root, self.target_meta[self.split]["filename"]), - self.target_meta[self.split]["md5"] if self.checksum else None, + os.path.join(self.root, self.target_meta[self.split]['filename']), + self.target_meta[self.split]['md5'] if self.checksum else None, ) return images and targets @@ -336,17 +336,17 @@ def _download(self, api_key: str | None = None) -> None: api_key: a RadiantEarth MLHub API key to use for downloading the dataset """ if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return for collection_id in self.collection_ids: download_radiant_mlhub_collection(collection_id, self.root, api_key) image_archive_path = os.path.join( - self.root, self.image_meta[self.split]["filename"] + self.root, self.image_meta[self.split]['filename'] ) target_archive_path = os.path.join( - self.root, self.target_meta[self.split]["filename"] + self.root, self.target_meta[self.split]['filename'] ) for fn in [image_archive_path, target_archive_path]: extract_archive(fn, self.root) @@ -378,30 +378,30 @@ def plot( else: raise RGBBandsMissingError() - if "prediction" in sample: - prediction = sample["prediction"] + if 'prediction' in sample: + prediction = sample['prediction'] n_cols = 3 else: n_cols = 2 - image, mask = sample["image"] / 3000, sample["mask"] + image, mask = sample['image'] / 3000, sample['mask'] fig, axs = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, n_cols * 5)) axs[0].imshow(image.permute(1, 2, 0)) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(mask) - axs[1].axis("off") + axs[1].axis('off') - if "prediction" in sample: + if 'prediction' in sample: axs[2].imshow(prediction) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[2].set_title("Prediction") + axs[2].set_title('Prediction') if show_titles: - axs[0].set_title("Image") - axs[1].set_title("Mask") + axs[0].set_title('Image') + axs[1].set_title('Mask') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py index 415a5c5c102..32697b224bc 100644 --- a/torchgeo/datasets/cms_mangrove_canopy.py +++ b/torchgeo/datasets/cms_mangrove_canopy.py @@ -37,141 +37,141 @@ class CMSGlobalMangroveCanopy(RasterDataset): _(?P[A-Za-z][^.]*) """ - zipfile = "CMS_Global_Map_Mangrove_Canopy_1665.zip" - md5 = "3e7f9f23bf971c25e828b36e6c5496e3" + zipfile = 'CMS_Global_Map_Mangrove_Canopy_1665.zip' + md5 = '3e7f9f23bf971c25e828b36e6c5496e3' all_countries = [ - "AndamanAndNicobar", - "Angola", - "Anguilla", - "AntiguaAndBarbuda", - "Aruba", - "Australia", - "Bahamas", - "Bahrain", - "Bangladesh", - "Barbados", - "Belize", - "Benin", - "Brazil", - "BritishVirginIslands", - "Brunei", - "Cambodia", - "Cameroon", - "CarribeanCaymanIslands", - "China", - "Colombia", - "Comoros", - "CostaRica", - "Cote", - "CoteDivoire", - "CotedIvoire", - "Cuba", - "DemocraticRepublicOfCongo", - "Djibouti", - "DominicanRepublic", - "EcuadorWithGalapagos", - "Egypt", - "ElSalvador", - "EquatorialGuinea", - "Eritrea", - "EuropaIsland", - "Fiji", - "Fiji2", - "FrenchGuiana", - "FrenchGuyana", - "FrenchPolynesia", - "Gabon", - "Gambia", - "Ghana", - "Grenada", - "Guadeloupe", - "Guam", - "Guatemala", - "Guinea", - "GuineaBissau", - "Guyana", - "Haiti", - "Hawaii", - "Honduras", - "HongKong", - "India", - "Indonesia", - "Iran", - "Jamaica", - "Japan", - "Kenya", - "Liberia", - "Macau", - "Madagascar", - "Malaysia", - "Martinique", - "Mauritania", - "Mayotte", - "Mexico", - "Micronesia", - "Mozambique", - "Myanmar", - "NewCaledonia", - "NewZealand", - "Newzealand", - "Nicaragua", - "Nigeria", - "NorthernMarianaIslands", - "Oman", - "Pakistan", - "Palau", - "Panama", - "PapuaNewGuinea", - "Peru", - "Philipines", - "PuertoRico", - "Qatar", - "ReunionAndMauritius", - "SaintKittsAndNevis", - "SaintLucia", - "SaintVincentAndTheGrenadines", - "Samoa", - "SaudiArabia", - "Senegal", - "Seychelles", - "SierraLeone", - "Singapore", - "SolomonIslands", - "Somalia", - "Somalia2", - "Soudan", - "SouthAfrica", - "SriLanka", - "Sudan", - "Suriname", - "Taiwan", - "Tanzania", - "Thailand", - "TimorLeste", - "Togo", - "Tonga", - "TrinidadAndTobago", - "TurksAndCaicosIslands", - "Tuvalu", - "UnitedArabEmirates", - "UnitedStates", - "Vanuatu", - "Venezuela", - "Vietnam", - "VirginIslandsUs", - "WallisAndFutuna", - "Yemen", + 'AndamanAndNicobar', + 'Angola', + 'Anguilla', + 'AntiguaAndBarbuda', + 'Aruba', + 'Australia', + 'Bahamas', + 'Bahrain', + 'Bangladesh', + 'Barbados', + 'Belize', + 'Benin', + 'Brazil', + 'BritishVirginIslands', + 'Brunei', + 'Cambodia', + 'Cameroon', + 'CarribeanCaymanIslands', + 'China', + 'Colombia', + 'Comoros', + 'CostaRica', + 'Cote', + 'CoteDivoire', + 'CotedIvoire', + 'Cuba', + 'DemocraticRepublicOfCongo', + 'Djibouti', + 'DominicanRepublic', + 'EcuadorWithGalapagos', + 'Egypt', + 'ElSalvador', + 'EquatorialGuinea', + 'Eritrea', + 'EuropaIsland', + 'Fiji', + 'Fiji2', + 'FrenchGuiana', + 'FrenchGuyana', + 'FrenchPolynesia', + 'Gabon', + 'Gambia', + 'Ghana', + 'Grenada', + 'Guadeloupe', + 'Guam', + 'Guatemala', + 'Guinea', + 'GuineaBissau', + 'Guyana', + 'Haiti', + 'Hawaii', + 'Honduras', + 'HongKong', + 'India', + 'Indonesia', + 'Iran', + 'Jamaica', + 'Japan', + 'Kenya', + 'Liberia', + 'Macau', + 'Madagascar', + 'Malaysia', + 'Martinique', + 'Mauritania', + 'Mayotte', + 'Mexico', + 'Micronesia', + 'Mozambique', + 'Myanmar', + 'NewCaledonia', + 'NewZealand', + 'Newzealand', + 'Nicaragua', + 'Nigeria', + 'NorthernMarianaIslands', + 'Oman', + 'Pakistan', + 'Palau', + 'Panama', + 'PapuaNewGuinea', + 'Peru', + 'Philipines', + 'PuertoRico', + 'Qatar', + 'ReunionAndMauritius', + 'SaintKittsAndNevis', + 'SaintLucia', + 'SaintVincentAndTheGrenadines', + 'Samoa', + 'SaudiArabia', + 'Senegal', + 'Seychelles', + 'SierraLeone', + 'Singapore', + 'SolomonIslands', + 'Somalia', + 'Somalia2', + 'Soudan', + 'SouthAfrica', + 'SriLanka', + 'Sudan', + 'Suriname', + 'Taiwan', + 'Tanzania', + 'Thailand', + 'TimorLeste', + 'Togo', + 'Tonga', + 'TrinidadAndTobago', + 'TurksAndCaicosIslands', + 'Tuvalu', + 'UnitedArabEmirates', + 'UnitedStates', + 'Vanuatu', + 'Venezuela', + 'Vietnam', + 'VirginIslandsUs', + 'WallisAndFutuna', + 'Yemen', ] - measurements = ["agb", "hba95", "hmax95"] + measurements = ['agb', 'hba95', 'hmax95'] def __init__( self, - paths: str | list[str] = "data", + paths: str | list[str] = 'data', crs: CRS | None = None, res: float | None = None, - measurement: str = "agb", + measurement: str = 'agb', country: str = all_countries[0], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, @@ -202,19 +202,19 @@ def __init__( self.paths = paths self.checksum = checksum - assert isinstance(country, str), "Country argument must be a str." + assert isinstance(country, str), 'Country argument must be a str.' assert ( country in self.all_countries - ), f"You have selected an invalid country, please choose one of {self.all_countries}" + ), f'You have selected an invalid country, please choose one of {self.all_countries}' self.country = country - assert isinstance(measurement, str), "Measurement must be a string." + assert isinstance(measurement, str), 'Measurement must be a string.' assert ( measurement in self.measurements - ), f"You have entered an invalid measurement, please choose one of {self.measurements}." + ), f'You have entered an invalid measurement, please choose one of {self.measurements}.' self.measurement = measurement - self.filename_glob = f"**/Mangrove_{self.measurement}_{self.country}*" + self.filename_glob = f'**/Mangrove_{self.measurement}_{self.country}*' self._verify() @@ -231,7 +231,7 @@ def _verify(self) -> None: pathname = os.path.join(self.paths, self.zipfile) if os.path.exists(pathname): if self.checksum and not check_integrity(pathname, self.md5): - raise RuntimeError("Dataset found, but corrupted.") + raise RuntimeError('Dataset found, but corrupted.') self._extract() return @@ -259,29 +259,29 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - mask = sample["mask"].squeeze() + mask = sample['mask'].squeeze() ncols = 1 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - pred = sample["prediction"].squeeze() + pred = sample['prediction'].squeeze() ncols = 2 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4)) if showing_predictions: axs[0].imshow(mask) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(pred) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Mask") - axs[1].set_title("Prediction") + axs[0].set_title('Mask') + axs[1].set_title('Prediction') else: axs.imshow(mask) - axs.axis("off") + axs.axis('off') if show_titles: - axs.set_title("Mask") + axs.set_title('Mask') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index 902cbdcb49d..5ff49510b66 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -64,8 +64,8 @@ def filename(self) -> str: def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -84,7 +84,7 @@ def __init__( AssertionError: if ``split`` argument is invalid DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert split in ["train", "test"] + assert split in ['train', 'test'] self.root = root self.split = split @@ -101,10 +101,10 @@ def __init__( self.targets = [] with open( os.path.join(self.root, self.filename.format(split)), - encoding="utf-8-sig", - newline="", + encoding='utf-8-sig', + newline='', ) as f: - reader = csv.reader(f, delimiter=" ") + reader = csv.reader(f, delimiter=' ') for row in reader: self.images.append(row[0]) self.targets.append(row[1]) @@ -118,7 +118,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and label at that index """ - sample = {"image": self._load_image(index), "label": self._load_target(index)} + sample = {'image': self._load_image(index), 'label': self._load_target(index)} if self.transforms is not None: sample = self.transforms(sample) @@ -144,7 +144,7 @@ def _load_image(self, index: int) -> Tensor: """ filename = os.path.join(self.root, self.images[index]) with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img) + array: 'np.typing.NDArray[np.int_]' = np.array(img) tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -178,7 +178,7 @@ def _check_integrity(self) -> bool: def _download(self) -> None: """Download the dataset and extract it.""" if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return for filename, md5 in zip(self.filenames, self.md5s): @@ -207,23 +207,23 @@ def plot( .. versionadded:: 0.2 """ - image = sample["image"] - label = cast(str, sample["label"].item()) + image = sample['image'] + label = cast(str, sample['label'].item()) - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = cast(str, sample["prediction"].item()) + prediction = cast(str, sample['prediction'].item()) else: prediction = None fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(image.permute(1, 2, 0)) - ax.axis("off") + ax.axis('off') if show_titles: - title = f"Label: {label}" + title = f'Label: {label}' if prediction is not None: - title += f"\nPrediction: {prediction}" + title += f'\nPrediction: {prediction}' ax.set_title(title) if suptitle is not None: @@ -236,58 +236,58 @@ class COWCCounting(COWC): """COWC Dataset for car counting.""" base_url = ( - "https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/counting/" + 'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/counting/' ) filenames = [ - "COWC_train_list_64_class.txt.bz2", - "COWC_test_list_64_class.txt.bz2", - "COWC_Counting_Toronto_ISPRS.tbz", - "COWC_Counting_Selwyn_LINZ.tbz", - "COWC_Counting_Potsdam_ISPRS.tbz", - "COWC_Counting_Vaihingen_ISPRS.tbz", - "COWC_Counting_Columbus_CSUAV_AFRL.tbz", - "COWC_Counting_Utah_AGRC.tbz", + 'COWC_train_list_64_class.txt.bz2', + 'COWC_test_list_64_class.txt.bz2', + 'COWC_Counting_Toronto_ISPRS.tbz', + 'COWC_Counting_Selwyn_LINZ.tbz', + 'COWC_Counting_Potsdam_ISPRS.tbz', + 'COWC_Counting_Vaihingen_ISPRS.tbz', + 'COWC_Counting_Columbus_CSUAV_AFRL.tbz', + 'COWC_Counting_Utah_AGRC.tbz', ] md5s = [ - "187543d20fa6d591b8da51136e8ef8fb", - "930cfd6e160a7b36db03146282178807", - "bc2613196dfa93e66d324ae43e7c1fdb", - "ea842ae055f5c74d0d933d2194764545", - "19a77ab9932b722ef52b197d70e68ce7", - "4009c1e420566390746f5b4db02afdb9", - "daf8033c4e8ceebbf2c3cac3fabb8b10", - "777ec107ed2a3d54597a739ce74f95ad", + '187543d20fa6d591b8da51136e8ef8fb', + '930cfd6e160a7b36db03146282178807', + 'bc2613196dfa93e66d324ae43e7c1fdb', + 'ea842ae055f5c74d0d933d2194764545', + '19a77ab9932b722ef52b197d70e68ce7', + '4009c1e420566390746f5b4db02afdb9', + 'daf8033c4e8ceebbf2c3cac3fabb8b10', + '777ec107ed2a3d54597a739ce74f95ad', ] - filename = "COWC_{}_list_64_class.txt" + filename = 'COWC_{}_list_64_class.txt' class COWCDetection(COWC): """COWC Dataset for car detection.""" base_url = ( - "https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/detection/" + 'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/detection/' ) filenames = [ - "COWC_train_list_detection.txt.bz2", - "COWC_test_list_detection.txt.bz2", - "COWC_Detection_Toronto_ISPRS.tbz", - "COWC_Detection_Selwyn_LINZ.tbz", - "COWC_Detection_Potsdam_ISPRS.tbz", - "COWC_Detection_Vaihingen_ISPRS.tbz", - "COWC_Detection_Columbus_CSUAV_AFRL.tbz", - "COWC_Detection_Utah_AGRC.tbz", + 'COWC_train_list_detection.txt.bz2', + 'COWC_test_list_detection.txt.bz2', + 'COWC_Detection_Toronto_ISPRS.tbz', + 'COWC_Detection_Selwyn_LINZ.tbz', + 'COWC_Detection_Potsdam_ISPRS.tbz', + 'COWC_Detection_Vaihingen_ISPRS.tbz', + 'COWC_Detection_Columbus_CSUAV_AFRL.tbz', + 'COWC_Detection_Utah_AGRC.tbz', ] md5s = [ - "c954a5a3dac08c220b10cfbeec83893c", - "c6c2d0a78f12a2ad88b286b724a57c1a", - "11af24f43b198b0f13c8e94814008a48", - "22fd37a86961010f5d519a7da0e1fc72", - "bf053545cc1915d8b6597415b746fe48", - "23945d5b22455450a938382ccc2a8b27", - "f40522dc97bea41b10117d4a5b946a6f", - "195da7c9443a939a468c9f232fd86ee3", + 'c954a5a3dac08c220b10cfbeec83893c', + 'c6c2d0a78f12a2ad88b286b724a57c1a', + '11af24f43b198b0f13c8e94814008a48', + '22fd37a86961010f5d519a7da0e1fc72', + 'bf053545cc1915d8b6597415b746fe48', + '23945d5b22455450a938382ccc2a8b27', + 'f40522dc97bea41b10117d4a5b946a6f', + '195da7c9443a939a468c9f232fd86ee3', ] - filename = "COWC_{}_list_detection.txt" + filename = 'COWC_{}_list_detection.txt' # TODO: add COCW-M datasets: diff --git a/torchgeo/datasets/cropharvest.py b/torchgeo/datasets/cropharvest.py index 06eb484338f..9887d71db6b 100644 --- a/torchgeo/datasets/cropharvest.py +++ b/torchgeo/datasets/cropharvest.py @@ -55,47 +55,47 @@ class CropHarvest(NonGeoDataset): # https://github.com/nasaharvest/cropharvest/blob/main/cropharvest/bands.py all_bands = [ - "VV", - "VH", - "B2", - "B3", - "B4", - "B5", - "B6", - "B7", - "B8", - "B8A", - "B9", - "B11", - "B12", - "temperature_2m", - "total_precipitation", - "elevation", - "slope", - "NDVI", + 'VV', + 'VH', + 'B2', + 'B3', + 'B4', + 'B5', + 'B6', + 'B7', + 'B8', + 'B8A', + 'B9', + 'B11', + 'B12', + 'temperature_2m', + 'total_precipitation', + 'elevation', + 'slope', + 'NDVI', ] - rgb_bands = ["B4", "B3", "B2"] + rgb_bands = ['B4', 'B3', 'B2'] - features_url = "https://zenodo.org/records/7257688/files/features.tar.gz?download=1" - labels_url = "https://zenodo.org/records/7257688/files/labels.geojson?download=1" + features_url = 'https://zenodo.org/records/7257688/files/features.tar.gz?download=1' + labels_url = 'https://zenodo.org/records/7257688/files/labels.geojson?download=1' file_dict = { - "features": { - "url": features_url, - "filename": "features.tar.gz", - "extracted_filename": os.path.join("features", "arrays"), - "md5": "cad4df655c75caac805a80435e46ee3e", + 'features': { + 'url': features_url, + 'filename': 'features.tar.gz', + 'extracted_filename': os.path.join('features', 'arrays'), + 'md5': 'cad4df655c75caac805a80435e46ee3e', }, - "labels": { - "url": labels_url, - "filename": "labels.geojson", - "extracted_filename": "labels.geojson", - "md5": "bf7bae6812fc7213481aff6a2e34517d", + 'labels': { + 'url': labels_url, + 'filename': 'labels.geojson', + 'extracted_filename': 'labels.geojson', + 'md5': 'bf7bae6812fc7213481aff6a2e34517d', }, } def __init__( self, - root: str = "data", + root: str = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -117,7 +117,7 @@ def __init__( import h5py # noqa: F401 except ImportError: raise ImportError( - "h5py is not installed and is required to use this dataset" + 'h5py is not installed and is required to use this dataset' ) self.root = root @@ -129,9 +129,9 @@ def __init__( self.files = self._load_features(self.root) self.labels = self._load_labels(self.root) - self.classes = self.labels["properties.label"].unique() + self.classes = self.labels['properties.label'].unique() self.classes = self.classes[self.classes != np.array(None)] - self.classes = np.insert(self.classes, 0, ["None", "Other"]) + self.classes = np.insert(self.classes, 0, ['None', 'Other']) def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. @@ -143,10 +143,10 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: single pixel time-series array and label at that index """ files = self.files[index] - data = self._load_array(files["chip"]) + data = self._load_array(files['chip']) - label = self._load_label(files["index"], files["dataset"]) - sample = {"array": data, "label": label} + label = self._load_label(files['index'], files['dataset']) + sample = {'array': data, 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -173,15 +173,15 @@ def _load_features(self, root: str) -> list[dict[str, str]]: """ files = [] chips = glob.glob( - os.path.join(root, self.file_dict["features"]["extracted_filename"], "*.h5") + os.path.join(root, self.file_dict['features']['extracted_filename'], '*.h5') ) chips = sorted(os.path.basename(chip) for chip in chips) for chip in chips: chip_path = os.path.join( - root, self.file_dict["features"]["extracted_filename"], chip + root, self.file_dict['features']['extracted_filename'], chip ) - index = chip.split("_")[0] - dataset = chip.split("_")[1][:-3] + index = chip.split('_')[0] + dataset = chip.split('_')[1][:-3] files.append(dict(chip=chip_path, index=index, dataset=dataset)) return files @@ -194,10 +194,10 @@ def _load_labels(self, root: str) -> pd.DataFrame: Returns: pandas dataframe containing label data for each feature """ - filename = self.file_dict["labels"]["extracted_filename"] - with open(os.path.join(root, filename), encoding="utf8") as f: + filename = self.file_dict['labels']['extracted_filename'] + with open(os.path.join(root, filename), encoding='utf8') as f: data = json.load(f) - df = pd.json_normalize(data["features"]) + df = pd.json_normalize(data['features']) return df def _load_array(self, path: str) -> Tensor: @@ -212,8 +212,8 @@ def _load_array(self, path: str) -> Tensor: import h5py filename = os.path.join(path) - with h5py.File(filename, "r") as f: - array = f.get("array")[()] + with h5py.File(filename, 'r') as f: + array = f.get('array')[()] tensor = torch.from_numpy(array) return tensor @@ -229,15 +229,15 @@ def _load_label(self, idx: str, dataset: str) -> Tensor: """ index = int(idx) row = self.labels[ - (self.labels["properties.index"] == index) - & (self.labels["properties.dataset"] == dataset) + (self.labels['properties.index'] == index) + & (self.labels['properties.dataset'] == dataset) ] - row = row.to_dict(orient="records")[0] - label = "None" - if row["properties.label"]: - label = row["properties.label"] - elif row["properties.is_crop"] == 1: - label = "Other" + row = row.to_dict(orient='records')[0] + label = 'None' + if row['properties.label']: + label = row['properties.label'] + elif row['properties.is_crop'] == 1: + label = 'Other' return torch.tensor(np.where(self.classes == label)[0][0]) @@ -245,13 +245,13 @@ def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if feature files already exist feature_path = os.path.join( - self.root, self.file_dict["features"]["extracted_filename"] + self.root, self.file_dict['features']['extracted_filename'] ) feature_path_zip = os.path.join( - self.root, self.file_dict["features"]["filename"] + self.root, self.file_dict['features']['filename'] ) label_path = os.path.join( - self.root, self.file_dict["labels"]["extracted_filename"] + self.root, self.file_dict['labels']['extracted_filename'] ) # Check if labels exist if os.path.exists(label_path): @@ -273,24 +273,24 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset and extract it.""" - features_path = os.path.join(self.file_dict["features"]["filename"]) + features_path = os.path.join(self.file_dict['features']['filename']) download_url( - self.file_dict["features"]["url"], + self.file_dict['features']['url'], self.root, filename=features_path, - md5=self.file_dict["features"]["md5"] if self.checksum else None, + md5=self.file_dict['features']['md5'] if self.checksum else None, ) download_url( - self.file_dict["labels"]["url"], + self.file_dict['labels']['url'], self.root, - filename=os.path.join(self.file_dict["labels"]["filename"]), - md5=self.file_dict["labels"]["md5"] if self.checksum else None, + filename=os.path.join(self.file_dict['labels']['filename']), + md5=self.file_dict['labels']['md5'] if self.checksum else None, ) def _extract(self) -> None: """Extract the dataset.""" - features_path = os.path.join(self.root, self.file_dict["features"]["filename"]) + features_path = os.path.join(self.root, self.file_dict['features']['filename']) extract_archive(features_path) def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: @@ -305,13 +305,13 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure """ fig, axs = plt.subplots() bands = [self.all_bands.index(band) for band in self.rgb_bands] - rgb = np.array(sample["array"])[:, bands] / 3000 + rgb = np.array(sample['array'])[:, bands] / 3000 axs.imshow(rgb[None, ...]) axs.set_title(f'Crop type: {self.classes[sample["label"]]}') axs.set_xticks(np.arange(12)) axs.set_xticklabels(np.arange(12) + 1) axs.set_yticks([]) - axs.set_xlabel("Month") + axs.set_xlabel('Month') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py index 97dd8a3b616..7595e6a851a 100644 --- a/torchgeo/datasets/cv4a_kenya_crop_type.py +++ b/torchgeo/datasets/cv4a_kenya_crop_type.py @@ -64,56 +64,56 @@ class CV4AKenyaCropType(NonGeoDataset): """ collection_ids = [ - "ref_african_crops_kenya_02_labels", - "ref_african_crops_kenya_02_source", + 'ref_african_crops_kenya_02_labels', + 'ref_african_crops_kenya_02_source', ] image_meta = { - "filename": "ref_african_crops_kenya_02_source.tar.gz", - "md5": "9c2004782f6dc83abb1bf45ba4d0da46", + 'filename': 'ref_african_crops_kenya_02_source.tar.gz', + 'md5': '9c2004782f6dc83abb1bf45ba4d0da46', } target_meta = { - "filename": "ref_african_crops_kenya_02_labels.tar.gz", - "md5": "93949abd0ae82ba564f5a933cefd8215", + 'filename': 'ref_african_crops_kenya_02_labels.tar.gz', + 'md5': '93949abd0ae82ba564f5a933cefd8215', } tile_names = [ - "ref_african_crops_kenya_02_tile_00", - "ref_african_crops_kenya_02_tile_01", - "ref_african_crops_kenya_02_tile_02", - "ref_african_crops_kenya_02_tile_03", + 'ref_african_crops_kenya_02_tile_00', + 'ref_african_crops_kenya_02_tile_01', + 'ref_african_crops_kenya_02_tile_02', + 'ref_african_crops_kenya_02_tile_03', ] dates = [ - "20190606", - "20190701", - "20190706", - "20190711", - "20190721", - "20190805", - "20190815", - "20190825", - "20190909", - "20190919", - "20190924", - "20191004", - "20191103", + '20190606', + '20190701', + '20190706', + '20190711', + '20190721', + '20190805', + '20190815', + '20190825', + '20190909', + '20190919', + '20190924', + '20191004', + '20191103', ] band_names = ( - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B11", - "B12", - "CLD", + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', + 'CLD', ) - rgb_bands = ["B04", "B03", "B02"] + rgb_bands = ['B04', 'B03', 'B02'] # Same for all tiles tile_height = 3035 @@ -121,7 +121,7 @@ class CV4AKenyaCropType(NonGeoDataset): def __init__( self, - root: str = "data", + root: str = 'data', chip_size: int = 256, stride: int = 128, bands: tuple[str, ...] = band_names, @@ -196,12 +196,12 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: field_ids = field_ids[y : y + self.chip_size, x : x + self.chip_size] sample = { - "image": img, - "mask": labels, - "field_ids": field_ids, - "tile_index": torch.tensor(tile_index), - "x": torch.tensor(x), - "y": torch.tensor(y), + 'image': img, + 'mask': labels, + 'field_ids': field_ids, + 'tile_index': torch.tensor(tile_index), + 'x': torch.tensor(x), + 'y': torch.tensor(y), } if self.transforms is not None: @@ -233,17 +233,17 @@ def _load_label_tile(self, tile_name: str) -> tuple[Tensor, Tensor]: assert tile_name in self.tile_names if self.verbose: - print(f"Loading labels/field_ids for {tile_name}") + print(f'Loading labels/field_ids for {tile_name}') directory = os.path.join( - self.root, "ref_african_crops_kenya_02_labels", tile_name + "_label" + self.root, 'ref_african_crops_kenya_02_labels', tile_name + '_label' ) - with Image.open(os.path.join(directory, "labels.tif")) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img) + with Image.open(os.path.join(directory, 'labels.tif')) as img: + array: 'np.typing.NDArray[np.int_]' = np.array(img) labels = torch.from_numpy(array) - with Image.open(os.path.join(directory, "field_ids.tif")) as img: + with Image.open(os.path.join(directory, 'field_ids.tif')) as img: array = np.array(img) field_ids = torch.from_numpy(array) @@ -259,7 +259,7 @@ def _validate_bands(self, bands: tuple[str, ...]) -> None: AssertionError: if ``bands`` is not a tuple ValueError: if an invalid band name is provided """ - assert isinstance(bands, tuple), "The list of bands must be a tuple" + assert isinstance(bands, tuple), 'The list of bands must be a tuple' for band in bands: if band not in self.band_names: raise ValueError(f"'{band}' is an invalid band name.") @@ -286,7 +286,7 @@ def _load_all_image_tiles( assert tile_name in self.tile_names if self.verbose: - print(f"Loading all imagery for {tile_name}") + print(f'Loading all imagery for {tile_name}') img = torch.zeros( len(self.dates), @@ -324,7 +324,7 @@ def _load_single_image_tile( assert date in self.dates if self.verbose: - print(f"Loading imagery for {tile_name} at {date}") + print(f'Loading imagery for {tile_name} at {date}') img = torch.zeros( len(bands), self.tile_height, self.tile_width, dtype=torch.float32 @@ -332,12 +332,12 @@ def _load_single_image_tile( for band_index, band_name in enumerate(self.bands): filepath = os.path.join( self.root, - "ref_african_crops_kenya_02_source", - f"{tile_name}_{date}", - f"{band_name}.tif", + 'ref_african_crops_kenya_02_source', + f'{tile_name}_{date}', + f'{band_name}.tif', ) with Image.open(filepath) as band_img: - array: "np.typing.NDArray[np.int_]" = np.array(band_img) + array: 'np.typing.NDArray[np.int_]' = np.array(band_img) img[band_index] = torch.from_numpy(array) return img @@ -349,13 +349,13 @@ def _check_integrity(self) -> bool: True if dataset files are found and/or MD5s match, else False """ images: bool = check_integrity( - os.path.join(self.root, self.image_meta["filename"]), - self.image_meta["md5"] if self.checksum else None, + os.path.join(self.root, self.image_meta['filename']), + self.image_meta['md5'] if self.checksum else None, ) targets: bool = check_integrity( - os.path.join(self.root, self.target_meta["filename"]), - self.target_meta["md5"] if self.checksum else None, + os.path.join(self.root, self.target_meta['filename']), + self.target_meta['md5'] if self.checksum else None, ) return images and targets @@ -370,12 +370,12 @@ def get_splits(self) -> tuple[list[int], list[int]]: test_field_ids = [] splits_fn = os.path.join( self.root, - "ref_african_crops_kenya_02_labels", - "_common", - "field_train_test_ids.csv", + 'ref_african_crops_kenya_02_labels', + '_common', + 'field_train_test_ids.csv', ) - with open(splits_fn, newline="") as f: + with open(splits_fn, newline='') as f: reader = csv.reader(f) # Skip header row @@ -395,14 +395,14 @@ def _download(self, api_key: str | None = None) -> None: api_key: a RadiantEarth MLHub API key to use for downloading the dataset """ if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return for collection_id in self.collection_ids: download_radiant_mlhub_collection(collection_id, self.root, api_key) - image_archive_path = os.path.join(self.root, self.image_meta["filename"]) - target_archive_path = os.path.join(self.root, self.target_meta["filename"]) + image_archive_path = os.path.join(self.root, self.image_meta['filename']) + target_archive_path = os.path.join(self.root, self.target_meta['filename']) for fn in [image_archive_path, target_archive_path]: extract_archive(fn, self.root) @@ -436,18 +436,18 @@ def plot( else: raise RGBBandsMissingError() - if "prediction" in sample: - prediction = sample["prediction"] + if 'prediction' in sample: + prediction = sample['prediction'] n_cols = 3 else: n_cols = 2 - image, mask = sample["image"], sample["mask"] + image, mask = sample['image'], sample['mask'] assert time_step <= image.shape[0] - 1, ( - "The specified time step" - f" does not exist, image only contains {image.shape[0]} time" - " instances." + 'The specified time step' + f' does not exist, image only contains {image.shape[0]} time' + ' instances.' ) image = image[time_step, rgb_indices, :, :] @@ -455,19 +455,19 @@ def plot( fig, axs = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, n_cols * 5)) axs[0].imshow(image.permute(1, 2, 0)) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(mask) - axs[1].axis("off") + axs[1].axis('off') - if "prediction" in sample: + if 'prediction' in sample: axs[2].imshow(prediction) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[2].set_title("Prediction") + axs[2].set_title('Prediction') if show_titles: - axs[0].set_title("Image") - axs[1].set_title("Mask") + axs[0].set_title('Image') + axs[1].set_title('Mask') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 46f2e9d6f00..3b0d8fd0278 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -51,29 +51,29 @@ class TropicalCyclone(NonGeoDataset): to be consistent with TropicalCycloneDataModule. """ - collection_id = "nasa_tropical_storm_competition" + collection_id = 'nasa_tropical_storm_competition' collection_ids = [ - "nasa_tropical_storm_competition_train_source", - "nasa_tropical_storm_competition_test_source", - "nasa_tropical_storm_competition_train_labels", - "nasa_tropical_storm_competition_test_labels", + 'nasa_tropical_storm_competition_train_source', + 'nasa_tropical_storm_competition_test_source', + 'nasa_tropical_storm_competition_train_labels', + 'nasa_tropical_storm_competition_test_labels', ] md5s = { - "train": { - "source": "97e913667a398704ea8d28196d91dad6", - "labels": "97d02608b74c82ffe7496a9404a30413", + 'train': { + 'source': '97e913667a398704ea8d28196d91dad6', + 'labels': '97d02608b74c82ffe7496a9404a30413', }, - "test": { - "source": "8d88099e4b310feb7781d776a6e1dcef", - "labels": "d910c430f90153c1f78a99cbc08e7bd0", + 'test': { + 'source': '8d88099e4b310feb7781d776a6e1dcef', + 'labels': 'd910c430f90153c1f78a99cbc08e7bd0', }, } size = 366 def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, api_key: str | None = None, @@ -107,10 +107,10 @@ def __init__( if not self._check_integrity(): raise DatasetNotFoundError(self) - output_dir = "_".join([self.collection_id, split, "source"]) - filename = os.path.join(root, output_dir, "collection.json") + output_dir = '_'.join([self.collection_id, split, 'source']) + filename = os.path.join(root, output_dir, 'collection.json') with open(filename) as f: - self.collection = json.load(f)["links"] + self.collection = json.load(f)['links'] def __getitem__(self, index: int) -> dict[str, Any]: """Return an index within the dataset. @@ -121,14 +121,14 @@ def __getitem__(self, index: int) -> dict[str, Any]: Returns: data, labels, field ids, and metadata at that index """ - source_id = os.path.split(self.collection[index]["href"])[0] + source_id = os.path.split(self.collection[index]['href'])[0] directory = os.path.join( self.root, - "_".join([self.collection_id, self.split, "{0}"]), - source_id.replace("source", "{0}"), + '_'.join([self.collection_id, self.split, '{0}']), + source_id.replace('source', '{0}'), ) - sample: dict[str, Any] = {"image": self._load_image(directory)} + sample: dict[str, Any] = {'image': self._load_image(directory)} sample.update(self._load_features(directory)) if self.transforms is not None: @@ -154,7 +154,7 @@ def _load_image(self, directory: str) -> Tensor: Returns: the image """ - filename = os.path.join(directory.format("source"), "image.jpg") + filename = os.path.join(directory.format('source'), 'image.jpg') with Image.open(filename) as img: if img.height != self.size or img.width != self.size: # Moved in PIL 9.1.0 @@ -163,7 +163,7 @@ def _load_image(self, directory: str) -> Tensor: except AttributeError: resample = Image.BILINEAR # type: ignore[attr-defined] img = img.resize(size=(self.size, self.size), resample=resample) - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('RGB')) tensor = torch.from_numpy(array) tensor = tensor.permute((2, 0, 1)).float() return tensor @@ -177,17 +177,17 @@ def _load_features(self, directory: str) -> dict[str, Any]: Returns: the features """ - filename = os.path.join(directory.format("source"), "features.json") + filename = os.path.join(directory.format('source'), 'features.json') with open(filename) as f: features: dict[str, Any] = json.load(f) - filename = os.path.join(directory.format("labels"), "labels.json") + filename = os.path.join(directory.format('labels'), 'labels.json') with open(filename) as f: features.update(json.load(f)) - features["relative_time"] = int(features["relative_time"]) - features["ocean"] = int(features["ocean"]) - features["label"] = torch.tensor(int(features["wind_speed"])).float() + features['relative_time'] = int(features['relative_time']) + features['ocean'] = int(features['ocean']) + features['label'] = torch.tensor(int(features['wind_speed'])).float() return features @@ -199,8 +199,8 @@ def _check_integrity(self) -> bool: """ for split, resources in self.md5s.items(): for resource_type, md5 in resources.items(): - filename = "_".join([self.collection_id, split, resource_type]) - filename = os.path.join(self.root, filename + ".tar.gz") + filename = '_'.join([self.collection_id, split, resource_type]) + filename = os.path.join(self.root, filename + '.tar.gz') if not check_integrity(filename, md5 if self.checksum else None): return False return True @@ -212,7 +212,7 @@ def _download(self, api_key: str | None = None) -> None: api_key: a RadiantEarth MLHub API key to use for downloading the dataset """ if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return for collection_id in self.collection_ids: @@ -220,8 +220,8 @@ def _download(self, api_key: str | None = None) -> None: for split, resources in self.md5s.items(): for resource_type in resources: - filename = "_".join([self.collection_id, split, resource_type]) - filename = os.path.join(self.root, filename) + ".tar.gz" + filename = '_'.join([self.collection_id, split, resource_type]) + filename = os.path.join(self.root, filename) + '.tar.gz' extract_archive(filename, self.root) def plot( @@ -242,21 +242,21 @@ def plot( .. versionadded:: 0.2 """ - image, label = sample["image"], sample["label"] + image, label = sample['image'], sample['label'] - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = sample["prediction"].item() + prediction = sample['prediction'].item() fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(image.permute(1, 2, 0)) - ax.axis("off") + ax.axis('off') if show_titles: - title = f"Label: {label}" + title = f'Label: {label}' if showing_predictions: - title += f"\nPrediction: {prediction}" + title += f'\nPrediction: {prediction}' ax.set_title(title) if suptitle is not None: diff --git a/torchgeo/datasets/deepglobelandcover.py b/torchgeo/datasets/deepglobelandcover.py index 03b9a91ff29..da7d8ef2ae4 100644 --- a/torchgeo/datasets/deepglobelandcover.py +++ b/torchgeo/datasets/deepglobelandcover.py @@ -75,18 +75,18 @@ class DeepGlobeLandCover(NonGeoDataset): .. versionadded:: 0.3 """ # noqa: E501 - filename = "data.zip" - data_root = "data" - md5 = "f32684b0b2bf6f8d604cd359a399c061" - splits = ["train", "test"] + filename = 'data.zip' + data_root = 'data' + md5 = 'f32684b0b2bf6f8d604cd359a399c061' + splits = ['train', 'test'] classes = [ - "Urban land", - "Agriculture land", - "Rangeland", - "Forest land", - "Water", - "Barren land", - "Unknown", + 'Urban land', + 'Agriculture land', + 'Rangeland', + 'Forest land', + 'Water', + 'Barren land', + 'Unknown', ] colormap = [ (0, 255, 255), @@ -100,8 +100,8 @@ class DeepGlobeLandCover(NonGeoDataset): def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: @@ -124,23 +124,23 @@ def __init__( self.checksum = checksum self._verify() - if split == "train": - split_folder = "training_data" + if split == 'train': + split_folder = 'training_data' else: - split_folder = "test_data" + split_folder = 'test_data' self.image_fns = [] self.mask_fns = [] for image in sorted( - os.listdir(os.path.join(root, self.data_root, split_folder, "images")) + os.listdir(os.path.join(root, self.data_root, split_folder, 'images')) ): - if image.endswith(".jpg"): + if image.endswith('.jpg'): id = image[:-8] image_path = os.path.join( - root, self.data_root, split_folder, "images", image + root, self.data_root, split_folder, 'images', image ) mask_path = os.path.join( - root, self.data_root, split_folder, "masks", str(id) + "_mask.png" + root, self.data_root, split_folder, 'masks', str(id) + '_mask.png' ) self.image_fns.append(image_path) @@ -157,7 +157,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ image = self._load_image(index) mask = self._load_target(index) - sample = {"image": image, "mask": mask} + sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -184,7 +184,7 @@ def _load_image(self, index: int) -> Tensor: path = self.image_fns[index] with Image.open(path) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img) + array: 'np.typing.NDArray[np.int_]' = np.array(img) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)).to(torch.float32) @@ -201,7 +201,7 @@ def _load_target(self, index: int) -> Tensor: """ path = self.mask_fns[index] with Image.open(path) as img: - array: "np.typing.NDArray[np.uint8]" = np.array(img) + array: 'np.typing.NDArray[np.uint8]' = np.array(img) array = rgb_to_mask(array, self.colormap) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW @@ -219,7 +219,7 @@ def _verify(self) -> None: if os.path.isfile(filepath): if self.checksum and not check_integrity(filepath, self.md5): - raise RuntimeError("Dataset found, but corrupted.") + raise RuntimeError('Dataset found, but corrupted.') extract_archive(filepath) return @@ -245,12 +245,12 @@ def plot( """ ncols = 1 image1 = draw_semantic_segmentation_masks( - sample["image"], sample["mask"], alpha=alpha, colors=self.colormap + sample['image'], sample['mask'], alpha=alpha, colors=self.colormap ) - if "prediction" in sample: + if 'prediction' in sample: ncols += 1 image2 = draw_semantic_segmentation_masks( - sample["image"], sample["prediction"], alpha=alpha, colors=self.colormap + sample['image'], sample['prediction'], alpha=alpha, colors=self.colormap ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) @@ -260,15 +260,15 @@ def plot( ax0 = axs ax0.imshow(image1) - ax0.axis("off") + ax0.axis('off') if ncols > 1: ax1.imshow(image2) - ax1.axis("off") + ax1.axis('off') if show_titles: - ax0.set_title("Ground Truth") + ax0.set_title('Ground Truth') if ncols > 1: - ax1.set_title("Predictions") + ax1.set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/dfc2022.py b/torchgeo/datasets/dfc2022.py index 40e0ad9b6c3..c78151e91ce 100644 --- a/torchgeo/datasets/dfc2022.py +++ b/torchgeo/datasets/dfc2022.py @@ -82,67 +82,67 @@ class DFC2022(NonGeoDataset): """ # noqa: E501 classes = [ - "No information", - "Urban fabric", - "Industrial, commercial, public, military, private and transport units", - "Mine, dump and construction sites", - "Artificial non-agricultural vegetated areas", - "Arable land (annual crops)", - "Permanent crops", - "Pastures", - "Complex and mixed cultivation patterns", - "Orchards at the fringe of urban classes", - "Forests", - "Herbaceous vegetation associations", - "Open spaces with little or no vegetation", - "Wetlands", - "Water", - "Clouds and Shadows", + 'No information', + 'Urban fabric', + 'Industrial, commercial, public, military, private and transport units', + 'Mine, dump and construction sites', + 'Artificial non-agricultural vegetated areas', + 'Arable land (annual crops)', + 'Permanent crops', + 'Pastures', + 'Complex and mixed cultivation patterns', + 'Orchards at the fringe of urban classes', + 'Forests', + 'Herbaceous vegetation associations', + 'Open spaces with little or no vegetation', + 'Wetlands', + 'Water', + 'Clouds and Shadows', ] colormap = [ - "#231F20", - "#DB5F57", - "#DB9757", - "#DBD057", - "#ADDB57", - "#75DB57", - "#7BC47B", - "#58B158", - "#D4F6D4", - "#B0E2B0", - "#008000", - "#58B0A7", - "#995D13", - "#579BDB", - "#0062FF", - "#231F20", + '#231F20', + '#DB5F57', + '#DB9757', + '#DBD057', + '#ADDB57', + '#75DB57', + '#7BC47B', + '#58B158', + '#D4F6D4', + '#B0E2B0', + '#008000', + '#58B0A7', + '#995D13', + '#579BDB', + '#0062FF', + '#231F20', ] metadata = { - "train": { - "filename": "labeled_train.zip", - "md5": "2e87d6a218e466dd0566797d7298c7a9", - "directory": "labeled_train", + 'train': { + 'filename': 'labeled_train.zip', + 'md5': '2e87d6a218e466dd0566797d7298c7a9', + 'directory': 'labeled_train', }, - "train-unlabeled": { - "filename": "unlabeled_train.zip", - "md5": "1016d724bc494b8c50760ae56bb0585e", - "directory": "unlabeled_train", + 'train-unlabeled': { + 'filename': 'unlabeled_train.zip', + 'md5': '1016d724bc494b8c50760ae56bb0585e', + 'directory': 'unlabeled_train', }, - "val": { - "filename": "val.zip", - "md5": "6ddd9c0f89d8e74b94ea352d4002073f", - "directory": "val", + 'val': { + 'filename': 'val.zip', + 'md5': '6ddd9c0f89d8e74b94ea352d4002073f', + 'directory': 'val', }, } - image_root = "BDORTHO" - dem_root = "RGEALTI" - target_root = "UrbanAtlas" + image_root = 'BDORTHO' + dem_root = 'RGEALTI' + target_root = 'UrbanAtlas' def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: @@ -180,15 +180,15 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - image = self._load_image(files["image"]) - dem = self._load_image(files["dem"], shape=image.shape[1:]) + image = self._load_image(files['image']) + dem = self._load_image(files['dem'], shape=image.shape[1:]) image = torch.cat(tensors=[image, dem], dim=0) - sample = {"image": image} + sample = {'image': image} - if self.split == "train": - mask = self._load_target(files["target"]) - sample["mask"] = mask + if self.split == 'train': + mask = self._load_target(files['target']) + sample['mask'] = mask if self.transforms is not None: sample = self.transforms(sample) @@ -209,19 +209,19 @@ def _load_files(self) -> list[dict[str, str]]: Returns: list of dicts containing paths for each pair of image/dem/mask """ - directory = os.path.join(self.root, self.metadata[self.split]["directory"]) + directory = os.path.join(self.root, self.metadata[self.split]['directory']) images = glob.glob( - os.path.join(directory, "**", self.image_root, "*.tif"), recursive=True + os.path.join(directory, '**', self.image_root, '*.tif'), recursive=True ) files = [] for image in sorted(images): dem = image.replace(self.image_root, self.dem_root) - dem = f"{os.path.splitext(dem)[0]}_RGEALTI.tif" + dem = f'{os.path.splitext(dem)[0]}_RGEALTI.tif' - if self.split == "train": + if self.split == 'train': target = image.replace(self.image_root, self.target_root) - target = f"{os.path.splitext(target)[0]}_UA2012.tif" + target = f'{os.path.splitext(target)[0]}_UA2012.tif' files.append(dict(image=image, dem=dem, target=target)) else: files.append(dict(image=image, dem=dem)) @@ -239,8 +239,8 @@ def _load_image(self, path: str, shape: Sequence[int] | None = None) -> Tensor: the image """ with rasterio.open(path) as f: - array: "np.typing.NDArray[np.float_]" = f.read( - out_shape=shape, out_dtype="float32", resampling=Resampling.bilinear + array: 'np.typing.NDArray[np.float_]' = f.read( + out_shape=shape, out_dtype='float32', resampling=Resampling.bilinear ) tensor = torch.from_numpy(array) return tensor @@ -255,8 +255,8 @@ def _load_target(self, path: str) -> Tensor: the target mask """ with rasterio.open(path) as f: - array: "np.typing.NDArray[np.int_]" = f.read( - indexes=1, out_dtype="int32", resampling=Resampling.bilinear + array: 'np.typing.NDArray[np.int_]' = f.read( + indexes=1, out_dtype='int32', resampling=Resampling.bilinear ) tensor = torch.from_numpy(array) tensor = tensor.to(torch.long) @@ -268,7 +268,7 @@ def _verify(self) -> None: exists = [] for split_info in self.metadata.values(): exists.append( - os.path.exists(os.path.join(self.root, split_info["directory"])) + os.path.exists(os.path.join(self.root, split_info['directory'])) ) if all(exists): @@ -277,10 +277,10 @@ def _verify(self) -> None: # Check if .zip files already exists (if so then extract) exists = [] for split_info in self.metadata.values(): - filepath = os.path.join(self.root, split_info["filename"]) + filepath = os.path.join(self.root, split_info['filename']) if os.path.isfile(filepath): - if self.checksum and not check_integrity(filepath, split_info["md5"]): - raise RuntimeError("Dataset found, but corrupted.") + if self.checksum and not check_integrity(filepath, split_info['md5']): + raise RuntimeError('Dataset found, but corrupted.') exists.append(True) extract_archive(filepath) else: @@ -308,51 +308,51 @@ def plot( a matplotlib Figure with the rendered sample """ ncols = 2 - image = sample["image"][:3] + image = sample['image'][:3] image = image.to(torch.uint8) image = image.permute(1, 2, 0).numpy() - dem = sample["image"][-1].numpy() + dem = sample['image'][-1].numpy() dem = percentile_normalization(dem, lower=0, upper=100, axis=(0, 1)) - showing_mask = "mask" in sample - showing_prediction = "prediction" in sample + showing_mask = 'mask' in sample + showing_prediction = 'prediction' in sample cmap = colors.ListedColormap(self.colormap) if showing_mask: - mask = sample["mask"].numpy() + mask = sample['mask'].numpy() ncols += 1 if showing_prediction: - pred = sample["prediction"].numpy() + pred = sample['prediction'].numpy() ncols += 1 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 10)) axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(dem) - axs[1].axis("off") + axs[1].axis('off') if showing_mask: - axs[2].imshow(mask, cmap=cmap, interpolation="none") - axs[2].axis("off") + axs[2].imshow(mask, cmap=cmap, interpolation='none') + axs[2].axis('off') if showing_prediction: - axs[3].imshow(pred, cmap=cmap, interpolation="none") - axs[3].axis("off") + axs[3].imshow(pred, cmap=cmap, interpolation='none') + axs[3].axis('off') elif showing_prediction: - axs[2].imshow(pred, cmap=cmap, interpolation="none") - axs[2].axis("off") + axs[2].imshow(pred, cmap=cmap, interpolation='none') + axs[2].axis('off') if show_titles: - axs[0].set_title("Image") - axs[1].set_title("DEM") + axs[0].set_title('Image') + axs[1].set_title('DEM') if showing_mask: - axs[2].set_title("Ground Truth") + axs[2].set_title('Ground Truth') if showing_prediction: - axs[3].set_title("Predictions") + axs[3].set_title('Predictions') elif showing_prediction: - axs[2].set_title("Predictions") + axs[2].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/eddmaps.py b/torchgeo/datasets/eddmaps.py index 94e409c4d07..599b1e091b1 100644 --- a/torchgeo/datasets/eddmaps.py +++ b/torchgeo/datasets/eddmaps.py @@ -41,7 +41,7 @@ class EDDMapS(GeoDataset): res = 0 _crs = CRS.from_epsg(4326) # Lat/Lon - def __init__(self, root: str = "data") -> None: + def __init__(self, root: str = 'data') -> None: """Initialize a new Dataset instance. Args: @@ -54,13 +54,13 @@ def __init__(self, root: str = "data") -> None: self.root = root - filepath = os.path.join(root, "mappings.csv") + filepath = os.path.join(root, 'mappings.csv') if not os.path.exists(filepath): raise DatasetNotFoundError(self) # Read CSV file data = pd.read_csv( - filepath, engine="c", usecols=["ObsDate", "Latitude", "Longitude"] + filepath, engine='c', usecols=['ObsDate', 'Latitude', 'Longitude'] ) # Convert from pandas DataFrame to rtree Index @@ -71,7 +71,7 @@ def __init__(self, root: str = "data") -> None: continue if not pd.isna(date): - mint, maxt = disambiguate_timestamp(date, "%m-%d-%y") + mint, maxt = disambiguate_timestamp(date, '%m-%d-%y') else: mint, maxt = 0, sys.maxsize @@ -96,9 +96,9 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if not bboxes: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) - sample = {"crs": self.crs, "bbox": bboxes} + sample = {'crs': self.crs, 'bbox': bboxes} return sample diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index 0518ffeb8d3..7ea2f366195 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -46,83 +46,83 @@ class EnviroAtlas(GeoDataset): .. versionadded:: 0.3 """ - url = "https://zenodo.org/record/5778193/files/enviroatlas_lotp.zip?download=1" - filename = "enviroatlas_lotp.zip" - md5 = "bfe601be21c7c001315fc6154be8ef14" + url = 'https://zenodo.org/record/5778193/files/enviroatlas_lotp.zip?download=1' + filename = 'enviroatlas_lotp.zip' + md5 = 'bfe601be21c7c001315fc6154be8ef14' crs = CRS.from_epsg(3857) res = 1 - valid_prior_layers = ["prior", "prior_no_osm_no_buildings"] + valid_prior_layers = ['prior', 'prior_no_osm_no_buildings'] valid_layers = [ - "naip", - "nlcd", - "roads", - "water", - "waterways", - "waterbodies", - "buildings", - "lc", + 'naip', + 'nlcd', + 'roads', + 'water', + 'waterways', + 'waterbodies', + 'buildings', + 'lc', ] + valid_prior_layers cities = [ - "pittsburgh_pa-2010_1m", - "durham_nc-2012_1m", - "austin_tx-2012_1m", - "phoenix_az-2010_1m", + 'pittsburgh_pa-2010_1m', + 'durham_nc-2012_1m', + 'austin_tx-2012_1m', + 'phoenix_az-2010_1m', ] splits = ( - [f"{state}-train" for state in cities[:1]] - + [f"{state}-val" for state in cities[:1]] - + [f"{state}-test" for state in cities] - + [f"{state}-val5" for state in cities] + [f'{state}-train' for state in cities[:1]] + + [f'{state}-val' for state in cities[:1]] + + [f'{state}-test' for state in cities] + + [f'{state}-val5' for state in cities] ) # these are used to check the integrity of the dataset _files = [ - "austin_tx-2012_1m-test_tiles-debuffered", - "austin_tx-2012_1m-val5_tiles-debuffered", - "durham_nc-2012_1m-test_tiles-debuffered", - "durham_nc-2012_1m-val5_tiles-debuffered", - "phoenix_az-2010_1m-test_tiles-debuffered", - "phoenix_az-2010_1m-val5_tiles-debuffered", - "pittsburgh_pa-2010_1m-test_tiles-debuffered", - "pittsburgh_pa-2010_1m-train_tiles-debuffered", - "pittsburgh_pa-2010_1m-val5_tiles-debuffered", - "pittsburgh_pa-2010_1m-val_tiles-debuffered", - "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_a_naip.tif", - "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_b_nlcd.tif", - "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_c_roads.tif", - "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d1_waterways.tif", - "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d2_waterbodies.tif", - "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d_water.tif", - "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_e_buildings.tif", - "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_h_highres_labels.tif", - "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif", # noqa: E501 - "austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif", # noqa: E501 - "spatial_index.geojson", + 'austin_tx-2012_1m-test_tiles-debuffered', + 'austin_tx-2012_1m-val5_tiles-debuffered', + 'durham_nc-2012_1m-test_tiles-debuffered', + 'durham_nc-2012_1m-val5_tiles-debuffered', + 'phoenix_az-2010_1m-test_tiles-debuffered', + 'phoenix_az-2010_1m-val5_tiles-debuffered', + 'pittsburgh_pa-2010_1m-test_tiles-debuffered', + 'pittsburgh_pa-2010_1m-train_tiles-debuffered', + 'pittsburgh_pa-2010_1m-val5_tiles-debuffered', + 'pittsburgh_pa-2010_1m-val_tiles-debuffered', + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_a_naip.tif', + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_b_nlcd.tif', + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_c_roads.tif', + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d1_waterways.tif', + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d2_waterbodies.tif', + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d_water.tif', + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_e_buildings.tif', + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_h_highres_labels.tif', + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif', # noqa: E501 + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501 + 'spatial_index.geojson', ] - p_src_crs = pyproj.CRS("epsg:3857") + p_src_crs = pyproj.CRS('epsg:3857') p_transformers = { - "epsg:26917": pyproj.Transformer.from_crs( - p_src_crs, pyproj.CRS("epsg:26917"), always_xy=True + 'epsg:26917': pyproj.Transformer.from_crs( + p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True ).transform, - "epsg:26918": pyproj.Transformer.from_crs( - p_src_crs, pyproj.CRS("epsg:26918"), always_xy=True + 'epsg:26918': pyproj.Transformer.from_crs( + p_src_crs, pyproj.CRS('epsg:26918'), always_xy=True ).transform, - "epsg:26914": pyproj.Transformer.from_crs( - p_src_crs, pyproj.CRS("epsg:26914"), always_xy=True + 'epsg:26914': pyproj.Transformer.from_crs( + p_src_crs, pyproj.CRS('epsg:26914'), always_xy=True ).transform, - "epsg:26912": pyproj.Transformer.from_crs( - p_src_crs, pyproj.CRS("epsg:26912"), always_xy=True + 'epsg:26912': pyproj.Transformer.from_crs( + p_src_crs, pyproj.CRS('epsg:26912'), always_xy=True ).transform, } # used to convert the 10 high-res classes labeled as [0, 10, 20, 30, 40, 52, 70, 80, # 82, 91, 92] to sequential labels [0, ..., 10] - raw_enviroatlas_to_idx_map: "np.typing.NDArray[np.uint8]" = np.array( + raw_enviroatlas_to_idx_map: 'np.typing.NDArray[np.uint8]' = np.array( [ 0, 0, @@ -222,17 +222,17 @@ class EnviroAtlas(GeoDataset): ) highres_classes = [ - "Unclassified", - "Water", - "Impervious Surface", - "Soil and Barren", - "Trees and Forest", - "Shrubs", - "Grass and Herbaceous", - "Agriculture", - "Orchards", - "Woody Wetlands", - "Emergent Wetlands", + 'Unclassified', + 'Water', + 'Impervious Surface', + 'Soil and Barren', + 'Trees and Forest', + 'Shrubs', + 'Grass and Herbaceous', + 'Agriculture', + 'Orchards', + 'Woody Wetlands', + 'Emergent Wetlands', ] highres_cmap = ListedColormap( [ @@ -252,9 +252,9 @@ class EnviroAtlas(GeoDataset): def __init__( self, - root: str = "data", - splits: Sequence[str] = ["pittsburgh_pa-2010_1m-train"], - layers: Sequence[str] = ["naip", "prior"], + root: str = 'data', + splits: Sequence[str] = ['pittsburgh_pa-2010_1m-train'], + layers: Sequence[str] = ['naip', 'prior'], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, prior_as_input: bool = False, cache: bool = True, @@ -299,11 +299,11 @@ def __init__( mint: float = 0 maxt: float = sys.maxsize with fiona.open( - os.path.join(root, "enviroatlas_lotp", "spatial_index.geojson"), "r" + os.path.join(root, 'enviroatlas_lotp', 'spatial_index.geojson'), 'r' ) as f: for i, row in enumerate(f): - if row["properties"]["split"] in splits: - box = shapely.geometry.shape(row["geometry"]) + if row['properties']['split'] in splits: + box = shapely.geometry.shape(row['geometry']) minx, miny, maxx, maxy = box.bounds coords = (minx, maxx, miny, maxy, mint, maxt) @@ -311,22 +311,22 @@ def __init__( i, coords, { - "naip": row["properties"]["naip"], - "nlcd": row["properties"]["nlcd"], - "roads": row["properties"]["roads"], - "water": row["properties"]["water"], - "waterways": row["properties"]["waterways"], - "waterbodies": row["properties"]["waterbodies"], - "buildings": row["properties"]["buildings"], - "lc": row["properties"]["lc"], - "prior_no_osm_no_buildings": row["properties"][ - "naip" + 'naip': row['properties']['naip'], + 'nlcd': row['properties']['nlcd'], + 'roads': row['properties']['roads'], + 'water': row['properties']['water'], + 'waterways': row['properties']['waterways'], + 'waterbodies': row['properties']['waterbodies'], + 'buildings': row['properties']['buildings'], + 'lc': row['properties']['lc'], + 'prior_no_osm_no_buildings': row['properties'][ + 'naip' ].replace( - "a_naip", - "prior_from_cooccurrences_101_31_no_osm_no_buildings", + 'a_naip', + 'prior_from_cooccurrences_101_31_no_osm_no_buildings', ), - "prior": row["properties"]["naip"].replace( - "a_naip", "prior_from_cooccurrences_101_31" + 'prior': row['properties']['naip'].replace( + 'a_naip', 'prior_from_cooccurrences_101_31' ), }, ) @@ -346,11 +346,11 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: hits = self.index.intersection(tuple(query), objects=True) filepaths = cast(list[dict[str, str]], [hit.object for hit in hits]) - sample = {"image": [], "mask": [], "crs": self.crs, "bbox": query} + sample = {'image': [], 'mask': [], 'crs': self.crs, 'bbox': query} if len(filepaths) == 0: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) elif len(filepaths) == 1: filenames = filepaths[0] @@ -363,7 +363,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: fn = filenames[layer] with rasterio.open( - os.path.join(self.root, "enviroatlas_lotp", fn) + os.path.join(self.root, 'enviroatlas_lotp', fn) ) as f: dst_crs = f.crs.to_string().lower() @@ -380,30 +380,30 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: ) if layer in [ - "naip", - "buildings", - "roads", - "waterways", - "waterbodies", - "water", + 'naip', + 'buildings', + 'roads', + 'waterways', + 'waterbodies', + 'water', ]: - sample["image"].append(data) - elif layer in ["prior", "prior_no_osm_no_buildings"]: + sample['image'].append(data) + elif layer in ['prior', 'prior_no_osm_no_buildings']: if self.prior_as_input: - sample["image"].append(data) + sample['image'].append(data) else: - sample["mask"].append(data) - elif layer in ["lc"]: + sample['mask'].append(data) + elif layer in ['lc']: data = self.raw_enviroatlas_to_idx_map[data] - sample["mask"].append(data) + sample['mask'].append(data) else: - raise IndexError(f"query: {query} spans multiple tiles which is not valid") + raise IndexError(f'query: {query} spans multiple tiles which is not valid') - sample["image"] = np.concatenate(sample["image"], axis=0) - sample["mask"] = np.concatenate(sample["mask"], axis=0) + sample['image'] = np.concatenate(sample['image'], axis=0) + sample['mask'] = np.concatenate(sample['mask'], axis=0) - sample["image"] = torch.from_numpy(sample["image"]) - sample["mask"] = torch.from_numpy(sample["mask"]) + sample['image'] = torch.from_numpy(sample['image']) + sample['mask'] = torch.from_numpy(sample['mask']) if self.transforms is not None: sample = self.transforms(sample) @@ -414,7 +414,7 @@ def _verify(self) -> None: """Verify the integrity of the dataset.""" def exists(filename: str) -> bool: - return os.path.exists(os.path.join(self.root, "enviroatlas_lotp", filename)) + return os.path.exists(os.path.join(self.root, 'enviroatlas_lotp', filename)) # Check if the extracted files already exist if all(map(exists, self._files)): @@ -462,53 +462,53 @@ def plot( Raises: ValueError: if the NAIP layer isn't included in ``self.layers`` """ - if "naip" not in self.layers or "lc" not in self.layers: + if 'naip' not in self.layers or 'lc' not in self.layers: raise ValueError("The 'naip' and 'lc' layers must be included for plotting") image_layers = [] mask_layers = [] for layer in self.layers: if layer in [ - "naip", - "buildings", - "roads", - "waterways", - "waterbodies", - "water", + 'naip', + 'buildings', + 'roads', + 'waterways', + 'waterbodies', + 'water', ]: image_layers.append(layer) - elif layer in ["prior", "prior_no_osm_no_buildings"]: + elif layer in ['prior', 'prior_no_osm_no_buildings']: if self.prior_as_input: image_layers.append(layer) else: mask_layers.append(layer) - elif layer in ["lc"]: + elif layer in ['lc']: mask_layers.append(layer) - naip_index = image_layers.index("naip") - lc_index = mask_layers.index("lc") + naip_index = image_layers.index('naip') + lc_index = mask_layers.index('lc') image = np.rollaxis( - sample["image"][naip_index : naip_index + 3, :, :].numpy(), 0, 3 + sample['image'][naip_index : naip_index + 3, :, :].numpy(), 0, 3 ) - mask = sample["mask"][lc_index].numpy() + mask = sample['mask'][lc_index].numpy() num_panels = 2 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - predictions = sample["prediction"].numpy() + predictions = sample['prediction'].numpy() num_panels += 1 fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5)) axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow( - mask, vmin=0, vmax=10, cmap=self.highres_cmap, interpolation="none" + mask, vmin=0, vmax=10, cmap=self.highres_cmap, interpolation='none' ) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Image") - axs[1].set_title("Mask") + axs[0].set_title('Image') + axs[1].set_title('Mask') if showing_predictions: axs[2].imshow( @@ -516,11 +516,11 @@ def plot( vmin=0, vmax=10, cmap=self.highres_cmap, - interpolation="none", + interpolation='none', ) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[2].set_title("Predictions") + axs[2].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/esri2020.py b/torchgeo/datasets/esri2020.py index 2238f6943b6..bb2da326073 100644 --- a/torchgeo/datasets/esri2020.py +++ b/torchgeo/datasets/esri2020.py @@ -51,24 +51,24 @@ class Esri2020(RasterDataset): """ is_image = False - filename_glob = "*_20200101-20210101.*" + filename_glob = '*_20200101-20210101.*' filename_regex = r"""^ (?P[0-9][0-9][A-Z]) _(?P\d{8}) -(?P\d{8}) """ - zipfile = "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip" - md5 = "4932855fcd00735a34b74b1f87db3df0" + zipfile = 'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip' + md5 = '4932855fcd00735a34b74b1f87db3df0' url = ( - "https://ai4edataeuwest.blob.core.windows.net/io-lulc/" - "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip" + 'https://ai4edataeuwest.blob.core.windows.net/io-lulc/' + 'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip' ) def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -150,29 +150,29 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - mask = sample["mask"].squeeze() + mask = sample['mask'].squeeze() ncols = 1 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = sample["prediction"].squeeze() + prediction = sample['prediction'].squeeze() ncols = 2 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4 * ncols, 4)) if showing_predictions: axs[0].imshow(mask) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(prediction) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Mask") - axs[1].set_title("Prediction") + axs[0].set_title('Mask') + axs[1].set_title('Prediction') else: axs.imshow(mask) - axs.axis("off") + axs.axis('off') if show_titles: - axs.set_title("Mask") + axs.set_title('Mask') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index 7dd9bf02235..1980050f0ba 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -55,33 +55,33 @@ class ETCI2021(NonGeoDataset): the ETCI competition. """ - bands = ["VV", "VH"] - masks = ["flood", "water_body"] + bands = ['VV', 'VH'] + masks = ['flood', 'water_body'] metadata = { - "train": { - "filename": "train.zip", - "md5": "1e95792fe0f6e3c9000abdeab2a8ab0f", - "directory": "train", - "url": "https://drive.google.com/file/d/14HqNW5uWLS92n7KrxKgDwUTsSEST6LCr", + 'train': { + 'filename': 'train.zip', + 'md5': '1e95792fe0f6e3c9000abdeab2a8ab0f', + 'directory': 'train', + 'url': 'https://drive.google.com/file/d/14HqNW5uWLS92n7KrxKgDwUTsSEST6LCr', }, - "val": { - "filename": "val_with_ref_labels.zip", - "md5": "fd18cecb318efc69f8319f90c3771bdf", - "directory": "test", - "url": "https://drive.google.com/file/d/19sriKPHCZLfJn_Jmk3Z_0b3VaCBVRVyn", + 'val': { + 'filename': 'val_with_ref_labels.zip', + 'md5': 'fd18cecb318efc69f8319f90c3771bdf', + 'directory': 'test', + 'url': 'https://drive.google.com/file/d/19sriKPHCZLfJn_Jmk3Z_0b3VaCBVRVyn', }, - "test": { - "filename": "test_without_ref_labels.zip", - "md5": "da9fa69e1498bd49d5c766338c6dac3d", - "directory": "test_internal", - "url": "https://drive.google.com/file/d/1rpMVluASnSHBfm2FhpPDio0GyCPOqg7E", + 'test': { + 'filename': 'test_without_ref_labels.zip', + 'md5': 'da9fa69e1498bd49d5c766338c6dac3d', + 'directory': 'test_internal', + 'url': 'https://drive.google.com/file/d/1rpMVluASnSHBfm2FhpPDio0GyCPOqg7E', }, } def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -125,18 +125,18 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - vv = self._load_image(files["vv"]) - vh = self._load_image(files["vh"]) - water_mask = self._load_target(files["water_mask"]) + vv = self._load_image(files['vv']) + vh = self._load_image(files['vh']) + water_mask = self._load_target(files['water_mask']) - if self.split != "test": - flood_mask = self._load_target(files["flood_mask"]) + if self.split != 'test': + flood_mask = self._load_target(files['flood_mask']) mask = torch.stack(tensors=[water_mask, flood_mask], dim=0) else: mask = water_mask.unsqueeze(0) image = torch.cat(tensors=[vv, vh], dim=0) - sample = {"image": image, "mask": mask} + sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -163,20 +163,20 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: water body mask, flood mask (train/val only) """ files = [] - directory = self.metadata[split]["directory"] - folders = sorted(glob.glob(os.path.join(root, directory, "*"))) - folders = [os.path.join(folder, "tiles") for folder in folders] + directory = self.metadata[split]['directory'] + folders = sorted(glob.glob(os.path.join(root, directory, '*'))) + folders = [os.path.join(folder, 'tiles') for folder in folders] for folder in folders: - vvs = sorted(glob.glob(os.path.join(folder, "vv", "*.png"))) - vhs = [vv.replace("vv", "vh") for vv in vvs] + vvs = sorted(glob.glob(os.path.join(folder, 'vv', '*.png'))) + vhs = [vv.replace('vv', 'vh') for vv in vvs] water_masks = [ - vv.replace("_vv.png", ".png").replace("vv", "water_body_label") + vv.replace('_vv.png', '.png').replace('vv', 'water_body_label') for vv in vvs ] - if split != "test": + if split != 'test': flood_masks = [ - vv.replace("_vv.png", ".png").replace("vv", "flood_label") + vv.replace('_vv.png', '.png').replace('vv', 'flood_label') for vv in vvs ] @@ -203,7 +203,7 @@ def _load_image(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('RGB')) tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -220,7 +220,7 @@ def _load_target(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("L")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('L')) tensor = torch.from_numpy(array) tensor = torch.clamp(tensor, min=0, max=1) tensor = tensor.to(torch.long) @@ -232,7 +232,7 @@ def _check_integrity(self) -> bool: Returns: True if the dataset directories and split files are found, else False """ - directory = self.metadata[self.split]["directory"] + directory = self.metadata[self.split]['directory'] dirpath = os.path.join(self.root, directory) if not os.path.exists(dirpath): return False @@ -241,14 +241,14 @@ def _check_integrity(self) -> bool: def _download(self) -> None: """Download the dataset and extract it.""" if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return download_and_extract_archive( - self.metadata[self.split]["url"], + self.metadata[self.split]['url'], self.root, - filename=self.metadata[self.split]["filename"], - md5=self.metadata[self.split]["md5"] if self.checksum else None, + filename=self.metadata[self.split]['filename'], + md5=self.metadata[self.split]['md5'] if self.checksum else None, ) def plot( @@ -267,12 +267,12 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - vv = np.rollaxis(sample["image"][:3].numpy(), 0, 3) - vh = np.rollaxis(sample["image"][3:].numpy(), 0, 3) - mask = sample["mask"].squeeze(0) + vv = np.rollaxis(sample['image'][:3].numpy(), 0, 3) + vh = np.rollaxis(sample['image'][3:].numpy(), 0, 3) + mask = sample['mask'].squeeze(0) showing_flood_mask = mask.shape[0] == 2 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample num_panels = 3 if showing_flood_mask: water_mask = mask[0].numpy() @@ -282,34 +282,34 @@ def plot( water_mask = mask.numpy() if showing_predictions: - predictions = sample["prediction"].numpy() + predictions = sample['prediction'].numpy() num_panels += 1 fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 3)) axs[0].imshow(vv) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(vh) - axs[1].axis("off") + axs[1].axis('off') axs[2].imshow(water_mask) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[0].set_title("VV") - axs[1].set_title("VH") - axs[2].set_title("Water mask") + axs[0].set_title('VV') + axs[1].set_title('VH') + axs[2].set_title('Water mask') idx = 0 if showing_flood_mask: axs[3 + idx].imshow(flood_mask) - axs[3 + idx].axis("off") + axs[3 + idx].axis('off') if show_titles: - axs[3 + idx].set_title("Flood mask") + axs[3 + idx].set_title('Flood mask') idx += 1 if showing_predictions: axs[3 + idx].imshow(predictions) - axs[3 + idx].axis("off") + axs[3 + idx].axis('off') if show_titles: - axs[3 + idx].set_title("Predictions") + axs[3 + idx].set_title('Predictions') idx += 1 if suptitle is not None: diff --git a/torchgeo/datasets/eudem.py b/torchgeo/datasets/eudem.py index 15913aeed14..c7570aca8e5 100644 --- a/torchgeo/datasets/eudem.py +++ b/torchgeo/datasets/eudem.py @@ -47,43 +47,43 @@ class EUDEM(RasterDataset): """ is_image = False - filename_glob = "eu_dem_v11_*.TIF" - zipfile_glob = "eu_dem_v11_*[A-Z0-9].zip" - filename_regex = "(?P[eudem_v11]{10})_(?P[A-Z0-9]{6})" + filename_glob = 'eu_dem_v11_*.TIF' + zipfile_glob = 'eu_dem_v11_*[A-Z0-9].zip' + filename_regex = '(?P[eudem_v11]{10})_(?P[A-Z0-9]{6})' md5s = { - "eu_dem_v11_E00N20.zip": "96edc7e11bc299b994e848050d6be591", - "eu_dem_v11_E10N00.zip": "e14be147ac83eddf655f4833d55c1571", - "eu_dem_v11_E10N10.zip": "2eb5187e4d827245b33768404529c709", - "eu_dem_v11_E10N20.zip": "1afc162eb131841aed0d00b692b870a8", - "eu_dem_v11_E20N10.zip": "77b040791b9fb7de271b3f47130b4e0c", - "eu_dem_v11_E20N20.zip": "89b965abdcb1dbd479c61117f55230c8", - "eu_dem_v11_E20N30.zip": "f5cb1b05813ae8ffc9e70f0ad56cc372", - "eu_dem_v11_E20N40.zip": "81be551ff646802d7d820385de7476e9", - "eu_dem_v11_E20N50.zip": "bbc351713ea3eb7e9eb6794acb9e4bc8", - "eu_dem_v11_E30N10.zip": "68fb95aac33a025c4f35571f32f237ff", - "eu_dem_v11_E30N20.zip": "da8ad029f9cc1ec9234ea3e7629fe18d", - "eu_dem_v11_E30N30.zip": "de27c78d0176e45aec5c9e462a95749c", - "eu_dem_v11_E30N40.zip": "4c00e58b624adfc4a5748c922e77ee40", - "eu_dem_v11_E30N50.zip": "4a21a88f4d2047b8995d1101df0b3a77", - "eu_dem_v11_E40N10.zip": "32fdf4572581eddc305a21c5d2f4bc81", - "eu_dem_v11_E40N20.zip": "71b027f29258493dd751cfd63f08578f", - "eu_dem_v11_E40N30.zip": "c6c21289882c1f74fc4649d255302c64", - "eu_dem_v11_E40N40.zip": "9f26e6e47f4160ef8ea5200e8cf90a45", - "eu_dem_v11_E40N50.zip": "a8c3c1c026cdd1537b8a3822c15834d9", - "eu_dem_v11_E50N10.zip": "9584273c7708b8e935f2bac3e30c19c6", - "eu_dem_v11_E50N20.zip": "8efdea43e7b6819861935d5a768a55f2", - "eu_dem_v11_E50N30.zip": "e39e58df1c13ac35eb0b29fb651f313c", - "eu_dem_v11_E50N40.zip": "d84395ab52ad254d930db17398fffc50", - "eu_dem_v11_E50N50.zip": "6abe852f4a20962db0e355ffc0d695a4", - "eu_dem_v11_E60N10.zip": "b6a3b8a39a4efc01c7e2cd8418672559", - "eu_dem_v11_E60N20.zip": "71dc3c55ab5c90628ce2149dbd60f090", - "eu_dem_v11_E70N20.zip": "5342465ad60cf7d28a586c9585179c35", + 'eu_dem_v11_E00N20.zip': '96edc7e11bc299b994e848050d6be591', + 'eu_dem_v11_E10N00.zip': 'e14be147ac83eddf655f4833d55c1571', + 'eu_dem_v11_E10N10.zip': '2eb5187e4d827245b33768404529c709', + 'eu_dem_v11_E10N20.zip': '1afc162eb131841aed0d00b692b870a8', + 'eu_dem_v11_E20N10.zip': '77b040791b9fb7de271b3f47130b4e0c', + 'eu_dem_v11_E20N20.zip': '89b965abdcb1dbd479c61117f55230c8', + 'eu_dem_v11_E20N30.zip': 'f5cb1b05813ae8ffc9e70f0ad56cc372', + 'eu_dem_v11_E20N40.zip': '81be551ff646802d7d820385de7476e9', + 'eu_dem_v11_E20N50.zip': 'bbc351713ea3eb7e9eb6794acb9e4bc8', + 'eu_dem_v11_E30N10.zip': '68fb95aac33a025c4f35571f32f237ff', + 'eu_dem_v11_E30N20.zip': 'da8ad029f9cc1ec9234ea3e7629fe18d', + 'eu_dem_v11_E30N30.zip': 'de27c78d0176e45aec5c9e462a95749c', + 'eu_dem_v11_E30N40.zip': '4c00e58b624adfc4a5748c922e77ee40', + 'eu_dem_v11_E30N50.zip': '4a21a88f4d2047b8995d1101df0b3a77', + 'eu_dem_v11_E40N10.zip': '32fdf4572581eddc305a21c5d2f4bc81', + 'eu_dem_v11_E40N20.zip': '71b027f29258493dd751cfd63f08578f', + 'eu_dem_v11_E40N30.zip': 'c6c21289882c1f74fc4649d255302c64', + 'eu_dem_v11_E40N40.zip': '9f26e6e47f4160ef8ea5200e8cf90a45', + 'eu_dem_v11_E40N50.zip': 'a8c3c1c026cdd1537b8a3822c15834d9', + 'eu_dem_v11_E50N10.zip': '9584273c7708b8e935f2bac3e30c19c6', + 'eu_dem_v11_E50N20.zip': '8efdea43e7b6819861935d5a768a55f2', + 'eu_dem_v11_E50N30.zip': 'e39e58df1c13ac35eb0b29fb651f313c', + 'eu_dem_v11_E50N40.zip': 'd84395ab52ad254d930db17398fffc50', + 'eu_dem_v11_E50N50.zip': '6abe852f4a20962db0e355ffc0d695a4', + 'eu_dem_v11_E60N10.zip': 'b6a3b8a39a4efc01c7e2cd8418672559', + 'eu_dem_v11_E60N20.zip': '71dc3c55ab5c90628ce2149dbd60f090', + 'eu_dem_v11_E70N20.zip': '5342465ad60cf7d28a586c9585179c35', } def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -130,7 +130,7 @@ def _verify(self) -> None: for zipfile in glob.iglob(pathname): filename = os.path.basename(zipfile) if self.checksum and not check_integrity(zipfile, self.md5s[filename]): - raise RuntimeError("Dataset found, but corrupted.") + raise RuntimeError('Dataset found, but corrupted.') extract_archive(zipfile) return @@ -152,29 +152,29 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - mask = sample["mask"].squeeze() + mask = sample['mask'].squeeze() ncols = 1 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - pred = sample["prediction"].squeeze() + pred = sample['prediction'].squeeze() ncols = 2 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4)) if showing_predictions: axs[0].imshow(mask) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(pred) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Mask") - axs[1].set_title("Prediction") + axs[0].set_title('Mask') + axs[1].set_title('Prediction') else: axs.imshow(mask) - axs.axis("off") + axs.axis('off') if show_titles: - axs.set_title("Mask") + axs.set_title('Mask') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index 140e6f310c6..49b2d28f1e7 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -40,16 +40,16 @@ class EuroCrops(VectorDataset): .. versionadded:: 0.6 """ - base_url = "https://zenodo.org/records/8229128/files/" + base_url = 'https://zenodo.org/records/8229128/files/' - hcat_fname = "HCAT2.csv" - hcat_md5 = "b323e8de3d8d507bd0550968925b6906" + hcat_fname = 'HCAT2.csv' + hcat_md5 = 'b323e8de3d8d507bd0550968925b6906' # Name of the column containing HCAT code in CSV file. - hcat_code_column = "HCAT2_code" + hcat_code_column = 'HCAT2_code' - label_name = "EC_hcat_c" + label_name = 'EC_hcat_c' - filename_glob = "*_EC*.shp" + filename_glob = '*_EC*.shp' # Override variables to automatically extract timestamp. filename_regex = r""" @@ -61,26 +61,26 @@ class EuroCrops(VectorDataset): (?PEC(?:21)?) \.shp$ """ - date_format = "%Y" + date_format = '%Y' # Filename and md5 of files in this dataset on zenodo. zenodo_files = [ - ("AT_2021.zip", "490241df2e3d62812e572049fc0c36c5"), - ("BE_VLG_2021.zip", "ac4b9e12ad39b1cba47fdff1a786c2d7"), - ("DE_LS_2021.zip", "6d94e663a3ff7988b32cb36ea24a724f"), - ("DE_NRW_2021.zip", "a5af4e520cc433b9014cf8389c8f4c1f"), - ("DK_2019.zip", "d296478680edc3173422b379ace323d8"), - ("EE_2021.zip", "a7596f6691ad778a912d5a07e7ca6e41"), - ("ES_NA_2020.zip", "023f3b397d0f6f7a020508ed8320d543"), - ("FR_2018.zip", "282304734f156fb4df93a60b30e54c29"), - ("HR_2020.zip", "8bfe2b0cbd580737adcf7335682a1ea5"), - ("LT_2021.zip", "c7597214b90505877ee0cfa1232ac45f"), - ("LV_2021.zip", "b7253f96c8699d98ca503787f577ce26"), - ("NL_2020.zip", "823da32d28695b8b016740449391c0db"), - ("PT.zip", "3dba9c89c559b34d57acd286505bcb66"), - ("SE_2021.zip", "cab164c1c400fce56f7f1873bc966858"), - ("SI_2021.zip", "6b2dde6ba9d09c3ef8145ea520576228"), - ("SK_2021.zip", "c7762b4073869673edc08502e7b22f01"), + ('AT_2021.zip', '490241df2e3d62812e572049fc0c36c5'), + ('BE_VLG_2021.zip', 'ac4b9e12ad39b1cba47fdff1a786c2d7'), + ('DE_LS_2021.zip', '6d94e663a3ff7988b32cb36ea24a724f'), + ('DE_NRW_2021.zip', 'a5af4e520cc433b9014cf8389c8f4c1f'), + ('DK_2019.zip', 'd296478680edc3173422b379ace323d8'), + ('EE_2021.zip', 'a7596f6691ad778a912d5a07e7ca6e41'), + ('ES_NA_2020.zip', '023f3b397d0f6f7a020508ed8320d543'), + ('FR_2018.zip', '282304734f156fb4df93a60b30e54c29'), + ('HR_2020.zip', '8bfe2b0cbd580737adcf7335682a1ea5'), + ('LT_2021.zip', 'c7597214b90505877ee0cfa1232ac45f'), + ('LV_2021.zip', 'b7253f96c8699d98ca503787f577ce26'), + ('NL_2020.zip', '823da32d28695b8b016740449391c0db'), + ('PT.zip', '3dba9c89c559b34d57acd286505bcb66'), + ('SE_2021.zip', 'cab164c1c400fce56f7f1873bc966858'), + ('SI_2021.zip', '6b2dde6ba9d09c3ef8145ea520576228'), + ('SK_2021.zip', 'c7762b4073869673edc08502e7b22f01'), # Year is unknown for Romania portion (ny = no year). # We skip since it is inconsistent with the rest of the data. # ("RO_ny.zip", "648e1504097765b4b7f825decc838882"), @@ -88,7 +88,7 @@ class EuroCrops(VectorDataset): def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS = CRS.from_epsg(4326), res: float = 0.00001, classes: list[str] | None = None, @@ -157,7 +157,7 @@ def _check_integrity(self) -> bool: def _download(self) -> None: """Download the dataset and extract it.""" if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return assert isinstance(self.paths, str) download_url( @@ -194,7 +194,7 @@ def _load_class_map(self, classes: list[str] | None) -> None: for idx, hcat_code in enumerate(classes): self.class_map[hcat_code] = idx + 1 - def get_label(self, feature: "fiona.model.Feature") -> int: + def get_label(self, feature: 'fiona.model.Feature') -> int: """Get label value to use for rendering a feature. Args: @@ -206,19 +206,19 @@ def get_label(self, feature: "fiona.model.Feature") -> int: # Convert the HCAT code of this feature to its index per self.class_map. # We go up the class hierarchy until there is a match. # (Parent code is computed by replacing rightmost non-0 character with 0.) - hcat_code = feature["properties"][self.label_name] + hcat_code = feature['properties'][self.label_name] while True: if hcat_code in self.class_map: return self.class_map[hcat_code] hcat_code_list = list(hcat_code) - if all(c == "0" for c in hcat_code_list): + if all(c == '0' for c in hcat_code_list): break for i in range(len(hcat_code_list) - 1, -1, -1): - if hcat_code_list[i] == "0": + if hcat_code_list[i] == '0': continue - hcat_code_list[i] = "0" + hcat_code_list[i] = '0' break - hcat_code = "".join(hcat_code_list) + hcat_code = ''.join(hcat_code_list) return 0 def plot( @@ -237,36 +237,36 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - mask = sample["mask"].squeeze() + mask = sample['mask'].squeeze() ncols = 1 - showing_prediction = "prediction" in sample + showing_prediction = 'prediction' in sample if showing_prediction: - pred = sample["prediction"].squeeze() + pred = sample['prediction'].squeeze() ncols = 2 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4, 4)) - def apply_cmap(arr: "np.typing.NDArray[Any]") -> "np.typing.NDArray[np.float_]": + def apply_cmap(arr: 'np.typing.NDArray[Any]') -> 'np.typing.NDArray[np.float_]': # Color 0 as black, while applying default color map for the class indices. - cmap = plt.get_cmap("viridis") - im: "np.typing.NDArray[np.float_]" = cmap(arr / len(self.class_map)) + cmap = plt.get_cmap('viridis') + im: 'np.typing.NDArray[np.float_]' = cmap(arr / len(self.class_map)) im[arr == 0] = 0 return im if showing_prediction: - axs[0].imshow(apply_cmap(mask), interpolation="none") - axs[0].axis("off") - axs[1].imshow(apply_cmap(pred), interpolation="none") - axs[1].axis("off") + axs[0].imshow(apply_cmap(mask), interpolation='none') + axs[0].axis('off') + axs[1].imshow(apply_cmap(pred), interpolation='none') + axs[1].axis('off') if show_titles: - axs[0].set_title("Mask") - axs[1].set_title("Prediction") + axs[0].set_title('Mask') + axs[1].set_title('Prediction') else: - axs.imshow(apply_cmap(mask), interpolation="none") - axs.axis("off") + axs.imshow(apply_cmap(mask), interpolation='none') + axs.axis('off') if show_titles: - axs.set_title("Mask") + axs.set_title('Mask') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 7eff68cd7bb..f09e1cc7b9b 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -60,52 +60,52 @@ class EuroSAT(NonGeoClassificationDataset): * https://ieeexplore.ieee.org/document/8519248 """ - url = "https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip" # noqa: E501 - filename = "EuroSATallBands.zip" - md5 = "5ac12b3b2557aa56e1826e981e8e200e" + url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip' # noqa: E501 + filename = 'EuroSATallBands.zip' + md5 = '5ac12b3b2557aa56e1826e981e8e200e' # For some reason the class directories are actually nested in this directory base_dir = os.path.join( - "ds", "images", "remote_sensing", "otherDatasets", "sentinel_2", "tif" + 'ds', 'images', 'remote_sensing', 'otherDatasets', 'sentinel_2', 'tif' ) - splits = ["train", "val", "test"] + splits = ['train', 'val', 'test'] split_urls = { - "train": "https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt", # noqa: E501 - "val": "https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt", # noqa: E501 - "test": "https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt", # noqa: E501 + 'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt', # noqa: E501 + 'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt', # noqa: E501 + 'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt', # noqa: E501 } split_md5s = { - "train": "908f142e73d6acdf3f482c5e80d851b1", - "val": "95de90f2aa998f70a3b2416bfe0687b4", - "test": "7ae5ab94471417b6e315763121e67c5f", + 'train': '908f142e73d6acdf3f482c5e80d851b1', + 'val': '95de90f2aa998f70a3b2416bfe0687b4', + 'test': '7ae5ab94471417b6e315763121e67c5f', } all_band_names = ( - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B10", - "B11", - "B12", + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B10', + 'B11', + 'B12', ) - rgb_bands = ("B04", "B03", "B02") + rgb_bands = ('B04', 'B03', 'B02') - BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + BAND_SETS = {'all': all_band_names, 'rgb': rgb_bands} def __init__( self, - root: str = "data", - split: str = "train", - bands: Sequence[str] = BAND_SETS["all"], + root: str = 'data', + split: str = 'train', + bands: Sequence[str] = BAND_SETS['all'], transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -133,7 +133,7 @@ def __init__( self.download = download self.checksum = checksum - assert split in ["train", "val", "test"] + assert split in ['train', 'val', 'test'] self._validate_bands(bands) self.bands = bands @@ -144,9 +144,9 @@ def __init__( self._verify() valid_fns = set() - with open(os.path.join(self.root, f"eurosat-{split}.txt")) as f: + with open(os.path.join(self.root, f'eurosat-{split}.txt')) as f: for fn in f: - valid_fns.add(fn.strip().replace(".jpg", ".tif")) + valid_fns.add(fn.strip().replace('.jpg', '.tif')) def is_in_split(x: str) -> bool: return os.path.basename(x) in valid_fns @@ -169,7 +169,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image, label = self._load_image(index) image = torch.index_select(image, dim=0, index=self.band_indices).float() - sample = {"image": image, "label": label} + sample = {'image': image, 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -219,7 +219,7 @@ def _download(self) -> None: download_url( self.split_urls[split], self.root, - filename=f"eurosat-{split}.txt", + filename=f'eurosat-{split}.txt', md5=self.split_md5s[split] if self.checksum else None, ) @@ -273,25 +273,25 @@ def plot( else: raise RGBBandsMissingError() - image = np.take(sample["image"].numpy(), indices=rgb_indices, axis=0) + image = np.take(sample['image'].numpy(), indices=rgb_indices, axis=0) image = np.rollaxis(image, 0, 3) image = np.clip(image / 3000, 0, 1) - label = cast(int, sample["label"].item()) + label = cast(int, sample['label'].item()) label_class = self.classes[label] - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = cast(int, sample["prediction"].item()) + prediction = cast(int, sample['prediction'].item()) prediction_class = self.classes[prediction] fig, ax = plt.subplots(figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: - title = f"Label: {label_class}" + title = f'Label: {label_class}' if showing_predictions: - title += f"\nPrediction: {prediction_class}" + title += f'\nPrediction: {prediction_class}' ax.set_title(title) if suptitle is not None: @@ -310,17 +310,17 @@ class EuroSAT100(EuroSAT): .. versionadded:: 0.5 """ - url = "https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip" # noqa: E501 - filename = "EuroSAT100.zip" - md5 = "c21c649ba747e86eda813407ef17d596" + url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip' # noqa: E501 + filename = 'EuroSAT100.zip' + md5 = 'c21c649ba747e86eda813407ef17d596' split_urls = { - "train": "https://hf.co/datasets/torchgeo/eurosat/raw/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt", # noqa: E501 - "val": "https://hf.co/datasets/torchgeo/eurosat/raw/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt", # noqa: E501 - "test": "https://hf.co/datasets/torchgeo/eurosat/raw/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt", # noqa: E501 + 'train': 'https://hf.co/datasets/torchgeo/eurosat/raw/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt', # noqa: E501 + 'val': 'https://hf.co/datasets/torchgeo/eurosat/raw/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt', # noqa: E501 + 'test': 'https://hf.co/datasets/torchgeo/eurosat/raw/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt', # noqa: E501 } split_md5s = { - "train": "033d0c23e3a75e3fa79618b0e35fe1c7", - "val": "3e3f8b3c344182b8d126c4cc88f3f215", - "test": "f908f151b950f270ad18e61153579794", + 'train': '033d0c23e3a75e3fa79618b0e35fe1c7', + 'val': '3e3f8b3c344182b8d126c4cc88f3f215', + 'test': 'f908f151b950f270ad18e61153579794', } diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index f9de21faba7..d019e4d90bd 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -32,20 +32,20 @@ def parse_pascal_voc(path: str) -> dict[str, Any]: """ et = parse(path) element = et.getroot() - source = cast(Element, element.find("source")) - filename = cast(Element, source.find("filename")).text + source = cast(Element, element.find('source')) + filename = cast(Element, source.find('filename')).text labels, points = [], [] - objects = cast(Element, element.find("objects")) - for obj in objects.findall("object"): - elm_points = cast(Element, obj.find("points")) - lis_points = elm_points.findall("point") + objects = cast(Element, element.find('objects')) + for obj in objects.findall('object'): + elm_points = cast(Element, obj.find('points')) + lis_points = elm_points.findall('point') str_points = [] for point in lis_points: text = cast(str, point.text) - str_points.append(text.split(",")) + str_points.append(text.split(',')) tup_points = [(float(p1), float(p2)) for p1, p2 in str_points] - possibleresult = cast(Element, obj.find("possibleresult")) - name = cast(Element, possibleresult.find("name")) + possibleresult = cast(Element, obj.find('possibleresult')) + name = cast(Element, possibleresult.find('name')) label = name.text labels.append(label) points.append(tup_points) @@ -119,118 +119,118 @@ class FAIR1M(NonGeoDataset): """ classes = { - "Passenger Ship": {"id": 0, "category": "Ship"}, - "Motorboat": {"id": 1, "category": "Ship"}, - "Fishing Boat": {"id": 2, "category": "Ship"}, - "Tugboat": {"id": 3, "category": "Ship"}, - "other-ship": {"id": 4, "category": "Ship"}, - "Engineering Ship": {"id": 5, "category": "Ship"}, - "Liquid Cargo Ship": {"id": 6, "category": "Ship"}, - "Dry Cargo Ship": {"id": 7, "category": "Ship"}, - "Warship": {"id": 8, "category": "Ship"}, - "Small Car": {"id": 9, "category": "Vehicle"}, - "Bus": {"id": 10, "category": "Vehicle"}, - "Cargo Truck": {"id": 11, "category": "Vehicle"}, - "Dump Truck": {"id": 12, "category": "Vehicle"}, - "other-vehicle": {"id": 13, "category": "Vehicle"}, - "Van": {"id": 14, "category": "Vehicle"}, - "Trailer": {"id": 15, "category": "Vehicle"}, - "Tractor": {"id": 16, "category": "Vehicle"}, - "Excavator": {"id": 17, "category": "Vehicle"}, - "Truck Tractor": {"id": 18, "category": "Vehicle"}, - "Boeing737": {"id": 19, "category": "Airplane"}, - "Boeing747": {"id": 20, "category": "Airplane"}, - "Boeing777": {"id": 21, "category": "Airplane"}, - "Boeing787": {"id": 22, "category": "Airplane"}, - "ARJ21": {"id": 23, "category": "Airplane"}, - "C919": {"id": 24, "category": "Airplane"}, - "A220": {"id": 25, "category": "Airplane"}, - "A321": {"id": 26, "category": "Airplane"}, - "A330": {"id": 27, "category": "Airplane"}, - "A350": {"id": 28, "category": "Airplane"}, - "other-airplane": {"id": 29, "category": "Airplane"}, - "Baseball Field": {"id": 30, "category": "Court"}, - "Basketball Court": {"id": 31, "category": "Court"}, - "Football Field": {"id": 32, "category": "Court"}, - "Tennis Court": {"id": 33, "category": "Court"}, - "Roundabout": {"id": 34, "category": "Road"}, - "Intersection": {"id": 35, "category": "Road"}, - "Bridge": {"id": 36, "category": "Road"}, + 'Passenger Ship': {'id': 0, 'category': 'Ship'}, + 'Motorboat': {'id': 1, 'category': 'Ship'}, + 'Fishing Boat': {'id': 2, 'category': 'Ship'}, + 'Tugboat': {'id': 3, 'category': 'Ship'}, + 'other-ship': {'id': 4, 'category': 'Ship'}, + 'Engineering Ship': {'id': 5, 'category': 'Ship'}, + 'Liquid Cargo Ship': {'id': 6, 'category': 'Ship'}, + 'Dry Cargo Ship': {'id': 7, 'category': 'Ship'}, + 'Warship': {'id': 8, 'category': 'Ship'}, + 'Small Car': {'id': 9, 'category': 'Vehicle'}, + 'Bus': {'id': 10, 'category': 'Vehicle'}, + 'Cargo Truck': {'id': 11, 'category': 'Vehicle'}, + 'Dump Truck': {'id': 12, 'category': 'Vehicle'}, + 'other-vehicle': {'id': 13, 'category': 'Vehicle'}, + 'Van': {'id': 14, 'category': 'Vehicle'}, + 'Trailer': {'id': 15, 'category': 'Vehicle'}, + 'Tractor': {'id': 16, 'category': 'Vehicle'}, + 'Excavator': {'id': 17, 'category': 'Vehicle'}, + 'Truck Tractor': {'id': 18, 'category': 'Vehicle'}, + 'Boeing737': {'id': 19, 'category': 'Airplane'}, + 'Boeing747': {'id': 20, 'category': 'Airplane'}, + 'Boeing777': {'id': 21, 'category': 'Airplane'}, + 'Boeing787': {'id': 22, 'category': 'Airplane'}, + 'ARJ21': {'id': 23, 'category': 'Airplane'}, + 'C919': {'id': 24, 'category': 'Airplane'}, + 'A220': {'id': 25, 'category': 'Airplane'}, + 'A321': {'id': 26, 'category': 'Airplane'}, + 'A330': {'id': 27, 'category': 'Airplane'}, + 'A350': {'id': 28, 'category': 'Airplane'}, + 'other-airplane': {'id': 29, 'category': 'Airplane'}, + 'Baseball Field': {'id': 30, 'category': 'Court'}, + 'Basketball Court': {'id': 31, 'category': 'Court'}, + 'Football Field': {'id': 32, 'category': 'Court'}, + 'Tennis Court': {'id': 33, 'category': 'Court'}, + 'Roundabout': {'id': 34, 'category': 'Road'}, + 'Intersection': {'id': 35, 'category': 'Road'}, + 'Bridge': {'id': 36, 'category': 'Road'}, } filename_glob = { - "train": os.path.join("train", "**", "images", "*.tif"), - "val": os.path.join("validation", "images", "*.tif"), - "test": os.path.join("test", "images", "*.tif"), + 'train': os.path.join('train', '**', 'images', '*.tif'), + 'val': os.path.join('validation', 'images', '*.tif'), + 'test': os.path.join('test', 'images', '*.tif'), } directories = { - "train": ( - os.path.join("train", "part1", "images"), - os.path.join("train", "part1", "labelXml"), - os.path.join("train", "part2", "images"), - os.path.join("train", "part2", "labelXml"), + 'train': ( + os.path.join('train', 'part1', 'images'), + os.path.join('train', 'part1', 'labelXml'), + os.path.join('train', 'part2', 'images'), + os.path.join('train', 'part2', 'labelXml'), ), - "val": ( - os.path.join("validation", "images"), - os.path.join("validation", "labelXml"), + 'val': ( + os.path.join('validation', 'images'), + os.path.join('validation', 'labelXml'), ), - "test": (os.path.join("test", "images")), + 'test': (os.path.join('test', 'images')), } paths = { - "train": ( - os.path.join("train", "part1", "images.zip"), - os.path.join("train", "part1", "labelXml.zip"), - os.path.join("train", "part2", "images.zip"), - os.path.join("train", "part2", "labelXmls.zip"), + 'train': ( + os.path.join('train', 'part1', 'images.zip'), + os.path.join('train', 'part1', 'labelXml.zip'), + os.path.join('train', 'part2', 'images.zip'), + os.path.join('train', 'part2', 'labelXmls.zip'), ), - "val": ( - os.path.join("validation", "images.zip"), - os.path.join("validation", "labelXmls.zip"), + 'val': ( + os.path.join('validation', 'images.zip'), + os.path.join('validation', 'labelXmls.zip'), ), - "test": ( - os.path.join("test", "images0.zip"), - os.path.join("test", "images1.zip"), - os.path.join("test", "images2.zip"), + 'test': ( + os.path.join('test', 'images0.zip'), + os.path.join('test', 'images1.zip'), + os.path.join('test', 'images2.zip'), ), } urls = { - "train": ( - "https://drive.google.com/file/d/1LWT_ybL-s88Lzg9A9wHpj0h2rJHrqrVf", - "https://drive.google.com/file/d/1CnOuS8oX6T9JMqQnfFsbmf7U38G6Vc8u", - "https://drive.google.com/file/d/1cx4MRfpmh68SnGAYetNlDy68w0NgKucJ", - "https://drive.google.com/file/d/1RFVjadTHA_bsB7BJwSZoQbiyM7KIDEUI", + 'train': ( + 'https://drive.google.com/file/d/1LWT_ybL-s88Lzg9A9wHpj0h2rJHrqrVf', + 'https://drive.google.com/file/d/1CnOuS8oX6T9JMqQnfFsbmf7U38G6Vc8u', + 'https://drive.google.com/file/d/1cx4MRfpmh68SnGAYetNlDy68w0NgKucJ', + 'https://drive.google.com/file/d/1RFVjadTHA_bsB7BJwSZoQbiyM7KIDEUI', ), - "val": ( - "https://drive.google.com/file/d/1lSSHOD02B6_sUmr2b-R1iqhgWRQRw-S9", - "https://drive.google.com/file/d/1sTTna1C5n3Senpfo-73PdiNilnja1AV4", + 'val': ( + 'https://drive.google.com/file/d/1lSSHOD02B6_sUmr2b-R1iqhgWRQRw-S9', + 'https://drive.google.com/file/d/1sTTna1C5n3Senpfo-73PdiNilnja1AV4', ), - "test": ( - "https://drive.google.com/file/d/1HtOOVfK9qetDBjE7MM0dK_u5u7n4gdw3", - "https://drive.google.com/file/d/1iXKCPmmJtRYcyuWCQC35bk97NmyAsasq", - "https://drive.google.com/file/d/1oUc25FVf8Zcp4pzJ31A1j1sOLNHu63P0", + 'test': ( + 'https://drive.google.com/file/d/1HtOOVfK9qetDBjE7MM0dK_u5u7n4gdw3', + 'https://drive.google.com/file/d/1iXKCPmmJtRYcyuWCQC35bk97NmyAsasq', + 'https://drive.google.com/file/d/1oUc25FVf8Zcp4pzJ31A1j1sOLNHu63P0', ), } md5s = { - "train": ( - "a460fe6b1b5b276bf856ce9ac72d6568", - "80f833ff355f91445c92a0c0c1fa7414", - "ad237e61dba304fcef23cd14aa6c4280", - "5c5948e68cd0f991a0d73f10956a3b05", + 'train': ( + 'a460fe6b1b5b276bf856ce9ac72d6568', + '80f833ff355f91445c92a0c0c1fa7414', + 'ad237e61dba304fcef23cd14aa6c4280', + '5c5948e68cd0f991a0d73f10956a3b05', ), - "val": ("dce782be65405aa381821b5f4d9eac94", "700b516a21edc9eae66ca315b72a09a1"), - "test": ( - "fb8ccb274f3075d50ac9f7803fbafd3d", - "dc9bbbdee000e97f02276aa61b03e585", - "700b516a21edc9eae66ca315b72a09a1", + 'val': ('dce782be65405aa381821b5f4d9eac94', '700b516a21edc9eae66ca315b72a09a1'), + 'test': ( + 'fb8ccb274f3075d50ac9f7803fbafd3d', + 'dc9bbbdee000e97f02276aa61b03e585', + '700b516a21edc9eae66ca315b72a09a1', ), } - image_root: str = "images" - label_root: str = "labelXml" + image_root: str = 'images' + label_root: str = 'labelXml' def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -275,14 +275,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: path = self.files[index] image = self._load_image(path) - sample = {"image": image} + sample = {'image': image} - if self.split != "test": + if self.split != 'test': label_path = path.replace(self.image_root, self.label_root) - label_path = label_path.replace(".tif", ".xml") + label_path = label_path.replace('.tif', '.xml') voc = parse_pascal_voc(label_path) - boxes, labels = self._load_target(voc["points"], voc["labels"]) - sample = {"image": image, "boxes": boxes, "label": labels} + boxes, labels = self._load_target(voc['points'], voc['labels']) + sample = {'image': image, 'boxes': boxes, 'label': labels} if self.transforms is not None: sample = self.transforms(sample) @@ -307,7 +307,7 @@ def _load_image(self, path: str) -> Tensor: the image """ with Image.open(path) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('RGB')) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -325,7 +325,7 @@ def _load_target( Returns: the target bounding boxes and labels """ - labels_list = [self.classes[label]["id"] for label in labels] + labels_list = [self.classes[label]['id'] for label in labels] boxes = torch.tensor(points).to(torch.float) labels_tensor = torch.tensor(labels_list) return boxes, labels_tensor @@ -347,7 +347,7 @@ def _verify(self) -> None: filepath = os.path.join(self.root, path) if os.path.isfile(filepath): if self.checksum and not check_integrity(filepath, md5): - raise RuntimeError("Dataset found, but corrupted.") + raise RuntimeError('Dataset found, but corrupted.') exists.append(True) extract_archive(filepath) else: @@ -397,10 +397,10 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - image = sample["image"].permute((1, 2, 0)).numpy() + image = sample['image'].permute((1, 2, 0)).numpy() ncols = 1 - if "prediction_boxes" in sample: + if 'prediction_boxes' in sample: ncols += 1 fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) @@ -408,31 +408,31 @@ def plot( axs = [axs] axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') - if "boxes" in sample: + if 'boxes' in sample: polygons = [ - patches.Polygon(points, color="r", fill=False) - for points in sample["boxes"].numpy() + patches.Polygon(points, color='r', fill=False) + for points in sample['boxes'].numpy() ] for polygon in polygons: axs[0].add_patch(polygon) if show_titles: - axs[0].set_title("Ground Truth") + axs[0].set_title('Ground Truth') if ncols > 1: axs[1].imshow(image) - axs[1].axis("off") + axs[1].axis('off') polygons = [ - patches.Polygon(points, color="r", fill=False) - for points in sample["prediction_boxes"].numpy() + patches.Polygon(points, color='r', fill=False) + for points in sample['prediction_boxes'].numpy() ] for polygon in polygons: axs[0].add_patch(polygon) if show_titles: - axs[1].set_title("Predictions") + axs[1].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/fire_risk.py b/torchgeo/datasets/fire_risk.py index e10bbae8aad..40c64ed6f71 100644 --- a/torchgeo/datasets/fire_risk.py +++ b/torchgeo/datasets/fire_risk.py @@ -50,25 +50,25 @@ class FireRisk(NonGeoClassificationDataset): .. versionadded:: 0.5 """ - url = "https://drive.google.com/file/d/1J5GrJJPLWkpuptfY_kgqkiDtcSNP88OP" - md5 = "a77b9a100d51167992ae8c51d26198a6" - filename = "FireRisk.zip" - directory = "FireRisk" - splits = ["train", "val"] + url = 'https://drive.google.com/file/d/1J5GrJJPLWkpuptfY_kgqkiDtcSNP88OP' + md5 = 'a77b9a100d51167992ae8c51d26198a6' + filename = 'FireRisk.zip' + directory = 'FireRisk' + splits = ['train', 'val'] classes = [ - "High", - "Low", - "Moderate", - "Non-burnable", - "Very_High", - "Very_Low", - "Water", + 'High', + 'Low', + 'Moderate', + 'Non-burnable', + 'Very_High', + 'Very_Low', + 'Water', ] def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -149,22 +149,22 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - image = sample["image"].permute((1, 2, 0)).numpy() - label = cast(int, sample["label"].item()) + image = sample['image'].permute((1, 2, 0)).numpy() + label = cast(int, sample['label'].item()) label_class = self.classes[label] - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = cast(int, sample["prediction"].item()) + prediction = cast(int, sample['prediction'].item()) prediction_class = self.classes[prediction] fig, ax = plt.subplots(figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: - title = f"Label: {label_class}" + title = f'Label: {label_class}' if showing_predictions: - title += f"\nPrediction: {prediction_class}" + title += f'\nPrediction: {prediction_class}' ax.set_title(title) if suptitle is not None: diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index b20ab52963f..e25edf01d7d 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -37,22 +37,22 @@ def parse_pascal_voc(path: str) -> dict[str, Any]: """ et = ElementTree.parse(path) element = et.getroot() - filename = element.find("filename").text # type: ignore[union-attr] + filename = element.find('filename').text # type: ignore[union-attr] labels, bboxes = [], [] - for obj in element.findall("object"): - bndbox = obj.find("bndbox") + for obj in element.findall('object'): + bndbox = obj.find('bndbox') bbox = [ - int(bndbox.find("xmin").text), # type: ignore[union-attr, arg-type] - int(bndbox.find("ymin").text), # type: ignore[union-attr, arg-type] - int(bndbox.find("xmax").text), # type: ignore[union-attr, arg-type] - int(bndbox.find("ymax").text), # type: ignore[union-attr, arg-type] + int(bndbox.find('xmin').text), # type: ignore[union-attr, arg-type] + int(bndbox.find('ymin').text), # type: ignore[union-attr, arg-type] + int(bndbox.find('xmax').text), # type: ignore[union-attr, arg-type] + int(bndbox.find('ymax').text), # type: ignore[union-attr, arg-type] ] - label_var = obj.find("damage") + label_var = obj.find('damage') if label_var is not None: label = label_var.text else: - label = "other" + label = 'other' bboxes.append(bbox) labels.append(label) return dict(filename=filename, bboxes=bboxes, labels=labels) @@ -100,17 +100,17 @@ class ForestDamage(NonGeoDataset): .. versionadded:: 0.3 """ - classes = ["other", "H", "LD", "HD"] + classes = ['other', 'H', 'LD', 'HD'] url = ( - "https://lilablobssc.blob.core.windows.net/larch-casebearer/" - "Data_Set_Larch_Casebearer.zip" + 'https://lilablobssc.blob.core.windows.net/larch-casebearer/' + 'Data_Set_Larch_Casebearer.zip' ) - data_dir = "Data_Set_Larch_Casebearer" - md5 = "907815bcc739bff89496fac8f8ce63d7" + data_dir = 'Data_Set_Larch_Casebearer' + md5 = '907815bcc739bff89496fac8f8ce63d7' def __init__( self, - root: str = "data", + root: str = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -148,12 +148,12 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - parsed = parse_pascal_voc(files["annotation"]) - image = self._load_image(files["image"]) + parsed = parse_pascal_voc(files['annotation']) + image = self._load_image(files['image']) - boxes, labels = self._load_target(parsed["bboxes"], parsed["labels"]) + boxes, labels = self._load_target(parsed['bboxes'], parsed['labels']) - sample = {"image": image, "boxes": boxes, "label": labels} + sample = {'image': image, 'boxes': boxes, 'label': labels} if self.transforms is not None: sample = self.transforms(sample) @@ -178,10 +178,10 @@ def _load_files(self, root: str) -> list[dict[str, str]]: list of dicts containing paths for each pair of image, annotation """ images = sorted( - glob.glob(os.path.join(root, self.data_dir, "**", "Images", "*.JPG")) + glob.glob(os.path.join(root, self.data_dir, '**', 'Images', '*.JPG')) ) annotations = sorted( - glob.glob(os.path.join(root, self.data_dir, "**", "Annotations", "*.xml")) + glob.glob(os.path.join(root, self.data_dir, '**', 'Annotations', '*.xml')) ) files = [ @@ -201,7 +201,7 @@ def _load_image(self, path: str) -> Tensor: the image """ with Image.open(path) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('RGB')) tensor: Tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -233,10 +233,10 @@ def _verify(self) -> None: if os.path.isdir(filepath): return - filepath = os.path.join(self.root, self.data_dir + ".zip") + filepath = os.path.join(self.root, self.data_dir + '.zip') if os.path.isfile(filepath): if self.checksum and not check_integrity(filepath, self.md5): - raise RuntimeError("Dataset found, but corrupted.") + raise RuntimeError('Dataset found, but corrupted.') extract_archive(filepath) return @@ -252,7 +252,7 @@ def _download(self) -> None: download_and_extract_archive( self.url, self.root, - filename=self.data_dir + ".zip", + filename=self.data_dir + '.zip', md5=self.md5 if self.checksum else None, ) @@ -272,10 +272,10 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - image = sample["image"].permute((1, 2, 0)).numpy() + image = sample['image'].permute((1, 2, 0)).numpy() ncols = 1 - showing_predictions = "prediction_boxes" in sample + showing_predictions = 'prediction_boxes' in sample if showing_predictions: ncols += 1 @@ -284,7 +284,7 @@ def plot( axs = [axs] axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') bboxes = [ patches.Rectangle( @@ -292,20 +292,20 @@ def plot( bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=1, - edgecolor="r", - facecolor="none", + edgecolor='r', + facecolor='none', ) - for bbox in sample["boxes"].numpy() + for bbox in sample['boxes'].numpy() ] for bbox in bboxes: axs[0].add_patch(bbox) if show_titles: - axs[0].set_title("Ground Truth") + axs[0].set_title('Ground Truth') if showing_predictions: axs[1].imshow(image) - axs[1].axis("off") + axs[1].axis('off') pred_bboxes = [ patches.Rectangle( @@ -313,16 +313,16 @@ def plot( bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=1, - edgecolor="r", - facecolor="none", + edgecolor='r', + facecolor='none', ) - for bbox in sample["prediction_boxes"].numpy() + for bbox in sample['prediction_boxes'].numpy() ] for bbox in pred_bboxes: axs[1].add_patch(bbox) if show_titles: - axs[1].set_title("Predictions") + axs[1].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/gbif.py b/torchgeo/datasets/gbif.py index a34cc8ba685..fc6a4edd1c9 100644 --- a/torchgeo/datasets/gbif.py +++ b/torchgeo/datasets/gbif.py @@ -79,7 +79,7 @@ class GBIF(GeoDataset): res = 0 _crs = CRS.from_epsg(4326) # Lat/Lon - def __init__(self, root: str = "data") -> None: + def __init__(self, root: str = 'data') -> None: """Initialize a new Dataset instance. Args: @@ -92,15 +92,15 @@ def __init__(self, root: str = "data") -> None: self.root = root - files = glob.glob(os.path.join(root, "**.csv")) + files = glob.glob(os.path.join(root, '**.csv')) if not files: raise DatasetNotFoundError(self) # Read tab-delimited CSV file data = pd.read_table( files[0], - engine="c", - usecols=["decimalLatitude", "decimalLongitude", "day", "month", "year"], + engine='c', + usecols=['decimalLatitude', 'decimalLongitude', 'day', 'month', 'year'], ) # Convert from pandas DataFrame to rtree Index @@ -133,9 +133,9 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if not bboxes: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) - sample = {"crs": self.crs, "bbox": bboxes} + sample = {'crs': self.crs, 'bbox': bboxes} return sample diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 590f6f7dcf0..29b89068554 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -92,7 +92,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): #: This expression should be specific enough that it will not pick up files from #: other datasets. It should not include a file extension, as the dataset may be in #: a different file format than what it was originally downloaded as. - filename_glob = "*" + filename_glob = '*' # NOTE: according to the Python docs: # @@ -135,7 +135,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ - def __and__(self, other: "GeoDataset") -> "IntersectionDataset": + def __and__(self, other: 'GeoDataset') -> 'IntersectionDataset': """Take the intersection of two :class:`GeoDataset`. Args: @@ -151,7 +151,7 @@ def __and__(self, other: "GeoDataset") -> "IntersectionDataset": """ return IntersectionDataset(self, other) - def __or__(self, other: "GeoDataset") -> "UnionDataset": + def __or__(self, other: 'GeoDataset') -> 'UnionDataset': """Take the union of two GeoDatasets. Args: @@ -247,7 +247,7 @@ def crs(self, new_crs: CRS) -> None: if new_crs == self.crs: return - print(f"Converting {self.__class__.__name__} CRS from {self.crs} to {new_crs}") + print(f'Converting {self.__class__.__name__} CRS from {self.crs} to {new_crs}') new_index = Index(interleaved=False, properties=Property(dimension=3)) project = pyproj.Transformer.from_crs( @@ -283,7 +283,7 @@ def res(self, new_res: float) -> None: if new_res == self.res: return - print(f"Converting {self.__class__.__name__} res from {self.res} to {new_res}") + print(f'Converting {self.__class__.__name__} res from {self.res} to {new_res}') self._res = new_res @property @@ -305,14 +305,14 @@ def files(self) -> list[str]: files: set[str] = set() for path in paths: if os.path.isdir(path): - pathname = os.path.join(path, "**", self.filename_glob) + pathname = os.path.join(path, '**', self.filename_glob) files |= set(glob.iglob(pathname, recursive=True)) elif os.path.isfile(path) or path_is_vsi(path): files.add(path) else: warnings.warn( f"Could not find any relevant files for provided path '{path}'. " - f"Path was ignored.", + f'Path was ignored.', UserWarning, ) @@ -336,13 +336,13 @@ class RasterDataset(GeoDataset): #: groups are searched for to find other files: #: #: * ``band``: replaced with requested band name - filename_regex = ".*" + filename_regex = '.*' #: Date format string used to parse date from filename. #: #: Not used if :attr:`filename_regex` does not contain a ``date`` group or #: ``start`` and ``stop`` groups. - date_format = "%Y%m%d" + date_format = '%Y%m%d' #: True if the dataset only contains model inputs (such as images). False if the #: dataset only contains ground truth model outputs (such as segmentation masks). @@ -386,7 +386,7 @@ def dtype(self) -> torch.dtype: def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float | None = None, bands: Sequence[str] | None = None, @@ -446,12 +446,12 @@ def __init__( else: mint: float = 0 maxt: float = sys.maxsize - if "date" in match.groupdict(): - date = match.group("date") + if 'date' in match.groupdict(): + date = match.group('date') mint, maxt = disambiguate_timestamp(date, self.date_format) - elif "start" in match.groupdict() and "stop" in match.groupdict(): - start = match.group("start") - stop = match.group("stop") + elif 'start' in match.groupdict() and 'stop' in match.groupdict(): + start = match.group('start') + stop = match.group('stop') mint, _ = disambiguate_timestamp(start, self.date_format) _, maxt = disambiguate_timestamp(stop, self.date_format) @@ -471,8 +471,8 @@ def __init__( ] else: msg = ( - f"{self.__class__.__name__} is missing an `all_bands` " - "attribute, so `bands` cannot be specified." + f'{self.__class__.__name__} is missing an `all_bands` ' + 'attribute, so `bands` cannot be specified.' ) raise AssertionError(msg) @@ -496,7 +496,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if not filepaths: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) if self.separate_files: @@ -509,9 +509,9 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: directory = os.path.dirname(filepath) match = re.match(filename_regex, filename) if match: - if "band" in match.groupdict(): - start = match.start("band") - end = match.end("band") + if 'band' in match.groupdict(): + start = match.start('band') + end = match.end('band') filename = filename[:start] + band + filename[end:] filepath = os.path.join(directory, filename) band_filepaths.append(filepath) @@ -520,13 +520,13 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: else: data = self._merge_files(filepaths, query, self.band_indexes) - sample = {"crs": self.crs, "bbox": query} + sample = {'crs': self.crs, 'bbox': query} data = data.to(self.dtype) if self.is_image: - sample["image"] = data + sample['image'] = data else: - sample["mask"] = data + sample['mask'] = data if self.transforms is not None: sample = self.transforms(sample) @@ -601,12 +601,12 @@ class VectorDataset(GeoDataset): #: groups. The following groups are specifically searched for by the base class: #: #: * ``date``: used to calculate ``mint`` and ``maxt`` for ``index`` insertion - filename_regex = ".*" + filename_regex = '.*' #: Date format string used to parse date from filename. #: #: Not used if :attr:`filename_regex` does not contain a ``date`` group. - date_format = "%Y%m%d" + date_format = '%Y%m%d' @property def dtype(self) -> torch.dtype: @@ -623,7 +623,7 @@ def dtype(self) -> torch.dtype: def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float = 0.0001, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -676,8 +676,8 @@ def __init__( else: mint: float = 0 maxt: float = sys.maxsize - if "date" in match.groupdict(): - date = match.group("date") + if 'date' in match.groupdict(): + date = match.group('date') mint, maxt = disambiguate_timestamp(date, self.date_format) coords = (minx, maxx, miny, maxy, mint, maxt) self.index.insert(i, coords, filepath) @@ -706,7 +706,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if not filepaths: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) shapes = [] @@ -724,7 +724,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: for feature in src.filter(bbox=(minx, miny, maxx, maxy)): # Warp geometries to requested CRS shape = fiona.transform.transform_geom( - src.crs, self.crs.to_dict(), feature["geometry"] + src.crs, self.crs.to_dict(), feature['geometry'] ) label = self.get_label(feature) shapes.append((shape, label)) @@ -748,14 +748,14 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: masks = array_to_tensor(masks) masks = masks.to(self.dtype) - sample = {"mask": masks, "crs": self.crs, "bbox": query} + sample = {'mask': masks, 'crs': self.crs, 'bbox': query} if self.transforms is not None: sample = self.transforms(sample) return sample - def get_label(self, feature: "fiona.model.Feature") -> int: + def get_label(self, feature: 'fiona.model.Feature') -> int: """Get label value to use for rendering a feature. Args: @@ -767,7 +767,7 @@ def get_label(self, feature: "fiona.model.Feature") -> int: .. versionadded:: 0.6 """ if self.label_name: - return int(feature["properties"][self.label_name]) + return int(feature['properties'][self.label_name]) return 1 @@ -820,7 +820,7 @@ class NonGeoClassificationDataset(NonGeoDataset, ImageFolder): # type: ignore[m def __init__( self, - root: str = "data", + root: str = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, loader: Callable[[str], Any] | None = pil_loader, is_valid_file: Callable[[str], bool] | None = None, @@ -859,7 +859,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ image, label = self._load_image(index) - sample = {"image": image, "label": label} + sample = {'image': image, 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -884,7 +884,7 @@ def _load_image(self, index: int) -> tuple[Tensor, Tensor]: the image and class label """ img, label = ImageFolder.__getitem__(self, index) - array: "np.typing.NDArray[np.int_]" = np.array(img) + array: 'np.typing.NDArray[np.int_]' = np.array(img) tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -950,7 +950,7 @@ def __init__( for ds in self.datasets: if not isinstance(ds, GeoDataset): - raise ValueError("IntersectionDataset only supports GeoDatasets") + raise ValueError('IntersectionDataset only supports GeoDatasets') self.crs = dataset1.crs self.res = dataset1.res @@ -970,7 +970,7 @@ def _merge_dataset_indices(self) -> None: i += 1 if i == 0: - raise RuntimeError("Datasets have no spatiotemporal intersection") + raise RuntimeError('Datasets have no spatiotemporal intersection') def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image and metadata indexed by query. @@ -986,7 +986,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """ if not query.intersects(self.bounds): raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) # All datasets are guaranteed to have a valid query @@ -1107,7 +1107,7 @@ def __init__( for ds in self.datasets: if not isinstance(ds, GeoDataset): - raise ValueError("UnionDataset only supports GeoDatasets") + raise ValueError('UnionDataset only supports GeoDatasets') self.crs = dataset1.crs self.res = dataset1.res @@ -1138,7 +1138,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """ if not query.intersects(self.bounds): raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) # Not all datasets are guaranteed to have a valid query diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index fb7dd0459e4..cbb4b09cad3 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -61,34 +61,34 @@ class GID15(NonGeoDataset): * https://doi.org/10.1016/j.rse.2019.111322 """ - url = "https://drive.google.com/file/d/1zbkCEXPEKEV6gq19OKmIbaT8bXXfWW6u" - md5 = "615682bf659c3ed981826c6122c10c83" - filename = "gid-15.zip" - directory = "GID" - splits = ["train", "val", "test"] + url = 'https://drive.google.com/file/d/1zbkCEXPEKEV6gq19OKmIbaT8bXXfWW6u' + md5 = '615682bf659c3ed981826c6122c10c83' + filename = 'gid-15.zip' + directory = 'GID' + splits = ['train', 'val', 'test'] classes = [ - "background", - "industrial_land", - "urban_residential", - "rural_residential", - "traffic_land", - "paddy_field", - "irrigated_land", - "dry_cropland", - "garden_plot", - "arbor_woodland", - "shrub_land", - "natural_grassland", - "artificial_grassland", - "river", - "lake", - "pond", + 'background', + 'industrial_land', + 'urban_residential', + 'rural_residential', + 'traffic_land', + 'paddy_field', + 'irrigated_land', + 'dry_cropland', + 'garden_plot', + 'arbor_woodland', + 'shrub_land', + 'natural_grassland', + 'artificial_grassland', + 'river', + 'lake', + 'pond', ] def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -132,13 +132,13 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - image = self._load_image(files["image"]) + image = self._load_image(files['image']) - if self.split != "test": - mask = self._load_target(files["mask"]) - sample = {"image": image, "mask": mask} + if self.split != 'test': + mask = self._load_target(files['mask']) + sample = {'image': image, 'mask': mask} else: - sample = {"image": image} + sample = {'image': image} if self.transforms is not None: sample = self.transforms(sample) @@ -163,12 +163,12 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: Returns: list of dicts containing paths for each pair of image, mask """ - image_root = os.path.join(root, "GID", "img_dir") - images = glob.glob(os.path.join(image_root, split, "*.tif")) + image_root = os.path.join(root, 'GID', 'img_dir') + images = glob.glob(os.path.join(image_root, split, '*.tif')) images = sorted(images) - if split != "test": + if split != 'test': masks = [ - image.replace("img_dir", "ann_dir").replace(".tif", "_15label.png") + image.replace('img_dir', 'ann_dir').replace('.tif', '_15label.png') for image in images ] files = [dict(image=image, mask=mask) for image, mask in zip(images, masks)] @@ -188,7 +188,7 @@ def _load_image(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('RGB')) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)).float() @@ -205,7 +205,7 @@ def _load_target(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("L")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('L')) tensor = torch.from_numpy(array) tensor = tensor.to(torch.long) return tensor @@ -224,7 +224,7 @@ def _check_integrity(self) -> bool: def _download(self) -> None: """Download the dataset and extract it.""" if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return download_and_extract_archive( @@ -246,36 +246,36 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure .. versionadded:: 0.2 """ - if self.split != "test": - image, mask = sample["image"], sample["mask"] + if self.split != 'test': + image, mask = sample['image'], sample['mask'] ncols = 2 else: - image = sample["image"] + image = sample['image'] ncols = 1 - if "prediction" in sample: + if 'prediction' in sample: ncols += 1 - pred = sample["prediction"] + pred = sample['prediction'] fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 10)) - if self.split != "test": + if self.split != 'test': axs[0].imshow(image.permute(1, 2, 0)) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(mask) - axs[1].axis("off") - if "prediction" in sample: + axs[1].axis('off') + if 'prediction' in sample: axs[2].imshow(pred) - axs[2].axis("off") + axs[2].axis('off') else: - if "prediction" in sample: + if 'prediction' in sample: axs[0].imshow(image.permute(1, 2, 0)) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(pred) - axs[1].axis("off") + axs[1].axis('off') else: axs.imshow(image.permute(1, 2, 0)) - axs.axis("off") + axs.axis('off') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/globbiomass.py b/torchgeo/datasets/globbiomass.py index 96604b523b4..88db2a6277e 100644 --- a/torchgeo/datasets/globbiomass.py +++ b/torchgeo/datasets/globbiomass.py @@ -54,75 +54,75 @@ class GlobBiomass(RasterDataset): _(?P[a-z]{3}) """ - measurements = ["agb", "gsv"] + measurements = ['agb', 'gsv'] md5s = { - "N00E020_agb.zip": "bd83a3a4c143885d1962bde549413be6", - "N00E020_gsv.zip": "da5ddb88e369df2d781a0c6be008ae79", - "N00E060_agb.zip": "85eaca95b939086cc528e396b75bd097", - "N00E060_gsv.zip": "ec84174697c17ca4db2967374446ab30", - "N00E100_agb.zip": "c50c7c996615c1c6f19cb383ef11812a", - "N00E100_gsv.zip": "6e0ff834db822d3710ed40d00a200e8f", - "N00E140_agb.zip": "73f0b44b9e137789cefb711ef9aa281b", - "N00E140_gsv.zip": "43be3dd4563b63d12de006d240ba5edf", - "N00W020_agb.zip": "4fb979732f0a22cc7a2ca3667698084b", - "N00W020_gsv.zip": "ac5bbeedaa0f94a5e01c7a86751d6891", - "N00W060_agb.zip": "59da0b32b08fbbcd2dd76926a849562b", - "N00W060_gsv.zip": "5ca9598f621a7d10ab1d623ee5b44aa6", - "N00W100_agb.zip": "a819b75a39e8d4d37b15745c96ea1e35", - "N00W100_gsv.zip": "71aad3669d522f7190029ec33350831a", - "N00W180_agb.zip": "5a1d7486d8310fbaf4980a76e9ffcd78", - "N00W180_gsv.zip": "274be7dbb4e6d7563773cc302129a9c7", - "N40E020_agb.zip": "38bc7170f94734b365d614a566f872e7", - "N40E020_gsv.zip": "b52c1c777d68c331cc058a273530536e", - "N40E060_agb.zip": "1d94ad59f3f26664fefa4d7308b63f05", - "N40E060_gsv.zip": "3b68786b7641400077ef340a7ef748f4", - "N40E100_agb.zip": "3ccb436047c0db416fb237435645989c", - "N40E100_gsv.zip": "c44efe9e7ce2ae0f2e39b0db10f06c71", - "N40E140_agb.zip": "35ea51da229af1312ba4aaafc0dbd5d6", - "N40E140_gsv.zip": "8431828708c84263a4971a8779864f69", - "N40W020_agb.zip": "38345a1826719301ab1a0251b4835cc2", - "N40W020_gsv.zip": "5e136b7c2f921cd425cb5cc5669e7693", - "N40W060_agb.zip": "e3f54df1d188c0132ecf5aef3dc54ca6", - "N40W060_gsv.zip": "09093d78ffef0220cb459a88e61e3093", - "N40W100_agb.zip": "cc21ce8793e5594dc7a0b45f0d0f1466", - "N40W100_gsv.zip": "21be1398df88818d04dcce422e2010a6", - "N40W140_agb.zip": "64665f53fad7386abb1cf4a44a1c8b1a", - "N40W140_gsv.zip": "b59405219fc807cbe745789fbb6936a6", - "N40W180_agb.zip": "f83ef786da8333ee739e49be108994c1", - "N40W180_gsv.zip": "1f2eb8912b1a204eaeb2858b7e398baa", - "N80E020_agb.zip": "7f7aed44802890672bd908e28eda6f15", - "N80E020_gsv.zip": "6e285eec66306e56dc3a81adc0da2a27", - "N80E060_agb.zip": "55e7031e0207888f25f27efa9a0ab8f4", - "N80E060_gsv.zip": "8d14c7f61ad2aed527e124f9aacae30c", - "N80E100_agb.zip": "562eafd2813ff06e47284c48324bb1c7", - "N80E100_gsv.zip": "73067e0fac442c330ae2294996280042", - "N80E140_agb.zip": "1b51ce0df0dba925c5ef2149bebca115", - "N80E140_gsv.zip": "37ee3047d281fc34fa3a9e024a8317a1", - "N80W020_agb.zip": "60dde6adc0dfa219a34c976367f571c0", - "N80W020_gsv.zip": "b7be4e97bb4179710291ee8dee27f538", - "N80W060_agb.zip": "db7d35d0375851c4a181c3a8fa8b480e", - "N80W060_gsv.zip": "d36ffcf4622348382454c979baf53234", - "N80W100_agb.zip": "c0dbf53e635dabf9a4d7d1756adeda69", - "N80W100_gsv.zip": "abdeaf0d65da1216c326b6d0ce27d61b", - "N80W140_agb.zip": "7719c0efd23cd86215fea0285fd0ea4a", - "N80W140_gsv.zip": "499969bed381197ee9427a2e3f455a2e", - "N80W180_agb.zip": "e3a163d1944e1989a07225d262a01c6f", - "N80W180_gsv.zip": "5d39ec0368cfe63c40c66d61ae07f577", - "S40E140_agb.zip": "263eb077a984117b41cc7cfa0c32915b", - "S40E140_gsv.zip": "e0ffad85fbade4fb711cc5b3c7543898", - "S40W060_agb.zip": "2cbf6858c48f36add896db660826829b", - "S40W060_gsv.zip": "04dbfd4aca0bd2a2a7d8f563c8659252", - "S40W100_agb.zip": "ae89f021e7d9c2afea433878f77d1dd6", - "S40W100_gsv.zip": "b6aa3f276e1b51dade803a71df2acde6", + 'N00E020_agb.zip': 'bd83a3a4c143885d1962bde549413be6', + 'N00E020_gsv.zip': 'da5ddb88e369df2d781a0c6be008ae79', + 'N00E060_agb.zip': '85eaca95b939086cc528e396b75bd097', + 'N00E060_gsv.zip': 'ec84174697c17ca4db2967374446ab30', + 'N00E100_agb.zip': 'c50c7c996615c1c6f19cb383ef11812a', + 'N00E100_gsv.zip': '6e0ff834db822d3710ed40d00a200e8f', + 'N00E140_agb.zip': '73f0b44b9e137789cefb711ef9aa281b', + 'N00E140_gsv.zip': '43be3dd4563b63d12de006d240ba5edf', + 'N00W020_agb.zip': '4fb979732f0a22cc7a2ca3667698084b', + 'N00W020_gsv.zip': 'ac5bbeedaa0f94a5e01c7a86751d6891', + 'N00W060_agb.zip': '59da0b32b08fbbcd2dd76926a849562b', + 'N00W060_gsv.zip': '5ca9598f621a7d10ab1d623ee5b44aa6', + 'N00W100_agb.zip': 'a819b75a39e8d4d37b15745c96ea1e35', + 'N00W100_gsv.zip': '71aad3669d522f7190029ec33350831a', + 'N00W180_agb.zip': '5a1d7486d8310fbaf4980a76e9ffcd78', + 'N00W180_gsv.zip': '274be7dbb4e6d7563773cc302129a9c7', + 'N40E020_agb.zip': '38bc7170f94734b365d614a566f872e7', + 'N40E020_gsv.zip': 'b52c1c777d68c331cc058a273530536e', + 'N40E060_agb.zip': '1d94ad59f3f26664fefa4d7308b63f05', + 'N40E060_gsv.zip': '3b68786b7641400077ef340a7ef748f4', + 'N40E100_agb.zip': '3ccb436047c0db416fb237435645989c', + 'N40E100_gsv.zip': 'c44efe9e7ce2ae0f2e39b0db10f06c71', + 'N40E140_agb.zip': '35ea51da229af1312ba4aaafc0dbd5d6', + 'N40E140_gsv.zip': '8431828708c84263a4971a8779864f69', + 'N40W020_agb.zip': '38345a1826719301ab1a0251b4835cc2', + 'N40W020_gsv.zip': '5e136b7c2f921cd425cb5cc5669e7693', + 'N40W060_agb.zip': 'e3f54df1d188c0132ecf5aef3dc54ca6', + 'N40W060_gsv.zip': '09093d78ffef0220cb459a88e61e3093', + 'N40W100_agb.zip': 'cc21ce8793e5594dc7a0b45f0d0f1466', + 'N40W100_gsv.zip': '21be1398df88818d04dcce422e2010a6', + 'N40W140_agb.zip': '64665f53fad7386abb1cf4a44a1c8b1a', + 'N40W140_gsv.zip': 'b59405219fc807cbe745789fbb6936a6', + 'N40W180_agb.zip': 'f83ef786da8333ee739e49be108994c1', + 'N40W180_gsv.zip': '1f2eb8912b1a204eaeb2858b7e398baa', + 'N80E020_agb.zip': '7f7aed44802890672bd908e28eda6f15', + 'N80E020_gsv.zip': '6e285eec66306e56dc3a81adc0da2a27', + 'N80E060_agb.zip': '55e7031e0207888f25f27efa9a0ab8f4', + 'N80E060_gsv.zip': '8d14c7f61ad2aed527e124f9aacae30c', + 'N80E100_agb.zip': '562eafd2813ff06e47284c48324bb1c7', + 'N80E100_gsv.zip': '73067e0fac442c330ae2294996280042', + 'N80E140_agb.zip': '1b51ce0df0dba925c5ef2149bebca115', + 'N80E140_gsv.zip': '37ee3047d281fc34fa3a9e024a8317a1', + 'N80W020_agb.zip': '60dde6adc0dfa219a34c976367f571c0', + 'N80W020_gsv.zip': 'b7be4e97bb4179710291ee8dee27f538', + 'N80W060_agb.zip': 'db7d35d0375851c4a181c3a8fa8b480e', + 'N80W060_gsv.zip': 'd36ffcf4622348382454c979baf53234', + 'N80W100_agb.zip': 'c0dbf53e635dabf9a4d7d1756adeda69', + 'N80W100_gsv.zip': 'abdeaf0d65da1216c326b6d0ce27d61b', + 'N80W140_agb.zip': '7719c0efd23cd86215fea0285fd0ea4a', + 'N80W140_gsv.zip': '499969bed381197ee9427a2e3f455a2e', + 'N80W180_agb.zip': 'e3a163d1944e1989a07225d262a01c6f', + 'N80W180_gsv.zip': '5d39ec0368cfe63c40c66d61ae07f577', + 'S40E140_agb.zip': '263eb077a984117b41cc7cfa0c32915b', + 'S40E140_gsv.zip': 'e0ffad85fbade4fb711cc5b3c7543898', + 'S40W060_agb.zip': '2cbf6858c48f36add896db660826829b', + 'S40W060_gsv.zip': '04dbfd4aca0bd2a2a7d8f563c8659252', + 'S40W100_agb.zip': 'ae89f021e7d9c2afea433878f77d1dd6', + 'S40W100_gsv.zip': 'b6aa3f276e1b51dade803a71df2acde6', } def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float | None = None, - measurement: str = "agb", + measurement: str = 'agb', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, checksum: bool = False, @@ -151,14 +151,14 @@ def __init__( self.paths = paths self.checksum = checksum - assert isinstance(measurement, str), "Measurement argument must be a str." + assert isinstance(measurement, str), 'Measurement argument must be a str.' assert ( measurement in self.measurements - ), f"You have entered an invalid measurement, please choose one of {self.measurements}." + ), f'You have entered an invalid measurement, please choose one of {self.measurements}.' self.measurement = measurement - self.filename_glob = f"*0_{self.measurement}*.tif" - self.zipfile_glob = f"*0_{self.measurement}.zip" + self.filename_glob = f'*0_{self.measurement}*.tif' + self.zipfile_glob = f'*0_{self.measurement}.zip' self._verify() @@ -182,18 +182,18 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if not filepaths: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) - measurement_paths = [f for f in filepaths if "err" not in f] + measurement_paths = [f for f in filepaths if 'err' not in f] mask = self._merge_files(measurement_paths, query) - std_error_paths = [f for f in filepaths if "err" in f] + std_error_paths = [f for f in filepaths if 'err' in f] std_err_mask = self._merge_files(std_error_paths, query) mask = torch.cat((mask, std_err_mask), dim=0) - sample = {"mask": mask, "crs": self.crs, "bbox": query} + sample = {'mask': mask, 'crs': self.crs, 'bbox': query} if self.transforms is not None: sample = self.transforms(sample) @@ -213,7 +213,7 @@ def _verify(self) -> None: for zipfile in glob.iglob(pathname): filename = os.path.basename(zipfile) if self.checksum and not check_integrity(zipfile, self.md5s[filename]): - raise RuntimeError("Dataset found, but corrupted.") + raise RuntimeError('Dataset found, but corrupted.') extract_archive(zipfile) return @@ -235,13 +235,13 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - tensor = sample["mask"] + tensor = sample['mask'] mask = tensor[0, ...] error_mask = tensor[1, ...] - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - pred = sample["prediction"][0, ...] + pred = sample['prediction'][0, ...] ncols = 3 else: ncols = 2 @@ -250,23 +250,23 @@ def plot( if showing_predictions: axs[0].imshow(mask) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(error_mask) - axs[1].axis("off") + axs[1].axis('off') axs[2].imshow(pred) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[0].set_title("Mask") - axs[1].set_title("Uncertainty Mask") - axs[2].set_title("Prediction") + axs[0].set_title('Mask') + axs[1].set_title('Uncertainty Mask') + axs[2].set_title('Prediction') else: axs[0].imshow(mask) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(error_mask) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Mask") - axs[1].set_title("Uncertainty Mask") + axs[0].set_title('Mask') + axs[1].set_title('Uncertainty Mask') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index 4fdce326132..c17df135759 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -95,60 +95,60 @@ class IDTReeS(NonGeoDataset): """ classes = { - "ACPE": "Acer pensylvanicum L.", - "ACRU": "Acer rubrum L.", - "ACSA3": "Acer saccharum Marshall", - "AMLA": "Amelanchier laevis Wiegand", - "BETUL": "Betula sp.", - "CAGL8": "Carya glabra (Mill.) Sweet", - "CATO6": "Carya tomentosa (Lam.) Nutt.", - "FAGR": "Fagus grandifolia Ehrh.", - "GOLA": "Gordonia lasianthus (L.) Ellis", - "LITU": "Liriodendron tulipifera L.", - "LYLU3": "Lyonia lucida (Lam.) K. Koch", - "MAGNO": "Magnolia sp.", - "NYBI": "Nyssa biflora Walter", - "NYSY": "Nyssa sylvatica Marshall", - "OXYDE": "Oxydendrum sp.", - "PEPA37": "Persea palustris (Raf.) Sarg.", - "PIEL": "Pinus elliottii Engelm.", - "PIPA2": "Pinus palustris Mill.", - "PINUS": "Pinus sp.", - "PITA": "Pinus taeda L.", - "PRSE2": "Prunus serotina Ehrh.", - "QUAL": "Quercus alba L.", - "QUCO2": "Quercus coccinea", - "QUGE2": "Quercus geminata Small", - "QUHE2": "Quercus hemisphaerica W. Bartram ex Willd.", - "QULA2": "Quercus laevis Walter", - "QULA3": "Quercus laurifolia Michx.", - "QUMO4": "Quercus montana Willd.", - "QUNI": "Quercus nigra L.", - "QURU": "Quercus rubra L.", - "QUERC": "Quercus sp.", - "ROPS": "Robinia pseudoacacia L.", - "TSCA": "Tsuga canadensis (L.) Carriere", + 'ACPE': 'Acer pensylvanicum L.', + 'ACRU': 'Acer rubrum L.', + 'ACSA3': 'Acer saccharum Marshall', + 'AMLA': 'Amelanchier laevis Wiegand', + 'BETUL': 'Betula sp.', + 'CAGL8': 'Carya glabra (Mill.) Sweet', + 'CATO6': 'Carya tomentosa (Lam.) Nutt.', + 'FAGR': 'Fagus grandifolia Ehrh.', + 'GOLA': 'Gordonia lasianthus (L.) Ellis', + 'LITU': 'Liriodendron tulipifera L.', + 'LYLU3': 'Lyonia lucida (Lam.) K. Koch', + 'MAGNO': 'Magnolia sp.', + 'NYBI': 'Nyssa biflora Walter', + 'NYSY': 'Nyssa sylvatica Marshall', + 'OXYDE': 'Oxydendrum sp.', + 'PEPA37': 'Persea palustris (Raf.) Sarg.', + 'PIEL': 'Pinus elliottii Engelm.', + 'PIPA2': 'Pinus palustris Mill.', + 'PINUS': 'Pinus sp.', + 'PITA': 'Pinus taeda L.', + 'PRSE2': 'Prunus serotina Ehrh.', + 'QUAL': 'Quercus alba L.', + 'QUCO2': 'Quercus coccinea', + 'QUGE2': 'Quercus geminata Small', + 'QUHE2': 'Quercus hemisphaerica W. Bartram ex Willd.', + 'QULA2': 'Quercus laevis Walter', + 'QULA3': 'Quercus laurifolia Michx.', + 'QUMO4': 'Quercus montana Willd.', + 'QUNI': 'Quercus nigra L.', + 'QURU': 'Quercus rubra L.', + 'QUERC': 'Quercus sp.', + 'ROPS': 'Robinia pseudoacacia L.', + 'TSCA': 'Tsuga canadensis (L.) Carriere', } metadata = { - "train": { - "url": "https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1", # noqa: E501 - "md5": "5ddfa76240b4bb6b4a7861d1d31c299c", - "filename": "IDTREES_competition_train_v2.zip", + 'train': { + 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1', # noqa: E501 + 'md5': '5ddfa76240b4bb6b4a7861d1d31c299c', + 'filename': 'IDTREES_competition_train_v2.zip', }, - "test": { - "url": "https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1", # noqa: E501 - "md5": "b108931c84a70f2a38a8234290131c9b", - "filename": "IDTREES_competition_test_v2.zip", + 'test': { + 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1', # noqa: E501 + 'md5': 'b108931c84a70f2a38a8234290131c9b', + 'filename': 'IDTREES_competition_test_v2.zip', }, } - directories = {"train": ["train"], "test": ["task1", "task2"]} + directories = {'train': ['train'], 'test': ['task1', 'task2']} image_size = (200, 200) def __init__( self, - root: str = "data", - split: str = "train", - task: str = "task1", + root: str = 'data', + split: str = 'train', + task: str = 'task1', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -169,8 +169,8 @@ def __init__( ImportError: if laspy is not installed DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert split in ["train", "test"] - assert task in ["task1", "task2"] + assert split in ['train', 'test'] + assert task in ['task1', 'task2'] self.root = root self.split = split self.task = task @@ -186,7 +186,7 @@ def __init__( import laspy # noqa: F401 except ImportError: raise ImportError( - "laspy is not installed and is required to use this dataset" + 'laspy is not installed and is required to use this dataset' ) self.images, self.geometries, self.labels = self._load(root) @@ -202,28 +202,28 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ path = self.images[index] image = self._load_image(path).to(torch.uint8) - hsi = self._load_image(path.replace("RGB", "HSI")) - chm = self._load_image(path.replace("RGB", "CHM")) - las = self._load_las(path.replace("RGB", "LAS").replace(".tif", ".las")) - sample = {"image": image, "hsi": hsi, "chm": chm, "las": las} - - if self.split == "test": - if self.task == "task2": - sample["boxes"] = self._load_boxes(path) - h, w = sample["image"].shape[1:] - sample["boxes"], _ = self._filter_boxes( - image_size=(h, w), min_size=1, boxes=sample["boxes"], labels=None + hsi = self._load_image(path.replace('RGB', 'HSI')) + chm = self._load_image(path.replace('RGB', 'CHM')) + las = self._load_las(path.replace('RGB', 'LAS').replace('.tif', '.las')) + sample = {'image': image, 'hsi': hsi, 'chm': chm, 'las': las} + + if self.split == 'test': + if self.task == 'task2': + sample['boxes'] = self._load_boxes(path) + h, w = sample['image'].shape[1:] + sample['boxes'], _ = self._filter_boxes( + image_size=(h, w), min_size=1, boxes=sample['boxes'], labels=None ) else: - sample["boxes"] = self._load_boxes(path) - sample["label"] = self._load_target(path) + sample['boxes'] = self._load_boxes(path) + sample['label'] = self._load_target(path) - h, w = sample["image"].shape[1:] - sample["boxes"], sample["label"] = self._filter_boxes( + h, w = sample['image'].shape[1:] + sample['boxes'], sample['label'] = self._filter_boxes( image_size=(h, w), min_size=1, - boxes=sample["boxes"], - labels=sample["label"], + boxes=sample['boxes'], + labels=sample['label'], ) if self.transforms is not None: @@ -265,7 +265,7 @@ def _load_las(self, path: str) -> Tensor: import laspy las = laspy.read(path) - array: "np.typing.NDArray[np.int_]" = np.stack([las.x, las.y, las.z], axis=0) + array: 'np.typing.NDArray[np.int_]' = np.stack([las.x, las.y, las.z], axis=0) tensor = torch.from_numpy(array) return tensor @@ -284,10 +284,10 @@ def _load_boxes(self, path: str) -> Tensor: # Find object ids and geometries # The train set geometry->image mapping is contained # in the train/Field/itc_rsFile.csv file - if self.split == "train": - indices = self.labels["rsFile"] == base_path - ids = self.labels[indices]["id"].tolist() - geoms = [geometries[i]["geometry"]["coordinates"][0][:4] for i in ids] + if self.split == 'train': + indices = self.labels['rsFile'] == base_path + ids = self.labels[indices]['id'].tolist() + geoms = [geometries[i]['geometry']['coordinates'][0][:4] for i in ids] # The test set has no mapping csv. The mapping is inside of the geometry # properties i.e. geom["property"]["plotID"] contains the RGB image filename # Return all geometries with the matching RGB image filename of the sample @@ -295,9 +295,9 @@ def _load_boxes(self, path: str) -> Tensor: ids = [ k for k, v in geometries.items() - if v["properties"]["plotID"] == base_path + if v['properties']['plotID'] == base_path ] - geoms = [geometries[i]["geometry"]["coordinates"][0][:4] for i in ids] + geoms = [geometries[i]['geometry']['coordinates'][0][:4] for i in ids] # Convert to pixel coords boxes = [] @@ -324,10 +324,10 @@ def _load_target(self, path: str) -> Tensor: """ # Find indices for objects in the image base_path = os.path.basename(path) - indices = self.labels["rsFile"] == base_path + indices = self.labels['rsFile'] == base_path # Load object labels - classes = self.labels[indices]["taxonID"].tolist() + classes = self.labels[indices]['taxonID'].tolist() labels = [self.class2idx[c] for c in classes] tensor = torch.tensor(labels) return tensor @@ -343,20 +343,20 @@ def _load( Returns: the image path, geometries, and labels """ - if self.split == "train": + if self.split == 'train': directory = os.path.join(root, self.directories[self.split][0]) labels: pd.DataFrame = self._load_labels(directory) geoms = self._load_geometries(directory) else: directory = os.path.join(root, self.task) - if self.task == "task1": + if self.task == 'task1': geoms = None labels = None else: geoms = self._load_geometries(directory) labels = None - images = glob.glob(os.path.join(directory, "RemoteSensing", "RGB", "*.tif")) + images = glob.glob(os.path.join(directory, 'RemoteSensing', 'RGB', '*.tif')) return images, geoms, labels @@ -369,13 +369,13 @@ def _load_labels(self, directory: str) -> Any: Returns: a pandas DataFrame containing the labels for each image """ - path_mapping = os.path.join(directory, "Field", "itc_rsFile.csv") - path_labels = os.path.join(directory, "Field", "train_data.csv") + path_mapping = os.path.join(directory, 'Field', 'itc_rsFile.csv') + path_labels = os.path.join(directory, 'Field', 'train_data.csv') df_mapping = pd.read_csv(path_mapping) df_labels = pd.read_csv(path_labels) - df_mapping = df_mapping.set_index("indvdID", drop=True) - df_labels = df_labels.set_index("indvdID", drop=True) - df = df_labels.join(df_mapping, on="indvdID") + df_mapping = df_mapping.set_index('indvdID', drop=True) + df_labels = df_labels.set_index('indvdID', drop=True) + df = df_labels.join(df_mapping, on='indvdID') df = df.drop_duplicates() df.reset_index() return df @@ -389,7 +389,7 @@ def _load_geometries(self, directory: str) -> dict[int, dict[str, Any]]: Returns: a dict containing the geometries for each object """ - filepaths = glob.glob(os.path.join(directory, "ITC", "*.shp")) + filepaths = glob.glob(os.path.join(directory, 'ITC', '*.shp')) i = 0 features: dict[int, dict[str, Any]] = {} @@ -397,8 +397,8 @@ def _load_geometries(self, directory: str) -> dict[int, dict[str, Any]]: with fiona.open(path) as src: for feature in src: # The train set has a unique id for each geometry in the properties - if self.split == "train": - features[feature["properties"]["id"]] = feature + if self.split == 'train': + features[feature['properties']['id']] = feature # The test set has no unique id so create a dummy id else: features[i] = feature @@ -444,9 +444,9 @@ def _filter_boxes( def _verify(self) -> None: """Verify the integrity of the dataset.""" - url = self.metadata[self.split]["url"] - md5 = self.metadata[self.split]["md5"] - filename = self.metadata[self.split]["filename"] + url = self.metadata[self.split]['url'] + md5 = self.metadata[self.split]['md5'] + filename = self.metadata[self.split]['filename'] directories = self.directories[self.split] # Check if the files already exist @@ -499,58 +499,58 @@ def normalize(x: Tensor) -> Tensor: ncols = 3 - hsi = normalize(sample["hsi"][hsi_indices, :, :]).permute((1, 2, 0)).numpy() - chm = normalize(sample["chm"]).permute((1, 2, 0)).numpy() + hsi = normalize(sample['hsi'][hsi_indices, :, :]).permute((1, 2, 0)).numpy() + chm = normalize(sample['chm']).permute((1, 2, 0)).numpy() - if "boxes" in sample and len(sample["boxes"]): + if 'boxes' in sample and len(sample['boxes']): labels = ( - [self.idx2class[int(i)] for i in sample["label"]] - if "label" in sample + [self.idx2class[int(i)] for i in sample['label']] + if 'label' in sample else None ) image = draw_bounding_boxes( - image=sample["image"], boxes=sample["boxes"], labels=labels + image=sample['image'], boxes=sample['boxes'], labels=labels ) image = image.permute((1, 2, 0)).numpy() else: - image = sample["image"].permute((1, 2, 0)).numpy() + image = sample['image'].permute((1, 2, 0)).numpy() - if "prediction_boxes" in sample and len(sample["prediction_boxes"]): + if 'prediction_boxes' in sample and len(sample['prediction_boxes']): ncols += 1 labels = ( - [self.idx2class[int(i)] for i in sample["prediction_label"]] - if "prediction_label" in sample + [self.idx2class[int(i)] for i in sample['prediction_label']] + if 'prediction_label' in sample else None ) preds = draw_bounding_boxes( - image=sample["image"], boxes=sample["prediction_boxes"], labels=labels + image=sample['image'], boxes=sample['prediction_boxes'], labels=labels ) preds = preds.permute((1, 2, 0)).numpy() fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(hsi) - axs[1].axis("off") + axs[1].axis('off') axs[2].imshow(chm) - axs[2].axis("off") + axs[2].axis('off') if ncols > 3: axs[3].imshow(preds) - axs[3].axis("off") + axs[3].axis('off') if show_titles: - axs[0].set_title("Ground Truth") - axs[1].set_title("Hyperspectral False Color Image") - axs[2].set_title("Canopy Height Model") + axs[0].set_title('Ground Truth') + axs[1].set_title('Hyperspectral False Color Image') + axs[2].set_title('Canopy Height Model') if ncols > 3: - axs[3].set_title("Predictions") + axs[3].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) return fig - def plot_las(self, index: int) -> "pyvista.Plotter": # type: ignore[name-defined] # noqa: F821 + def plot_las(self, index: int) -> 'pyvista.Plotter': # type: ignore[name-defined] # noqa: F821 """Plot a sample point cloud at the index. Args: @@ -569,22 +569,22 @@ def plot_las(self, index: int) -> "pyvista.Plotter": # type: ignore[name-define import pyvista # noqa: F401 except ImportError: raise ImportError( - "pyvista is not installed and is required to plot point clouds" + 'pyvista is not installed and is required to plot point clouds' ) import laspy path = self.images[index] - path = path.replace("RGB", "LAS").replace(".tif", ".las") + path = path.replace('RGB', 'LAS').replace('.tif', '.las') las = laspy.read(path) - points: "np.typing.NDArray[np.int_]" = np.stack( + points: 'np.typing.NDArray[np.int_]' = np.stack( [las.x, las.y, las.z], axis=0 ).transpose((1, 0)) point_cloud = pyvista.PolyData(points) # Some point cloud files have no color->points mapping - if hasattr(las, "red"): + if hasattr(las, 'red'): colors = np.stack([las.red, las.green, las.blue], axis=0) colors = colors.transpose((1, 0)) / np.iinfo(np.uint16).max - point_cloud["colors"] = colors + point_cloud['colors'] = colors return point_cloud diff --git a/torchgeo/datasets/inaturalist.py b/torchgeo/datasets/inaturalist.py index 6838a5cdf4b..b8d48b6e1dd 100644 --- a/torchgeo/datasets/inaturalist.py +++ b/torchgeo/datasets/inaturalist.py @@ -33,7 +33,7 @@ class INaturalist(GeoDataset): res = 0 _crs = CRS.from_epsg(4326) # Lat/Lon - def __init__(self, root: str = "data") -> None: + def __init__(self, root: str = 'data') -> None: """Initialize a new Dataset instance. Args: @@ -46,15 +46,15 @@ def __init__(self, root: str = "data") -> None: self.root = root - files = glob.glob(os.path.join(root, "**.csv")) + files = glob.glob(os.path.join(root, '**.csv')) if not files: raise DatasetNotFoundError(self) # Read CSV file data = pd.read_csv( files[0], - engine="c", - usecols=["observed_on", "time_observed_at", "latitude", "longitude"], + engine='c', + usecols=['observed_on', 'time_observed_at', 'latitude', 'longitude'], ) # Dataset contains many possible timestamps: @@ -76,9 +76,9 @@ def __init__(self, root: str = "data") -> None: continue if not pd.isna(time): - mint, maxt = disambiguate_timestamp(time, "%Y-%m-%d %H:%M:%S %z") + mint, maxt = disambiguate_timestamp(time, '%Y-%m-%d %H:%M:%S %z') elif not pd.isna(date): - mint, maxt = disambiguate_timestamp(date, "%Y-%m-%d") + mint, maxt = disambiguate_timestamp(date, '%Y-%m-%d') else: mint, maxt = 0, sys.maxsize @@ -103,9 +103,9 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if not bboxes: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) - sample = {"crs": self.crs, "bbox": bboxes} + sample = {'crs': self.crs, 'bbox': bboxes} return sample diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index b42a005e5b4..76c9dfd3678 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -57,14 +57,14 @@ class InriaAerialImageLabeling(NonGeoDataset): Added support for a *val* split. """ - directory = "AerialImageDataset" - filename = "NEW2-AerialImageDataset.zip" - md5 = "4b1acfe84ae9961edc1a6049f940380f" + directory = 'AerialImageDataset' + filename = 'NEW2-AerialImageDataset.zip' + md5 = '4b1acfe84ae9961edc1a6049f940380f' def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, checksum: bool = False, ) -> None: @@ -82,7 +82,7 @@ def __init__( DatasetNotFoundError: If dataset is not found. """ self.root = root - assert split in {"train", "val", "test"} + assert split in {'train', 'val', 'test'} self.split = split self.transforms = transforms self.checksum = checksum @@ -100,28 +100,28 @@ def _load_files(self, root: str) -> list[dict[str, str]]: list of dicts containing paths for each pair of image and label """ files = [] - split = "train" if self.split in ["train", "val"] else "test" + split = 'train' if self.split in ['train', 'val'] else 'test' root_dir = os.path.join(root, self.directory, split) - pattern = re.compile(r"([A-Za-z]+)(\d+)") + pattern = re.compile(r'([A-Za-z]+)(\d+)') - images = glob.glob(os.path.join(root_dir, "images", "*.tif")) + images = glob.glob(os.path.join(root_dir, 'images', '*.tif')) images = sorted(images) - if split == "train": - labels = glob.glob(os.path.join(root_dir, "gt", "*.tif")) + if split == 'train': + labels = glob.glob(os.path.join(root_dir, 'gt', '*.tif')) labels = sorted(labels) for img, lbl in zip(images, labels): if match := pattern.search(img): idx = int(match.group(2)) # For validation, use the first 5 images of every location - if self.split == "train" and idx > 5: - files.append({"image": img, "label": lbl}) - elif self.split == "val" and idx < 6: - files.append({"image": img, "label": lbl}) + if self.split == 'train' and idx > 5: + files.append({'image': img, 'label': lbl}) + elif self.split == 'val' and idx < 6: + files.append({'image': img, 'label': lbl}) else: for img in images: - files.append({"image": img}) + files.append({'image': img}) return files @@ -172,11 +172,11 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - img = self._load_image(files["image"]) - sample = {"image": img} - if files.get("label"): - mask = self._load_target(files["label"]) - sample["mask"] = mask + img = self._load_image(files['image']) + sample = {'image': img} + if files.get('label'): + mask = self._load_target(files['label']) + sample['mask'] = mask if self.transforms is not None: sample = self.transforms(sample) @@ -193,8 +193,8 @@ def _verify(self) -> None: if not os.path.isfile(archive_path): raise DatasetNotFoundError(self) if not check_integrity(archive_path, md5_hash): - raise RuntimeError("Dataset corrupted") - print("Extracting...") + raise RuntimeError('Dataset corrupted') + print('Extracting...') extract_archive(archive_path) def plot( @@ -213,40 +213,40 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - image = np.rollaxis(sample["image"][:3].numpy(), 0, 3) + image = np.rollaxis(sample['image'][:3].numpy(), 0, 3) image = percentile_normalization(image, axis=(0, 1)) ncols = 1 - show_mask = "mask" in sample - show_predictions = "prediction" in sample + show_mask = 'mask' in sample + show_predictions = 'prediction' in sample if show_mask: - mask = sample["mask"].numpy() + mask = sample['mask'].numpy() ncols += 1 if show_predictions: - prediction = sample["prediction"].numpy() + prediction = sample['prediction'].numpy() ncols += 1 fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8)) if not isinstance(axs, np.ndarray): axs = [axs] axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') if show_titles: - axs[0].set_title("Image") + axs[0].set_title('Image') if show_mask: - axs[1].imshow(mask, interpolation="none") - axs[1].axis("off") + axs[1].imshow(mask, interpolation='none') + axs[1].axis('off') if show_titles: - axs[1].set_title("Label") + axs[1].set_title('Label') if show_predictions: - axs[2].imshow(prediction, interpolation="none") - axs[2].axis("off") + axs[2].imshow(prediction, interpolation='none') + axs[2].axis('off') if show_titles: - axs[2].set_title("Prediction") + axs[2].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/l7irish.py b/torchgeo/datasets/l7irish.py index 7e12fe67ddf..10de6341798 100644 --- a/torchgeo/datasets/l7irish.py +++ b/torchgeo/datasets/l7irish.py @@ -63,24 +63,24 @@ class L7Irish(RasterDataset): .. versionadded:: 0.5 """ # noqa: E501 - url = "https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz" # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz' # noqa: E501 md5s = { - "austral": "0a34770b992a62abeb88819feb192436", - "boreal": "b7cfdd689a3c2fd2a8d572e1c10ed082", - "mid_latitude_north": "c40abe5ad2487f8ab021cfb954982faa", - "mid_latitude_south": "37abab7f6ebe3d6cf6a3332144145427", - "polar_north": "49d9e616bd715057db9acb1c4d234d45", - "polar_south": "c1503db1cf46d5c37b579190f989e7ec", - "subtropical_north": "a6010de4c50167260de35beead9d6a65", - "subtropical_south": "c37d439df2f05bd7cfe87cf6ff61a690", - "tropical": "d7931419c70f3520a17361d96f1a4810", + 'austral': '0a34770b992a62abeb88819feb192436', + 'boreal': 'b7cfdd689a3c2fd2a8d572e1c10ed082', + 'mid_latitude_north': 'c40abe5ad2487f8ab021cfb954982faa', + 'mid_latitude_south': '37abab7f6ebe3d6cf6a3332144145427', + 'polar_north': '49d9e616bd715057db9acb1c4d234d45', + 'polar_south': 'c1503db1cf46d5c37b579190f989e7ec', + 'subtropical_north': 'a6010de4c50167260de35beead9d6a65', + 'subtropical_south': 'c37d439df2f05bd7cfe87cf6ff61a690', + 'tropical': 'd7931419c70f3520a17361d96f1a4810', } - classes = ["Fill", "Cloud Shadow", "Clear", "Thin Cloud", "Cloud"] + classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud'] # https://landsat.usgs.gov/cloud-validation/cca_irish_2015/L7_Irish_Cloud_Validation_Masks.xml - filename_glob = "L71*.TIF" + filename_glob = 'L71*.TIF' filename_regex = r""" ^L71 (?P\d{3}) @@ -89,15 +89,15 @@ class L7Irish(RasterDataset): (?P\d{8}) \.TIF$ """ - date_format = "%Y%m%d" + date_format = '%Y%m%d' separate_files = False - rgb_bands = ["B30", "B20", "B10"] - all_bands = ["B10", "B20", "B30", "B40", "B50", "B61", "B62", "B70", "B80"] + rgb_bands = ['B30', 'B20', 'B10'] + all_bands = ['B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80'] def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = CRS.from_epsg(3857), res: float | None = None, bands: Sequence[str] = all_bands, @@ -142,7 +142,7 @@ def _verify(self) -> None: # Check if the tar.gz files have already been downloaded assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, "*.tar.gz") + pathname = os.path.join(self.paths, '*.tar.gz') if glob.glob(pathname): self._extract() return @@ -165,7 +165,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, "*.tar.gz") + pathname = os.path.join(self.paths, '*.tar.gz') for tarfile in glob.iglob(pathname): extract_archive(tarfile) @@ -186,16 +186,16 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if not filepaths: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) image = self._merge_files(filepaths, query, self.band_indexes) mask_filepaths = [] for filepath in filepaths: - path, row = os.path.basename(os.path.dirname(filepath)).split("_")[:2] + path, row = os.path.basename(os.path.dirname(filepath)).split('_')[:2] mask_filepath = filepath.replace( - os.path.basename(filepath), f"L7_{path}_{row}_newmask2015.TIF" + os.path.basename(filepath), f'L7_{path}_{row}_newmask2015.TIF' ) mask_filepaths.append(mask_filepath) @@ -206,10 +206,10 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: mask[mask == k] = v sample = { - "crs": self.crs, - "bbox": query, - "image": image.float(), - "mask": mask.long(), + 'crs': self.crs, + 'bbox': query, + 'image': image.float(), + 'mask': mask.long(), } if self.transforms is not None: @@ -243,34 +243,34 @@ def plot( else: raise RGBBandsMissingError() - image = sample["image"][rgb_indices].permute(1, 2, 0) + image = sample['image'][rgb_indices].permute(1, 2, 0) # Stretch to the full range image = (image - image.min()) / (image.max() - image.min()) - mask = sample["mask"].numpy().astype("uint8").squeeze() + mask = sample['mask'].numpy().astype('uint8').squeeze() num_panels = 2 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - predictions = sample["prediction"].numpy().astype("uint8").squeeze() + predictions = sample['prediction'].numpy().astype('uint8').squeeze() num_panels += 1 - kwargs = {"cmap": "gray", "vmin": 0, "vmax": 4, "interpolation": "none"} + kwargs = {'cmap': 'gray', 'vmin': 0, 'vmax': 4, 'interpolation': 'none'} fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5)) axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(mask, **kwargs) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Image") - axs[1].set_title("Mask") + axs[0].set_title('Image') + axs[1].set_title('Mask') if showing_predictions: axs[2].imshow(predictions, **kwargs) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[2].set_title("Predictions") + axs[2].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/l8biome.py b/torchgeo/datasets/l8biome.py index f4fa8047911..fe2be46c1ef 100644 --- a/torchgeo/datasets/l8biome.py +++ b/torchgeo/datasets/l8biome.py @@ -62,23 +62,23 @@ class L8Biome(RasterDataset): .. versionadded:: 0.5 """ # noqa: E501 - url = "https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz" # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz' # noqa: E501 md5s = { - "barren": "0eb691822d03dabd4f5ea8aadd0b41c3", - "forest": "4a5645596f6bb8cea44677f746ec676e", - "grass_crops": "a69ed5d6cb227c5783f026b9303cdd3c", - "shrubland": "19df1d0a604faf6aab46d6a7a5e6da6a", - "snow_ice": "af8b189996cf3f578e40ee12e1f8d0c9", - "urban": "5450195ed95ee225934b9827bea1e8b0", - "water": "a81153415eb662c9e6812c2a8e38c743", - "wetlands": "1f86cc354631ca9a50ce54b7cab3f557", + 'barren': '0eb691822d03dabd4f5ea8aadd0b41c3', + 'forest': '4a5645596f6bb8cea44677f746ec676e', + 'grass_crops': 'a69ed5d6cb227c5783f026b9303cdd3c', + 'shrubland': '19df1d0a604faf6aab46d6a7a5e6da6a', + 'snow_ice': 'af8b189996cf3f578e40ee12e1f8d0c9', + 'urban': '5450195ed95ee225934b9827bea1e8b0', + 'water': 'a81153415eb662c9e6812c2a8e38c743', + 'wetlands': '1f86cc354631ca9a50ce54b7cab3f557', } - classes = ["Fill", "Cloud Shadow", "Clear", "Thin Cloud", "Cloud"] + classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud'] # https://gisgeography.com/landsat-file-naming-convention/ - filename_glob = "LC8*.TIF" + filename_glob = 'LC8*.TIF' filename_regex = r""" ^LC8 (?P\d{3}) @@ -88,11 +88,11 @@ class L8Biome(RasterDataset): (?P\d{2}) \.TIF$ """ - date_format = "%Y%j" + date_format = '%Y%j' separate_files = False - rgb_bands = ["B4", "B3", "B2"] - all_bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B9", "B10", "B11"] + rgb_bands = ['B4', 'B3', 'B2'] + all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11'] def __init__( self, @@ -141,7 +141,7 @@ def _verify(self) -> None: # Check if the tar.gz files have already been downloaded assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, "*.tar.gz") + pathname = os.path.join(self.paths, '*.tar.gz') if glob.glob(pathname): self._extract() return @@ -164,7 +164,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, "*.tar.gz") + pathname = os.path.join(self.paths, '*.tar.gz') for tarfile in glob.iglob(pathname): extract_archive(tarfile) @@ -185,14 +185,14 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if not filepaths: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) image = self._merge_files(filepaths, query, self.band_indexes) mask_filepaths = [] for filepath in filepaths: - mask_filepath = filepath.replace(".TIF", "_fixedmask.TIF") + mask_filepath = filepath.replace('.TIF', '_fixedmask.TIF') mask_filepaths.append(mask_filepath) mask = self._merge_files(mask_filepaths, query) @@ -202,10 +202,10 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: mask[mask == k] = v sample = { - "crs": self.crs, - "bbox": query, - "image": image.float(), - "mask": mask.long(), + 'crs': self.crs, + 'bbox': query, + 'image': image.float(), + 'mask': mask.long(), } if self.transforms is not None: @@ -239,34 +239,34 @@ def plot( else: raise RGBBandsMissingError() - image = sample["image"][rgb_indices].permute(1, 2, 0) + image = sample['image'][rgb_indices].permute(1, 2, 0) # Stretch to the full range image = (image - image.min()) / (image.max() - image.min()) - mask = sample["mask"].numpy().astype("uint8").squeeze() + mask = sample['mask'].numpy().astype('uint8').squeeze() num_panels = 2 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - predictions = sample["prediction"].numpy().astype("uint8").squeeze() + predictions = sample['prediction'].numpy().astype('uint8').squeeze() num_panels += 1 - kwargs = {"cmap": "gray", "vmin": 0, "vmax": 4, "interpolation": "none"} + kwargs = {'cmap': 'gray', 'vmin': 0, 'vmax': 4, 'interpolation': 'none'} fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5)) axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(mask, **kwargs) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Image") - axs[1].set_title("Mask") + axs[0].set_title('Image') + axs[1].set_title('Mask') if showing_predictions: axs[2].imshow(predictions, **kwargs) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[2].set_title("Predictions") + axs[2].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index b35a78280e1..33740dd4cec 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -66,10 +66,10 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC): .. versionadded:: 0.5 """ - url = "https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip" - filename = "landcover.ai.v1.zip" - md5 = "3268c89070e8734b4e91d531c0617e03" - classes = ["Background", "Building", "Woodland", "Water", "Road"] + url = 'https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip' + filename = 'landcover.ai.v1.zip' + md5 = '3268c89070e8734b4e91d531c0617e03' + classes = ['Background', 'Building', 'Woodland', 'Water', 'Road'] cmap = { 0: (0, 0, 0, 0), 1: (97, 74, 74, 255), @@ -79,7 +79,7 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC): } def __init__( - self, root: str = "data", download: bool = False, checksum: bool = False + self, root: str = 'data', download: bool = False, checksum: bool = False ) -> None: """Initialize a new LandCover.ai dataset instance. @@ -166,31 +166,31 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - image = np.rollaxis(sample["image"].numpy().astype("uint8").squeeze(), 0, 3) - mask = sample["mask"].numpy().astype("uint8").squeeze() + image = np.rollaxis(sample['image'].numpy().astype('uint8').squeeze(), 0, 3) + mask = sample['mask'].numpy().astype('uint8').squeeze() num_panels = 2 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - predictions = sample["prediction"].numpy() + predictions = sample['prediction'].numpy() num_panels += 1 fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5)) axs[0].imshow(image) - axs[0].axis("off") - axs[1].imshow(mask, vmin=0, vmax=4, cmap=self._lc_cmap, interpolation="none") - axs[1].axis("off") + axs[0].axis('off') + axs[1].imshow(mask, vmin=0, vmax=4, cmap=self._lc_cmap, interpolation='none') + axs[1].axis('off') if show_titles: - axs[0].set_title("Image") - axs[1].set_title("Mask") + axs[0].set_title('Image') + axs[1].set_title('Mask') if showing_predictions: axs[2].imshow( - predictions, vmin=0, vmax=4, cmap=self._lc_cmap, interpolation="none" + predictions, vmin=0, vmax=4, cmap=self._lc_cmap, interpolation='none' ) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[2].set_title("Predictions") + axs[2].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) @@ -205,12 +205,12 @@ class LandCoverAIGeo(LandCoverAIBase, RasterDataset): .. versionadded:: 0.5 """ - filename_glob = os.path.join("images", "*.tif") - filename_regex = ".*tif" + filename_glob = os.path.join('images', '*.tif') + filename_regex = '.*tif' def __init__( self, - root: str = "data", + root: str = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -240,8 +240,8 @@ def __init__( def _verify_data(self) -> bool: """Verify if the images and masks are present.""" - img_query = os.path.join(self.root, "images", "*.tif") - mask_query = os.path.join(self.root, "masks", "*.tif") + img_query = os.path.join(self.root, 'images', '*.tif') + mask_query = os.path.join(self.root, 'masks', '*.tif') images = glob.glob(img_query) masks = glob.glob(mask_query) return len(images) > 0 and len(images) == len(masks) @@ -260,20 +260,20 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """ hits = self.index.intersection(tuple(query), objects=True) img_filepaths = cast(list[str], [hit.object for hit in hits]) - mask_filepaths = [path.replace("images", "masks") for path in img_filepaths] + mask_filepaths = [path.replace('images', 'masks') for path in img_filepaths] if not img_filepaths: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) img = self._merge_files(img_filepaths, query, self.band_indexes) mask = self._merge_files(mask_filepaths, query, self.band_indexes) sample = { - "crs": self.crs, - "bbox": query, - "image": img.float(), - "mask": mask.long(), + 'crs': self.crs, + 'bbox': query, + 'image': img.float(), + 'mask': mask.long(), } if self.transforms is not None: @@ -295,12 +295,12 @@ class LandCoverAI(LandCoverAIBase, NonGeoDataset): the train/val/test split """ - sha256 = "15ee4ca9e3fd187957addfa8f0d74ac31bc928a966f76926e11b3c33ea76daa1" + sha256 = '15ee4ca9e3fd187957addfa8f0d74ac31bc928a966f76926e11b3c33ea76daa1' def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -319,13 +319,13 @@ def __init__( AssertionError: if ``split`` argument is invalid DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert split in ["train", "val", "test"] + assert split in ['train', 'val', 'test'] super().__init__(root, download, checksum) self.transforms = transforms self.split = split - with open(os.path.join(self.root, split + ".txt")) as f: + with open(os.path.join(self.root, split + '.txt')) as f: self.ids = f.readlines() def __getitem__(self, index: int) -> dict[str, Tensor]: @@ -338,7 +338,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ id_ = self.ids[index].rstrip() - sample = {"image": self._load_image(id_), "mask": self._load_target(id_)} + sample = {'image': self._load_image(id_), 'mask': self._load_target(id_)} if self.transforms is not None: sample = self.transforms(sample) @@ -363,9 +363,9 @@ def _load_image(self, id_: str) -> Tensor: Returns: the image """ - filename = os.path.join(self.root, "output", id_ + ".jpg") + filename = os.path.join(self.root, 'output', id_ + '.jpg') with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img) + array: 'np.typing.NDArray[np.int_]' = np.array(img) tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -381,16 +381,16 @@ def _load_target(self, id_: str) -> Tensor: Returns: the target mask """ - filename = os.path.join(self.root, "output", id_ + "_m.png") + filename = os.path.join(self.root, 'output', id_ + '_m.png') with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("L")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('L')) tensor = torch.from_numpy(array).long() return tensor def _verify_data(self) -> bool: """Verify if the images and masks are present.""" - img_query = os.path.join(self.root, "output", "*_*.jpg") - mask_query = os.path.join(self.root, "output", "*_*_m.png") + img_query = os.path.join(self.root, 'output', '*_*.jpg') + mask_query = os.path.join(self.root, 'output', '*_*_m.png') images = glob.glob(img_query) masks = glob.glob(mask_query) return len(images) > 0 and len(images) == len(masks) @@ -407,7 +407,7 @@ def _extract(self) -> None: # Always check the sha256 of this file before executing # to avoid malicious code injection with working_dir(self.root): - with open("split.py") as f: - split = f.read().encode("utf-8") + with open('split.py') as f: + split = f.read().encode('utf-8') assert hashlib.sha256(split).hexdigest() == self.sha256 exec(split) diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index 503580cfd17..8b7d269c893 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -59,7 +59,7 @@ def default_bands(self) -> list[str]: def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float | None = None, bands: Sequence[str] | None = None, @@ -120,7 +120,7 @@ def plot( else: raise RGBBandsMissingError() - image = sample["image"][rgb_indices].permute(1, 2, 0).float() + image = sample['image'][rgb_indices].permute(1, 2, 0).float() # Stretch to the full range image = (image - image.min()) / (image.max() - image.min()) @@ -128,10 +128,10 @@ def plot( fig, ax = plt.subplots(1, 1, figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: - ax.set_title("Image") + ax.set_title('Image') if suptitle is not None: plt.suptitle(suptitle) @@ -142,73 +142,73 @@ def plot( class Landsat1(Landsat): """Landsat 1 Multispectral Scanner (MSS).""" - filename_glob = "LM01_*_{}.*" + filename_glob = 'LM01_*_{}.*' - default_bands = ["B4", "B5", "B6", "B7"] - rgb_bands = ["B6", "B5", "B4"] + default_bands = ['B4', 'B5', 'B6', 'B7'] + rgb_bands = ['B6', 'B5', 'B4'] class Landsat2(Landsat1): """Landsat 2 Multispectral Scanner (MSS).""" - filename_glob = "LM02_*_{}.*" + filename_glob = 'LM02_*_{}.*' class Landsat3(Landsat1): """Landsat 3 Multispectral Scanner (MSS).""" - filename_glob = "LM03_*_{}.*" + filename_glob = 'LM03_*_{}.*' class Landsat4MSS(Landsat): """Landsat 4 Multispectral Scanner (MSS).""" - filename_glob = "LM04_*_{}.*" + filename_glob = 'LM04_*_{}.*' - default_bands = ["B1", "B2", "B3", "B4"] - rgb_bands = ["B3", "B2", "B1"] + default_bands = ['B1', 'B2', 'B3', 'B4'] + rgb_bands = ['B3', 'B2', 'B1'] class Landsat4TM(Landsat): """Landsat 4 Thematic Mapper (TM).""" - filename_glob = "LT04_*_{}.*" + filename_glob = 'LT04_*_{}.*' - default_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"] - rgb_bands = ["SR_B3", "SR_B2", "SR_B1"] + default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] + rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1'] class Landsat5MSS(Landsat4MSS): """Landsat 4 Multispectral Scanner (MSS).""" - filename_glob = "LM04_*_{}.*" + filename_glob = 'LM04_*_{}.*' class Landsat5TM(Landsat4TM): """Landsat 5 Thematic Mapper (TM).""" - filename_glob = "LT05_*_{}.*" + filename_glob = 'LT05_*_{}.*' class Landsat7(Landsat): """Landsat 7 Enhanced Thematic Mapper Plus (ETM+).""" - filename_glob = "LE07_*_{}.*" + filename_glob = 'LE07_*_{}.*' - default_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"] - rgb_bands = ["SR_B3", "SR_B2", "SR_B1"] + default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] + rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1'] class Landsat8(Landsat): """Landsat 8 Operational Land Imager (OLI) and Thermal Infrared Sensor (TIRS).""" - filename_glob = "LC08_*_{}.*" + filename_glob = 'LC08_*_{}.*' - default_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"] - rgb_bands = ["SR_B4", "SR_B3", "SR_B2"] + default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] + rgb_bands = ['SR_B4', 'SR_B3', 'SR_B2'] class Landsat9(Landsat8): """Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2).""" # noqa: E501 - filename_glob = "LC09_*_{}.*" + filename_glob = 'LC09_*_{}.*' diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index b2f1891c069..051c5733c0d 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -30,12 +30,12 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC): """ splits: list[str] | dict[str, dict[str, str]] - directories = ["A", "B", "label"] + directories = ['A', 'B', 'label'] def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -79,10 +79,10 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - image1 = self._load_image(files["image1"]) - image2 = self._load_image(files["image2"]) - mask = self._load_target(files["mask"]) - sample = {"image1": image1, "image2": image2, "mask": mask} + image1 = self._load_image(files['image1']) + image2 = self._load_image(files['image2']) + mask = self._load_target(files['mask']) + sample = {'image1': image1, 'image2': image2, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -108,7 +108,7 @@ def _load_image(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('RGB')) tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -125,7 +125,7 @@ def _load_target(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("L")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('L')) tensor = torch.from_numpy(array) tensor = torch.clamp(tensor, min=0, max=1) tensor = tensor.to(torch.long) @@ -151,34 +151,34 @@ def plot( """ ncols = 3 - image1 = sample["image1"].permute(1, 2, 0).numpy() + image1 = sample['image1'].permute(1, 2, 0).numpy() image1 = percentile_normalization(image1, axis=(0, 1)) - image2 = sample["image2"].permute(1, 2, 0).numpy() + image2 = sample['image2'].permute(1, 2, 0).numpy() image2 = percentile_normalization(image2, axis=(0, 1)) - if "prediction" in sample: + if 'prediction' in sample: ncols += 1 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) axs[0].imshow(image1) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(image2) - axs[1].axis("off") - axs[2].imshow(sample["mask"], cmap="gray", interpolation="none") - axs[2].axis("off") + axs[1].axis('off') + axs[2].imshow(sample['mask'], cmap='gray', interpolation='none') + axs[2].axis('off') - if "prediction" in sample: - axs[3].imshow(sample["prediction"], cmap="gray", interpolation="none") - axs[3].axis("off") + if 'prediction' in sample: + axs[3].imshow(sample['prediction'], cmap='gray', interpolation='none') + axs[3].axis('off') if show_titles: - axs[3].set_title("Prediction") + axs[3].set_title('Prediction') if show_titles: - axs[0].set_title("Image 1") - axs[1].set_title("Image 2") - axs[2].set_title("Mask") + axs[0].set_title('Image 1') + axs[1].set_title('Image 2') + axs[2].set_title('Mask') if suptitle is not None: plt.suptitle(suptitle) @@ -241,20 +241,20 @@ class LEVIRCD(LEVIRCDBase): """ splits = { - "train": { - "url": "https://drive.google.com/file/d/18GuoCuBn48oZKAlEo-LrNwABrFhVALU-", - "filename": "train.zip", - "md5": "a638e71f480628652dea78d8544307e4", + 'train': { + 'url': 'https://drive.google.com/file/d/18GuoCuBn48oZKAlEo-LrNwABrFhVALU-', + 'filename': 'train.zip', + 'md5': 'a638e71f480628652dea78d8544307e4', }, - "val": { - "url": "https://drive.google.com/file/d/1BqSt4ueO7XAyQ_84mUjswUSJt13ZBuzG", - "filename": "val.zip", - "md5": "f7b857978524f9aa8c3bf7f94e3047a4", + 'val': { + 'url': 'https://drive.google.com/file/d/1BqSt4ueO7XAyQ_84mUjswUSJt13ZBuzG', + 'filename': 'val.zip', + 'md5': 'f7b857978524f9aa8c3bf7f94e3047a4', }, - "test": { - "url": "https://drive.google.com/file/d/1jj3qJD_grJlgIhUWO09zibRGJe0R4Tn0", - "filename": "test.zip", - "md5": "07d5dd89e46f5c1359e2eca746989ed9", + 'test': { + 'url': 'https://drive.google.com/file/d/1jj3qJD_grJlgIhUWO09zibRGJe0R4Tn0', + 'filename': 'test.zip', + 'md5': '07d5dd89e46f5c1359e2eca746989ed9', }, } @@ -268,9 +268,9 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: Returns: list of dicts containing paths for each pair of image1, image2, mask """ - images1 = sorted(glob.glob(os.path.join(root, "A", f"{split}*.png"))) - images2 = sorted(glob.glob(os.path.join(root, "B", f"{split}*.png"))) - masks = sorted(glob.glob(os.path.join(root, "label", f"{split}*.png"))) + images1 = sorted(glob.glob(os.path.join(root, 'A', f'{split}*.png'))) + images2 = sorted(glob.glob(os.path.join(root, 'B', f'{split}*.png'))) + masks = sorted(glob.glob(os.path.join(root, 'label', f'{split}*.png'))) files = [] for image1, image2, mask in zip(images1, images2, masks): @@ -293,15 +293,15 @@ def _check_integrity(self) -> bool: def _download(self) -> None: """Download the dataset and extract it.""" if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return for split in self.splits: download_and_extract_archive( - self.splits[split]["url"], + self.splits[split]['url'], self.root, - filename=self.splits[split]["filename"], - md5=self.splits[split]["md5"] if self.checksum else None, + filename=self.splits[split]['filename'], + md5=self.splits[split]['md5'] if self.checksum else None, ) @@ -335,11 +335,11 @@ class LEVIRCDPlus(LEVIRCDBase): .. versionchanged:: 0.6 """ - url = "https://drive.google.com/file/d/1JamSsxiytXdzAIk6VDVWfc-OsX-81U81" - md5 = "1adf156f628aa32fb2e8fe6cada16c04" - filename = "LEVIR-CD+.zip" - directory = "LEVIR-CD+" - splits = ["train", "test"] + url = 'https://drive.google.com/file/d/1JamSsxiytXdzAIk6VDVWfc-OsX-81U81' + md5 = '1adf156f628aa32fb2e8fe6cada16c04' + filename = 'LEVIR-CD+.zip' + directory = 'LEVIR-CD+' + splits = ['train', 'test'] def _load_files(self, root: str, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. @@ -352,12 +352,12 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: list of dicts containing paths for each pair of image1, image2, mask """ files = [] - images = glob.glob(os.path.join(root, self.directory, split, "A", "*.png")) + images = glob.glob(os.path.join(root, self.directory, split, 'A', '*.png')) images = sorted(os.path.basename(image) for image in images) for image in images: - image1 = os.path.join(root, self.directory, split, "A", image) - image2 = os.path.join(root, self.directory, split, "B", image) - mask = os.path.join(root, self.directory, split, "label", image) + image1 = os.path.join(root, self.directory, split, 'A', image) + image2 = os.path.join(root, self.directory, split, 'B', image) + mask = os.path.join(root, self.directory, split, 'label', image) files.append(dict(image1=image1, image2=image2, mask=mask)) return files @@ -376,7 +376,7 @@ def _check_integrity(self) -> bool: def _download(self) -> None: """Download the dataset and extract it.""" if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return download_and_extract_archive( diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index 4e9b150251a..1583e460b01 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -56,43 +56,43 @@ class LoveDA(NonGeoDataset): .. versionadded:: 0.2 """ - scenes = ["urban", "rural"] - splits = ["train", "val", "test"] + scenes = ['urban', 'rural'] + splits = ['train', 'val', 'test'] info_dict = { - "train": { - "url": "https://zenodo.org/record/5706578/files/Train.zip?download=1", - "filename": "Train.zip", - "md5": "de2b196043ed9b4af1690b3f9a7d558f", + 'train': { + 'url': 'https://zenodo.org/record/5706578/files/Train.zip?download=1', + 'filename': 'Train.zip', + 'md5': 'de2b196043ed9b4af1690b3f9a7d558f', }, - "val": { - "url": "https://zenodo.org/record/5706578/files/Val.zip?download=1", - "filename": "Val.zip", - "md5": "84cae2577468ff0b5386758bb386d31d", + 'val': { + 'url': 'https://zenodo.org/record/5706578/files/Val.zip?download=1', + 'filename': 'Val.zip', + 'md5': '84cae2577468ff0b5386758bb386d31d', }, - "test": { - "url": "https://zenodo.org/record/5706578/files/Test.zip?download=1", - "filename": "Test.zip", - "md5": "a489be0090465e01fb067795d24e6b47", + 'test': { + 'url': 'https://zenodo.org/record/5706578/files/Test.zip?download=1', + 'filename': 'Test.zip', + 'md5': 'a489be0090465e01fb067795d24e6b47', }, } classes = [ - "background", - "building", - "road", - "water", - "barren", - "forest", - "agriculture", - "no-data", + 'background', + 'building', + 'road', + 'water', + 'barren', + 'forest', + 'agriculture', + 'no-data', ] def __init__( self, - root: str = "data", - split: str = "train", - scene: list[str] = ["urban", "rural"], + root: str = 'data', + split: str = 'train', + scene: list[str] = ['urban', 'rural'], transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -124,9 +124,9 @@ def __init__( self.transforms = transforms self.checksum = checksum - self.url = self.info_dict[self.split]["url"] - self.filename = self.info_dict[self.split]["filename"] - self.md5 = self.info_dict[self.split]["md5"] + self.url = self.info_dict[self.split]['url'] + self.filename = self.info_dict[self.split]['filename'] + self.md5 = self.info_dict[self.split]['md5'] self.directory = os.path.join(self.root, split.capitalize()) self.scene_paths = [ @@ -152,13 +152,13 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: and mask of dimension 1024x1024 """ files = self.files[index] - image = self._load_image(files["image"]) + image = self._load_image(files['image']) - if self.split != "test": - mask = self._load_target(files["mask"]) - sample = {"image": image, "mask": mask} + if self.split != 'test': + mask = self._load_target(files['mask']) + sample = {'image': image, 'mask': mask} else: - sample = {"image": image} + sample = {'image': image} if self.transforms is not None: sample = self.transforms(sample) @@ -184,12 +184,12 @@ def _load_files(self, scene_paths: list[str], split: str) -> list[dict[str, str] images = [] for s in scene_paths: - images.extend(glob.glob(os.path.join(s, "images_png", "*.png"))) + images.extend(glob.glob(os.path.join(s, 'images_png', '*.png'))) images = sorted(images) - if self.split != "test": - masks = [image.replace("images_png", "masks_png") for image in images] + if self.split != 'test': + masks = [image.replace('images_png', 'masks_png') for image in images] files = [dict(image=image, mask=mask) for image, mask in zip(images, masks)] else: files = [dict(image=image) for image in images] @@ -207,7 +207,7 @@ def _load_image(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('RGB')) tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -224,7 +224,7 @@ def _load_target(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("L")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('L')) tensor = torch.from_numpy(array) tensor = tensor.to(torch.long) return tensor @@ -244,7 +244,7 @@ def _check_integrity(self) -> bool: def _download(self) -> None: """Download the dataset and extract it.""" if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return download_and_extract_archive( @@ -264,23 +264,23 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure Returns: a matplotlib Figure with the rendered sample """ - if self.split != "test": - image, mask = sample["image"], sample["mask"] + if self.split != 'test': + image, mask = sample['image'], sample['mask'] ncols = 2 else: - image = sample["image"] + image = sample['image'] ncols = 1 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 10)) - if self.split != "test": + if self.split != 'test': axs[0].imshow(image.permute(1, 2, 0)) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(mask) - axs[1].axis("off") + axs[1].axis('off') else: axs.imshow(image.permute(1, 2, 0)) - axs.axis("off") + axs.axis('off') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/mapinwild.py b/torchgeo/datasets/mapinwild.py index 849d66485d8..cd1f8c5c9df 100644 --- a/torchgeo/datasets/mapinwild.py +++ b/torchgeo/datasets/mapinwild.py @@ -53,41 +53,41 @@ class MapInWild(NonGeoDataset): .. versionadded:: 0.5 """ - url = "https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/" # noqa: E501 + url = 'https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/' # noqa: E501 modality_urls = { - "esa_wc": {"esa_wc/ESA_WC.zip"}, - "viirs": {"viirs/VIIRS.zip"}, - "mask": {"mask/mask.zip"}, - "s1": {"s1/s1_part1.zip", "s1/s1_part2.zip"}, - "s2_temporal_subset": { - "s2_temporal_subset/s2_temporal_subset_part1.zip", - "s2_temporal_subset/s2_temporal_subset_part2.zip", + 'esa_wc': {'esa_wc/ESA_WC.zip'}, + 'viirs': {'viirs/VIIRS.zip'}, + 'mask': {'mask/mask.zip'}, + 's1': {'s1/s1_part1.zip', 's1/s1_part2.zip'}, + 's2_temporal_subset': { + 's2_temporal_subset/s2_temporal_subset_part1.zip', + 's2_temporal_subset/s2_temporal_subset_part2.zip', }, - "s2_autumn": {"s2_autumn/s2_autumn_part1.zip", "s2_autumn/s2_autumn_part2.zip"}, - "s2_spring": {"s2_spring/s2_spring_part1.zip", "s2_spring/s2_spring_part2.zip"}, - "s2_summer": {"s2_summer/s2_summer_part1.zip", "s2_summer/s2_summer_part2.zip"}, - "s2_winter": {"s2_winter/s2_winter_part1.zip", "s2_winter/s2_winter_part2.zip"}, - "split_IDs": {"split_IDs/split_IDs.csv"}, + 's2_autumn': {'s2_autumn/s2_autumn_part1.zip', 's2_autumn/s2_autumn_part2.zip'}, + 's2_spring': {'s2_spring/s2_spring_part1.zip', 's2_spring/s2_spring_part2.zip'}, + 's2_summer': {'s2_summer/s2_summer_part1.zip', 's2_summer/s2_summer_part2.zip'}, + 's2_winter': {'s2_winter/s2_winter_part1.zip', 's2_winter/s2_winter_part2.zip'}, + 'split_IDs': {'split_IDs/split_IDs.csv'}, } md5s = { - "ESA_WC.zip": "72b2ee578fe10f0df85bdb7f19311c92", - "VIIRS.zip": "4eff014bae127fe536f8a5f17d89ecb4", - "mask.zip": "87c83a23a73998ad60d448d240b66225", - "s1_part1.zip": "d8a911f5c76b50eb0760b8f0047e4674", - "s1_part2.zip": "a30369d17c62d2af5aa52a4189590e3c", - "s2_temporal_subset_part1.zip": "78c2d05514458a036fe133f1e2f11d2a", - "s2_temporal_subset_part2.zip": "076cd3bd00eb5b7f5d80c9e0a0de0275", - "s2_autumn_part1.zip": "6ee7d1ac44b5107e3663636269aecf68", - "s2_autumn_part2.zip": "4fc5e1d5c772421dba553722433ac3b9", - "s2_spring_part1.zip": "2a89687d8fafa7fc7f5e641bfa97d472", - "s2_spring_part2.zip": "5845dcae0ab3cdc174b7c41edd4283a9", - "s2_summer_part1.zip": "73ca8291d3f4fb7533636220a816bb71", - "s2_summer_part2.zip": "5b5816bbd32987619bf72cde5cacd032", - "s2_winter_part1.zip": "ca958f7cd98e37cb59d6f3877573ee6d", - "s2_winter_part2.zip": "e7aacb0806d6d619b6abc408e6b09fdc", - "split_IDs.csv": "cb5c6c073702acee23544e1e6fe5856f", + 'ESA_WC.zip': '72b2ee578fe10f0df85bdb7f19311c92', + 'VIIRS.zip': '4eff014bae127fe536f8a5f17d89ecb4', + 'mask.zip': '87c83a23a73998ad60d448d240b66225', + 's1_part1.zip': 'd8a911f5c76b50eb0760b8f0047e4674', + 's1_part2.zip': 'a30369d17c62d2af5aa52a4189590e3c', + 's2_temporal_subset_part1.zip': '78c2d05514458a036fe133f1e2f11d2a', + 's2_temporal_subset_part2.zip': '076cd3bd00eb5b7f5d80c9e0a0de0275', + 's2_autumn_part1.zip': '6ee7d1ac44b5107e3663636269aecf68', + 's2_autumn_part2.zip': '4fc5e1d5c772421dba553722433ac3b9', + 's2_spring_part1.zip': '2a89687d8fafa7fc7f5e641bfa97d472', + 's2_spring_part2.zip': '5845dcae0ab3cdc174b7c41edd4283a9', + 's2_summer_part1.zip': '73ca8291d3f4fb7533636220a816bb71', + 's2_summer_part2.zip': '5b5816bbd32987619bf72cde5cacd032', + 's2_winter_part1.zip': 'ca958f7cd98e37cb59d6f3877573ee6d', + 's2_winter_part2.zip': 'e7aacb0806d6d619b6abc408e6b09fdc', + 'split_IDs.csv': 'cb5c6c073702acee23544e1e6fe5856f', } mask_cmap = {1: (0, 153, 0), 0: (255, 255, 255)} @@ -108,9 +108,9 @@ class MapInWild(NonGeoDataset): def __init__( self, - root: str = "data", - modality: list[str] = ["mask", "esa_wc", "viirs", "s2_summer"], - split: str = "train", + root: str = 'data', + modality: list[str] = ['mask', 'esa_wc', 'viirs', 's2_summer'], + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -131,7 +131,7 @@ def __init__( AssertionError: if ``split`` argument is invalid DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert split in ["train", "validation", "test"] + assert split in ['train', 'validation', 'test'] self.checksum = checksum self.root = root @@ -139,7 +139,7 @@ def __init__( self.modality = modality self.download = download - modality.append("split_IDs") + modality.append('split_IDs') for mode in modality: for modality_link in self.modality_urls[mode]: modality_url = os.path.join(self.url, modality_link) @@ -156,15 +156,15 @@ def __init__( self._merge_parts(mode) # Masks will be loaded seperately in the :meth:`__getitem__` - if "mask" in self.modality: - self.modality.remove("mask") + if 'mask' in self.modality: + self.modality.remove('mask') # Split IDs has been downloaded and is not needed in the list - if "split_IDs" in self.modality: - self.modality.remove("split_IDs") + if 'split_IDs' in self.modality: + self.modality.remove('split_IDs') - if os.path.exists(os.path.join(self.root, "split_IDs.csv")): - split_dataframe = pd.read_csv(os.path.join(self.root, "split_IDs.csv")) + if os.path.exists(os.path.join(self.root, 'split_IDs.csv')): + split_dataframe = pd.read_csv(os.path.join(self.root, 'split_IDs.csv')) self.ids = split_dataframe[split].dropna().values.tolist() self.ids = list(map(int, self.ids)) @@ -180,17 +180,17 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: list_modalities = [] id = self.ids[index] - mask = self._load_raster(id, "mask") + mask = self._load_raster(id, 'mask') mask[mask != 0] = 1 for mode in self.modality: - mode = mode.upper() if mode in ["esa_wc", "viirs"] else mode + mode = mode.upper() if mode in ['esa_wc', 'viirs'] else mode data = self._load_raster(id, mode) list_modalities.append(data) image = torch.cat(list_modalities, dim=0) - sample: dict[str, Tensor] = {"image": image, "mask": mask} + sample: dict[str, Tensor] = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -215,9 +215,9 @@ def _load_raster(self, filename: int, source: str) -> Tensor: Returns: the raster image or target """ - with rasterio.open(os.path.join(self.root, source, f"{filename}.tif")) as f: + with rasterio.open(os.path.join(self.root, source, f'{filename}.tif')) as f: raw_array = f.read() - array: "np.typing.NDArray[np.int_]" = np.stack(raw_array, axis=0) + array: 'np.typing.NDArray[np.int_]' = np.stack(raw_array, axis=0) if array.dtype == np.uint16: array = array.astype(np.int32) tensor = torch.from_numpy(array).float() @@ -230,11 +230,11 @@ def _verify(self, url: str, md5: str | None = None) -> None: url: url to the file md5: md5 of the file to be verified """ - modality_folder_name = url.split("/")[-1] - mod_fold_no_ext = modality_folder_name.split(".")[0] + modality_folder_name = url.split('/')[-1] + mod_fold_no_ext = modality_folder_name.split('.')[0] modality_path = os.path.join(self.root, mod_fold_no_ext) split_path = os.path.join(self.root, modality_folder_name) - if mod_fold_no_ext == "split_IDs": + if mod_fold_no_ext == 'split_IDs': modality_path = split_path # Check if the files already exist @@ -242,10 +242,10 @@ def _verify(self, url: str, md5: str | None = None) -> None: return # Check if the zip files have already been downloaded, if so, extract - filepath = os.path.join(self.root, url.split("/")[-1]) - if os.path.isfile(filepath) and filepath.endswith(".zip"): + filepath = os.path.join(self.root, url.split('/')[-1]) + if os.path.isfile(filepath) and filepath.endswith('.zip'): if self.checksum and not check_integrity(filepath, md5): - raise RuntimeError("Dataset found, but corrupted.") + raise RuntimeError('Dataset found, but corrupted.') self._extract(url) return @@ -255,7 +255,7 @@ def _verify(self, url: str, md5: str | None = None) -> None: # Download the dataset self._download(url, md5) - if not url.endswith(".csv"): + if not url.endswith('.csv'): self._extract(url) def _download(self, url: str, md5: str | None) -> None: @@ -295,8 +295,8 @@ def _merge_parts(self, modality: str) -> None: # List of source folders source_folders = [ - os.path.join(self.root, modality + "_part1"), - os.path.join(self.root, modality + "_part2"), + os.path.join(self.root, modality + '_part1'), + os.path.join(self.root, modality + '_part2'), ] # Move files from each source folder to the new 'modality' folder @@ -309,7 +309,7 @@ def _merge_parts(self, modality: str) -> None: def _convert_to_color( self, arr_2d: Tensor, cmap: dict[int, tuple[int, int, int]] - ) -> "np.typing.NDArray[np.uint8]": + ) -> 'np.typing.NDArray[np.uint8]': """Numeric labels to RGB-color encoding. Args: @@ -342,22 +342,22 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - modality_channels = defaultdict(lambda: 10, {"viirs": 1, "esa_wc": 1, "s1": 2}) + modality_channels = defaultdict(lambda: 10, {'viirs': 1, 'esa_wc': 1, 's1': 2}) start_idx = 0 split_images = {} for modality in self.modality: end_idx = start_idx + modality_channels[modality] # Start + n of channels - split_images[modality] = sample["image"][start_idx:end_idx, :, :] # Slicing + split_images[modality] = sample['image'][start_idx:end_idx, :, :] # Slicing start_idx = end_idx # Update the iterator # Prepare the mask - mask = sample["mask"].squeeze() + mask = sample['mask'].squeeze() color_mask = self._convert_to_color(mask, cmap=self.mask_cmap) num_subplots = len(split_images) + 1 # +1 for color_mask - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: num_subplots += 1 @@ -368,35 +368,35 @@ def plot( ax = axs[i] img = np.transpose(image, (1, 2, 0)).squeeze() # Apply transformations based on modality type - if modality.startswith("s2"): + if modality.startswith('s2'): img = img[:, :, [4, 3, 2]] - if modality == "esa_wc": + if modality == 'esa_wc': img = self._convert_to_color(torch.as_tensor(img), cmap=self.wc_cmap) - if modality == "s1": + if modality == 's1': img = img[:, :, 0] - if not "esa_wc": + if not 'esa_wc': img = percentile_normalization(img) ax.imshow(img) if show_titles: ax.set_title(modality) - ax.axis("off") + ax.axis('off') # Plot color_mask in its own axis axs[len(split_images)].imshow(color_mask) if show_titles: - axs[len(split_images)].set_title("Annotation") - axs[len(split_images)].axis("off") + axs[len(split_images)].set_title('Annotation') + axs[len(split_images)].axis('off') # If available, plot predictions in a new axis if showing_predictions: - prediction = sample["prediction"].squeeze() + prediction = sample['prediction'].squeeze() color_predictions = self._convert_to_color(prediction, cmap=self.mask_cmap) - axs[-1].imshow(color_predictions, vmin=0, vmax=1, interpolation="none") + axs[-1].imshow(color_predictions, vmin=0, vmax=1, interpolation='none') if show_titles: - axs[-1].set_title("Prediction") - axs[-1].axis("off") + axs[-1].set_title('Prediction') + axs[-1].axis('off') plt.tight_layout() return fig diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index a8ff9f2457c..11262e806ef 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -49,150 +49,150 @@ class MillionAID(NonGeoDataset): """ multi_label_categories = [ - "agriculture_land", - "airport_area", - "apartment", - "apron", - "arable_land", - "bare_land", - "baseball_field", - "basketball_court", - "beach", - "bridge", - "cemetery", - "church", - "commercial_area", - "commercial_land", - "dam", - "desert", - "detached_house", - "dry_field", - "factory_area", - "forest", - "golf_course", - "grassland", - "greenhouse", - "ground_track_field", - "helipad", - "highway_area", - "ice_land", - "industrial_land", - "intersection", - "island", - "lake", - "leisure_land", - "meadow", - "mine", - "mining_area", - "mobile_home_park", - "oil_field", - "orchard", - "paddy_field", - "parking_lot", - "pier", - "port_area", - "power_station", - "public_service_land", - "quarry", - "railway", - "railway_area", - "religious_land", - "residential_land", - "river", - "road", - "rock_land", - "roundabout", - "runway", - "solar_power_plant", - "sparse_shrub_land", - "special_land", - "sports_land", - "stadium", - "storage_tank", - "substation", - "swimming_pool", - "tennis_court", - "terraced_field", - "train_station", - "transportation_land", - "unutilized_land", - "viaduct", - "wastewater_plant", - "water_area", - "wind_turbine", - "woodland", - "works", + 'agriculture_land', + 'airport_area', + 'apartment', + 'apron', + 'arable_land', + 'bare_land', + 'baseball_field', + 'basketball_court', + 'beach', + 'bridge', + 'cemetery', + 'church', + 'commercial_area', + 'commercial_land', + 'dam', + 'desert', + 'detached_house', + 'dry_field', + 'factory_area', + 'forest', + 'golf_course', + 'grassland', + 'greenhouse', + 'ground_track_field', + 'helipad', + 'highway_area', + 'ice_land', + 'industrial_land', + 'intersection', + 'island', + 'lake', + 'leisure_land', + 'meadow', + 'mine', + 'mining_area', + 'mobile_home_park', + 'oil_field', + 'orchard', + 'paddy_field', + 'parking_lot', + 'pier', + 'port_area', + 'power_station', + 'public_service_land', + 'quarry', + 'railway', + 'railway_area', + 'religious_land', + 'residential_land', + 'river', + 'road', + 'rock_land', + 'roundabout', + 'runway', + 'solar_power_plant', + 'sparse_shrub_land', + 'special_land', + 'sports_land', + 'stadium', + 'storage_tank', + 'substation', + 'swimming_pool', + 'tennis_court', + 'terraced_field', + 'train_station', + 'transportation_land', + 'unutilized_land', + 'viaduct', + 'wastewater_plant', + 'water_area', + 'wind_turbine', + 'woodland', + 'works', ] multi_class_categories = [ - "apartment", - "apron", - "bare_land", - "baseball_field", - "bapsketball_court", - "beach", - "bridge", - "cemetery", - "church", - "commercial_area", - "dam", - "desert", - "detached_house", - "dry_field", - "forest", - "golf_course", - "greenhouse", - "ground_track_field", - "helipad", - "ice_land", - "intersection", - "island", - "lake", - "meadow", - "mine", - "mobile_home_park", - "oil_field", - "orchard", - "paddy_field", - "parking_lot", - "pier", - "quarry", - "railway", - "river", - "road", - "rock_land", - "roundabout", - "runway", - "solar_power_plant", - "sparse_shrub_land", - "stadium", - "storage_tank", - "substation", - "swimming_pool", - "tennis_court", - "terraced_field", - "train_station", - "viaduct", - "wastewater_plant", - "wind_turbine", - "works", + 'apartment', + 'apron', + 'bare_land', + 'baseball_field', + 'bapsketball_court', + 'beach', + 'bridge', + 'cemetery', + 'church', + 'commercial_area', + 'dam', + 'desert', + 'detached_house', + 'dry_field', + 'forest', + 'golf_course', + 'greenhouse', + 'ground_track_field', + 'helipad', + 'ice_land', + 'intersection', + 'island', + 'lake', + 'meadow', + 'mine', + 'mobile_home_park', + 'oil_field', + 'orchard', + 'paddy_field', + 'parking_lot', + 'pier', + 'quarry', + 'railway', + 'river', + 'road', + 'rock_land', + 'roundabout', + 'runway', + 'solar_power_plant', + 'sparse_shrub_land', + 'stadium', + 'storage_tank', + 'substation', + 'swimming_pool', + 'tennis_court', + 'terraced_field', + 'train_station', + 'viaduct', + 'wastewater_plant', + 'wind_turbine', + 'works', ] md5s = { - "train": "1b40503cafa9b0601653ca36cd788852", - "test": "51a63ee3eeb1351889eacff349a983d8", + 'train': '1b40503cafa9b0601653ca36cd788852', + 'test': '51a63ee3eeb1351889eacff349a983d8', } - filenames = {"train": "train.zip", "test": "test.zip"} + filenames = {'train': 'train.zip', 'test': 'test.zip'} - tasks = ["multi-class", "multi-label"] - splits = ["train", "test"] + tasks = ['multi-class', 'multi-label'] + splits = ['train', 'test'] def __init__( self, - root: str = "data", - task: str = "multi-class", - split: str = "train", + root: str = 'data', + task: str = 'multi-class', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: @@ -221,7 +221,7 @@ def __init__( self.files = self._load_files(self.root) - self.classes = sorted({cls for f in self.files for cls in f["label"]}) + self.classes = sorted({cls for f in self.files for cls in f['label']}) self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} def __len__(self) -> int: @@ -242,10 +242,10 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - image = self._load_image(files["image"]) - cls_label = [self.class_to_idx[label] for label in files["label"]] + image = self._load_image(files['image']) + cls_label = [self.class_to_idx[label] for label in files['label']] label = torch.tensor(cls_label, dtype=torch.long) - sample = {"image": image, "label": label} + sample = {'image': image, 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -262,18 +262,18 @@ def _load_files(self, root: str) -> list[dict[str, Any]]: list of dicts containing paths for each pair of image, and list of labels """ imgs_no_subcat = list( - glob.glob(os.path.join(root, self.split, "*", "*", "*.jpg")) + glob.glob(os.path.join(root, self.split, '*', '*', '*.jpg')) ) imgs_subcat = list( - glob.glob(os.path.join(root, self.split, "*", "*", "*", "*.jpg")) + glob.glob(os.path.join(root, self.split, '*', '*', '*', '*.jpg')) ) scenes = [p.split(os.sep)[-3] for p in imgs_no_subcat] + [ p.split(os.sep)[-4] for p in imgs_subcat ] - subcategories = ["Missing" for p in imgs_no_subcat] + [ + subcategories = ['Missing' for p in imgs_no_subcat] + [ p.split(os.sep)[-3] for p in imgs_subcat ] @@ -281,9 +281,9 @@ def _load_files(self, root: str) -> list[dict[str, Any]]: p.split(os.sep)[-2] for p in imgs_subcat ] - if self.task == "multi-label": + if self.task == 'multi-label': labels = [ - [sc, sub, c] if sub != "Missing" else [sc, c] + [sc, sub, c] if sub != 'Missing' else [sc, c] for sc, sub, c in zip(scenes, subcategories, classes) ] else: @@ -305,7 +305,7 @@ def _load_image(self, path: str) -> Tensor: the image """ with Image.open(path) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('RGB')) tensor: Tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -321,10 +321,10 @@ def _verify(self) -> None: if os.path.isdir(filepath): return - filepath = os.path.join(self.root, self.split + ".zip") + filepath = os.path.join(self.root, self.split + '.zip') if os.path.isfile(filepath): if self.checksum and not check_integrity(filepath, self.md5s[self.split]): - raise RuntimeError("Dataset found, but corrupted.") + raise RuntimeError('Dataset found, but corrupted.') extract_archive(filepath) return @@ -347,22 +347,22 @@ def plot( a matplotlib Figure with the rendered sample """ - image = np.rollaxis(sample["image"].numpy(), 0, 3) - labels = [self.classes[cast(int, label)] for label in sample["label"]] + image = np.rollaxis(sample['image'].numpy(), 0, 3) + labels = [self.classes[cast(int, label)] for label in sample['label']] - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: prediction_labels = [ - self.classes[cast(int, label)] for label in sample["prediction"] + self.classes[cast(int, label)] for label in sample['prediction'] ] fig, ax = plt.subplots(figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: - title = f"Label: {labels}" + title = f'Label: {labels}' if showing_predictions: - title += f"\nPrediction: {prediction_labels}" + title += f'\nPrediction: {prediction_labels}' ax.set_title(title) if suptitle is not None: diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py index 19f12a29f99..d8185782367 100644 --- a/torchgeo/datasets/naip.py +++ b/torchgeo/datasets/naip.py @@ -32,7 +32,7 @@ class NAIP(RasterDataset): # https://www.nrcs.usda.gov/Internet/FSE_DOCUMENTS/nrcs141p2_015644.pdf # https://planetarycomputer.microsoft.com/dataset/naip#Storage-Documentation - filename_glob = "m_*.*" + filename_glob = 'm_*.*' filename_regex = r""" ^m _(?P\d+) @@ -45,8 +45,8 @@ class NAIP(RasterDataset): """ # Plotting - all_bands = ["R", "G", "B", "NIR"] - rgb_bands = ["R", "G", "B"] + all_bands = ['R', 'G', 'B', 'NIR'] + rgb_bands = ['R', 'G', 'B'] def plot( self, @@ -68,14 +68,14 @@ def plot( Method now takes a sample dict, not a Tensor. Additionally, possible to show subplot titles and/or use a custom suptitle. """ - image = sample["image"][0:3, :, :].permute(1, 2, 0) + image = sample['image'][0:3, :, :].permute(1, 2, 0) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: - ax.set_title("Image") + ax.set_title('Image') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index 4ebd0f6e39e..b31d581d077 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -57,15 +57,15 @@ class NASAMarineDebris(NonGeoDataset): .. versionadded:: 0.2 """ - collection_ids = ["nasa_marine_debris_source", "nasa_marine_debris_labels"] - directories = ["nasa_marine_debris_source", "nasa_marine_debris_labels"] - filenames = ["nasa_marine_debris_source.tar.gz", "nasa_marine_debris_labels.tar.gz"] - md5s = ["fe8698d1e68b3f24f0b86b04419a797d", "d8084f5a72778349e07ac90ec1e1d990"] - class_label = "marine_debris" + collection_ids = ['nasa_marine_debris_source', 'nasa_marine_debris_labels'] + directories = ['nasa_marine_debris_source', 'nasa_marine_debris_labels'] + filenames = ['nasa_marine_debris_source.tar.gz', 'nasa_marine_debris_labels.tar.gz'] + md5s = ['fe8698d1e68b3f24f0b86b04419a797d', 'd8084f5a72778349e07ac90ec1e1d990'] + class_label = 'marine_debris' def __init__( self, - root: str = "data", + root: str = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, api_key: str | None = None, @@ -104,15 +104,15 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and labels at that index """ - image = self._load_image(self.files[index]["image"]) - boxes = self._load_target(self.files[index]["target"]) - sample = {"image": image, "boxes": boxes} + image = self._load_image(self.files[index]['image']) + boxes = self._load_target(self.files[index]['target']) + sample = {'image': image, 'boxes': boxes} # Filter invalid boxes - w_check = (sample["boxes"][:, 2] - sample["boxes"][:, 0]) > 0 - h_check = (sample["boxes"][:, 3] - sample["boxes"][:, 1]) > 0 + w_check = (sample['boxes'][:, 2] - sample['boxes'][:, 0]) > 0 + h_check = (sample['boxes'][:, 3] - sample['boxes'][:, 1]) > 0 indices = w_check & h_check - sample["boxes"] = sample["boxes"][indices] + sample['boxes'] = sample['boxes'][indices] if self.transforms is not None: sample = self.transforms(sample) @@ -165,18 +165,18 @@ def _load_files(self) -> list[dict[str, str]]: image_root = os.path.join(self.root, self.directories[0]) target_root = os.path.join(self.root, self.directories[1]) image_folders = sorted( - f for f in os.listdir(image_root) if not f.endswith("json") + f for f in os.listdir(image_root) if not f.endswith('json') ) files = [] for folder in image_folders: files.append( { - "image": os.path.join(image_root, folder, "image_geotiff.tif"), - "target": os.path.join( + 'image': os.path.join(image_root, folder, 'image_geotiff.tif'), + 'target': os.path.join( target_root, - folder.replace("source", "labels"), - "pixel_bounds.npy", + folder.replace('source', 'labels'), + 'pixel_bounds.npy', ), } ) @@ -198,7 +198,7 @@ def _verify(self) -> None: filepath = os.path.join(self.root, filename) if os.path.exists(filepath): if self.checksum and not check_integrity(filepath, md5): - raise RuntimeError("Dataset checksum mismatch.") + raise RuntimeError('Dataset checksum mismatch.') exists.append(True) extract_archive(filepath) else: @@ -217,7 +217,7 @@ def _verify(self) -> None: for filename, md5 in zip(self.filenames, self.md5s): filepath = os.path.join(self.root, filename) if self.checksum and not check_integrity(filepath, md5): - raise RuntimeError("Dataset checksum mismatch.") + raise RuntimeError('Dataset checksum mismatch.') extract_archive(filepath) def plot( @@ -238,34 +238,34 @@ def plot( """ ncols = 1 - sample["image"] = sample["image"].byte() - image = sample["image"] - if "boxes" in sample and len(sample["boxes"]): - image = draw_bounding_boxes(image=sample["image"], boxes=sample["boxes"]) + sample['image'] = sample['image'].byte() + image = sample['image'] + if 'boxes' in sample and len(sample['boxes']): + image = draw_bounding_boxes(image=sample['image'], boxes=sample['boxes']) image = image.permute((1, 2, 0)).numpy() - if "prediction_boxes" in sample and len(sample["prediction_boxes"]): + if 'prediction_boxes' in sample and len(sample['prediction_boxes']): ncols += 1 preds = draw_bounding_boxes( - image=sample["image"], boxes=sample["prediction_boxes"] + image=sample['image'], boxes=sample['prediction_boxes'] ) preds = preds.permute((1, 2, 0)).numpy() fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) if ncols < 2: axs.imshow(image) - axs.axis("off") + axs.axis('off') if show_titles: - axs.set_title("Ground Truth") + axs.set_title('Ground Truth') else: axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(preds) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Ground Truth") - axs[1].set_title("Predictions") + axs[0].set_title('Ground Truth') + axs[1].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/nccm.py b/torchgeo/datasets/nccm.py index 48d33137dab..7f015eae33c 100644 --- a/torchgeo/datasets/nccm.py +++ b/torchgeo/datasets/nccm.py @@ -51,25 +51,25 @@ class NCCM(RasterDataset): .. versionadded:: 0.6 """ - filename_regex = r"CDL(?P\d{4})_clip" - filename_glob = "CDL*.*" + filename_regex = r'CDL(?P\d{4})_clip' + filename_glob = 'CDL*.*' - date_format = "%Y" + date_format = '%Y' is_image = False urls = { - 2019: "https://figshare.com/ndownloader/files/25070540", - 2018: "https://figshare.com/ndownloader/files/25070624", - 2017: "https://figshare.com/ndownloader/files/25070582", + 2019: 'https://figshare.com/ndownloader/files/25070540', + 2018: 'https://figshare.com/ndownloader/files/25070624', + 2017: 'https://figshare.com/ndownloader/files/25070582', } md5s = { - 2019: "0d062bbd42e483fdc8239d22dba7020f", - 2018: "b3bb4894478d10786aa798fb11693ec1", - 2017: "d047fbe4a85341fa6248fd7e0badab6c", + 2019: '0d062bbd42e483fdc8239d22dba7020f', + 2018: 'b3bb4894478d10786aa798fb11693ec1', + 2017: 'd047fbe4a85341fa6248fd7e0badab6c', } fnames = { - 2019: "CDL2019_clip.tif", - 2018: "CDL2018_clip1.tif", - 2017: "CDL2017_clip.tif", + 2019: 'CDL2019_clip.tif', + 2018: 'CDL2018_clip1.tif', + 2017: 'CDL2017_clip.tif', } cmap = { @@ -82,7 +82,7 @@ class NCCM(RasterDataset): def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float | None = None, years: list[int] = [2019], @@ -110,8 +110,8 @@ def __init__( DatasetNotFoundError: If dataset is not found and *download* is False. """ assert set(years) <= self.md5s.keys(), ( - "NCCM data product only exists for the following years: " - f"{list(self.md5s.keys())}." + 'NCCM data product only exists for the following years: ' + f'{list(self.md5s.keys())}.' ) self.paths = paths self.years = years @@ -140,7 +140,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ sample = super().__getitem__(query) - sample["mask"] = self.ordinal_map[sample["mask"]] + sample['mask'] = self.ordinal_map[sample['mask']] return sample def _verify(self) -> None: @@ -182,29 +182,29 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - mask = sample["mask"].squeeze() + mask = sample['mask'].squeeze() ncols = 1 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - pred = sample["prediction"].squeeze() + pred = sample['prediction'].squeeze() ncols = 2 fig, axs = plt.subplots( nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False ) - axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation="none") - axs[0, 0].axis("off") + axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation='none') + axs[0, 0].axis('off') if show_titles: - axs[0, 0].set_title("Mask") + axs[0, 0].set_title('Mask') if showing_predictions: - axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation="none") - axs[0, 1].axis("off") + axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation='none') + axs[0, 1].axis('off') if show_titles: - axs[0, 1].set_title("Prediction") + axs[0, 1].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py index 7cb03b55d0f..e1e106d0308 100644 --- a/torchgeo/datasets/nlcd.py +++ b/torchgeo/datasets/nlcd.py @@ -67,22 +67,22 @@ class NLCD(RasterDataset): .. versionadded:: 0.5 """ # noqa: E501 - filename_glob = "nlcd_*_land_cover_l48_*.img" + filename_glob = 'nlcd_*_land_cover_l48_*.img' filename_regex = ( - r"nlcd_(?P\d{4})_land_cover_l48_(?P\d{8})\.img" + r'nlcd_(?P\d{4})_land_cover_l48_(?P\d{8})\.img' ) - zipfile_glob = "nlcd_*_land_cover_l48_*.zip" - date_format = "%Y" + zipfile_glob = 'nlcd_*_land_cover_l48_*.zip' + date_format = '%Y' is_image = False - url = "https://s3-us-west-2.amazonaws.com/mrlc/nlcd_{}_land_cover_l48_20210604.zip" + url = 'https://s3-us-west-2.amazonaws.com/mrlc/nlcd_{}_land_cover_l48_20210604.zip' md5s = { - 2001: "538166a4d783204764e3df3b221fc4cd", - 2006: "67454e7874a00294adb9442374d0c309", - 2011: "ea524c835d173658eeb6fa3c8e6b917b", - 2016: "452726f6e3bd3f70d8ca2476723d238a", - 2019: "82851c3f8105763b01c83b4a9e6f3961", + 2001: '538166a4d783204764e3df3b221fc4cd', + 2006: '67454e7874a00294adb9442374d0c309', + 2011: 'ea524c835d173658eeb6fa3c8e6b917b', + 2016: '452726f6e3bd3f70d8ca2476723d238a', + 2019: '82851c3f8105763b01c83b4a9e6f3961', } cmap = { @@ -107,7 +107,7 @@ class NLCD(RasterDataset): def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float | None = None, years: list[int] = [2019], @@ -139,13 +139,13 @@ def __init__( DatasetNotFoundError: If dataset is not found and *download* is False. """ assert set(years) <= self.md5s.keys(), ( - "NLCD data product only exists for the following years: " - f"{list(self.md5s.keys())}." + 'NLCD data product only exists for the following years: ' + f'{list(self.md5s.keys())}.' ) assert ( set(classes) <= self.cmap.keys() - ), f"Only the following classes are valid: {list(self.cmap.keys())}." - assert 0 in classes, "Classes must include the background class: 0" + ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths self.years = years @@ -177,7 +177,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: IndexError: if query is not found in the index """ sample = super().__getitem__(query) - sample["mask"] = self.ordinal_map[sample["mask"]] + sample['mask'] = self.ordinal_map[sample['mask']] return sample def _verify(self) -> None: @@ -189,9 +189,9 @@ def _verify(self) -> None: # Check if the zip files have already been downloaded exists = [] for year in self.years: - zipfile_year = self.zipfile_glob.replace("*", str(year), 1) + zipfile_year = self.zipfile_glob.replace('*', str(year), 1) assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, "**", zipfile_year) + pathname = os.path.join(self.paths, '**', zipfile_year) if glob.glob(pathname, recursive=True): exists.append(True) self._extract() @@ -221,9 +221,9 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" for year in self.years: - zipfile_name = self.zipfile_glob.replace("*", str(year), 1) + zipfile_name = self.zipfile_glob.replace('*', str(year), 1) assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, "**", zipfile_name) + pathname = os.path.join(self.paths, '**', zipfile_name) extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) def plot( @@ -242,29 +242,29 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - mask = sample["mask"].squeeze() + mask = sample['mask'].squeeze() ncols = 1 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - pred = sample["prediction"].squeeze() + pred = sample['prediction'].squeeze() ncols = 2 fig, axs = plt.subplots( nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False ) - axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation="none") - axs[0, 0].axis("off") + axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation='none') + axs[0, 0].axis('off') if show_titles: - axs[0, 0].set_title("Mask") + axs[0, 0].set_title('Mask') if showing_predictions: - axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation="none") - axs[0, 1].axis("off") + axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation='none') + axs[0, 1].axis('off') if show_titles: - axs[0, 1].set_title("Prediction") + axs[0, 1].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py index e2529d81e9a..b414fca6ba4 100644 --- a/torchgeo/datasets/openbuildings.py +++ b/torchgeo/datasets/openbuildings.py @@ -60,153 +60,153 @@ class OpenBuildings(VectorDataset): """ md5s = { - "025_buildings.csv.gz": "41db2572bfd08628d01475a2ee1a2f17", - "04f_buildings.csv.gz": "3232c1c6d45c1543260b77e5689fc8b1", - "05b_buildings.csv.gz": "4fc57c63bbbf9a21a3902da7adc3a670", - "093_buildings.csv.gz": "00fce146dadf0b30255e750c4c5ac2de", - "095_buildings.csv.gz": "f5765b0936f7ccbd0b4abed60d994f08", - "0c3_buildings.csv.gz": "013b130fe872387e0cff842399b423de", - "0c3_buildings.csv": "a697ad2433e9a9f6001de25b4664651a", - "0c5_buildings.csv.gz": "16ca283e9344e9da8b47acaf03c1c6e4", - "0c7_buildings.csv.gz": "b3774930006497a80c8a2fbf33056610", - "0d1_buildings.csv.gz": "41e652218ca5964d297d9cd1d84b831c", - "0d7_buildings.csv.gz": "d365fe47d10b0756dd54ceca24598d8e", - "0d9_buildings.csv.gz": "3ebd47fa4f86857266e9a7346d6aa163", - "0db_buildings.csv.gz": "368213e9caa7ee229ef9403b0ca8c80d", - "0dd_buildings.csv.gz": "8f5fcefff262fdfd82800092d2e9d841", - "0df_buildings.csv.gz": "cbb5f63b10daa25568bdde8d9f66f8a4", - "0e1_buildings.csv.gz": "a9b9bf1e541b62c8a34d2f6f2ae71e1c", - "0e3_buildings.csv.gz": "3d9c2ffc11c02aec2bd008699f9c4bd1", - "0e5_buildings.csv.gz": "1e1b2bf63dfc520e62e4b68db23fe64c", - "0e7_buildings.csv.gz": "c96797588c90e66268367cb56b4b9af8", - "0e9_buildings.csv.gz": "c53bb7bbc8140034d1be2c49ff49af68", - "0eb_buildings.csv.gz": "407c771f614a15d69d78f1e25decf694", - "0ed_buildings.csv.gz": "bddd10992d291677019d7106ce1f4fac", - "0ef_buildings.csv.gz": "d1b91936e7ac06c661878ef9eb5dba7b", - "0f1_buildings.csv.gz": "9d86eb10d2d8766e1385b6c52c11d5e2", - "0f9_buildings.csv.gz": "1c6775131214b26f4a27b4c42d6e9fca", - "0fb_buildings.csv.gz": "d39528cb4e0cbff589ca89dc86d9b5db", - "0fd_buildings.csv.gz": "304fe4a60e950c900697d975098f7536", - "0ff_buildings.csv.gz": "266ca7ed1ad0251b3999b0e2e9b54648", - "103_buildings.csv.gz": "8d3cafab5f1e02b2a0a6180eb34d1cac", - "105_buildings.csv.gz": "dd61cc74239aa9a1b30f10859122807b", - "107_buildings.csv.gz": "823c05984f859a1bf17af8ce78bf2892", - "109_buildings.csv.gz": "cfdee0e807168cd1c183d9c01535369b", - "10b_buildings.csv.gz": "d8ecaf406abd864b641ba34985f3042e", - "10d_buildings.csv.gz": "af584a542a17942ff7e94653322dba87", - "10f_buildings.csv.gz": "3d5369e15c4d1f59fb38cf61f4e6290b", - "111_buildings.csv.gz": "47504e43d1b67101bed5d924225328dc", - "113_buildings.csv.gz": "3f991c831569f91f34eaa8fc3882b2fd", - "117_buildings.csv.gz": "a4145fa6e458480e30c807f80ae5cd65", - "119_buildings.csv.gz": "5661b7ac23f266542c7e0d962a8cae58", - "11b_buildings.csv.gz": "41b6d036610d0bddac069ec72e68710e", - "11d_buildings.csv.gz": "1ef75e9d176dd8d6bfa6012d36b1d25c", - "11f_buildings.csv.gz": "f004873d1ef3933c1716ab6409565b7d", - "121_buildings.csv.gz": "0c7e7a9043ed069fbdefdcfcfc437482", - "123_buildings.csv.gz": "c46bd53b67025c3de11657220cce0aec", - "125_buildings.csv.gz": "33253ae1a82656f4eedca9bd86f981a3", - "127_buildings.csv.gz": "2f827f8fc93485572178e9ad0c65e22d", - "129_buildings.csv.gz": "74f98346990a1d1e41241ce8f4bb201a", - "12f_buildings.csv.gz": "b1b0777296df2bfef512df0945ca3e14", - "131_buildings.csv.gz": "8362825b10c9396ecbb85c49cd210bc6", - "137_buildings.csv.gz": "96da7389df820405b0010db4a6c98c61", - "139_buildings.csv.gz": "c41e26fc6f3565c3d7c66ab977dc8159", - "13b_buildings.csv.gz": "981d4ccb0f41a103bdad8ef949eb4ffe", - "13d_buildings.csv.gz": "d15585d06ee74b0095842dd887197035", - "141_buildings.csv.gz": "ae0bf17778d45119c74e50e06a04020d", - "143_buildings.csv.gz": "9699809e57eb097dfaf9d484f1d9c5fa", - "145_buildings.csv.gz": "81e74e0165ea358278ce18507dddfdb0", - "147_buildings.csv.gz": "39edad15fa16c432f5d460f0a8166032", - "149_buildings.csv.gz": "94bf8f8fa221744fb1d57c7d4065e69e", - "14f_buildings.csv.gz": "ca8410be89b5cf868c2a67861712e4ea", - "15b_buildings.csv.gz": "8c0071c0ae20a60e8dd4d7aa6aac5a99", - "15d_buildings.csv.gz": "35f044a323556adda5f31e8fc9307c85", - "161_buildings.csv.gz": "ba08b70a26f07b5e2cd4eafd9d6f826b", - "163_buildings.csv.gz": "2bec83a2504b531cd1cb0311fcb6c952", - "165_buildings.csv.gz": "48f934733dd3054164f9b09abee63312", - "167_buildings.csv.gz": "bba8657024d80d44e475759b65adc969", - "169_buildings.csv.gz": "13e142e48597ee7a8b0b812e226dfa72", - "16b_buildings.csv.gz": "9c62351d6cc8eaf761ab89d4586d26d6", - "16d_buildings.csv.gz": "a33c23da3f603c8c3eacc5e6a47aaf66", - "16f_buildings.csv.gz": "4850dd7c8f0fb628ba5864ea9f47647b", - "171_buildings.csv.gz": "4217f1b025db869c8bed1014704c2a79", - "173_buildings.csv.gz": "5a5f3f07e261a9dc58c6180b69130e4a", - "175_buildings.csv.gz": "5bbf7a7c8f57d28e024ddf8f4039b575", - "177_buildings.csv.gz": "76cd4b17d68d62e1f088f229b65f8acf", - "179_buildings.csv.gz": "a5a1c6609483336ddff91b2385e70eb9", - "17b_buildings.csv.gz": "a47c1145a3b0bcdaba18c153b7b92b87", - "17d_buildings.csv.gz": "3226d0abf396f44c1a436be83538dfd8", - "17f_buildings.csv.gz": "3e18d4fc5837ee89274d30f2126b92b2", - "181_buildings.csv.gz": "c87639d7f6d6a85a3fa6b06910b0e145", - "183_buildings.csv.gz": "e94438ebf19b3b25035954d23a0e90cf", - "185_buildings.csv.gz": "8de8d1d50c16c575f85b96dee474cb56", - "189_buildings.csv.gz": "da94cd495a99496fd687bbb4a1715c90", - "18b_buildings.csv.gz": "9ab353335fe6ff694e834889be2b305d", - "18d_buildings.csv.gz": "e37e0f868ce96f7d14f7bf1a301da1d3", - "18f_buildings.csv.gz": "e9000b9ef9bb0f838088e96becfc95a1", - "191_buildings.csv.gz": "c00bb4d6b2b12615d576c06fe545cbfa", - "193_buildings.csv.gz": "d48d4c03ef053f6987b3e6e9e78a8b03", - "195_buildings.csv.gz": "d93ab833e74480f07a5ccf227067db5a", - "197_buildings.csv.gz": "8667e040f9863e43924aafe6071fabc7", - "199_buildings.csv.gz": "04ba65a4caf16cc1e0d5c4e1322c5885", - "19b_buildings.csv.gz": "e49412e3e1bccceb0bdb4df5201288f4", - "19d_buildings.csv.gz": "92b5fb4e96529d90e99c788e3e8696d4", - "19f_buildings.csv.gz": "c023f6c37d0026b56f530b841517a6cd", - "1a1_buildings.csv.gz": "471483b50c722af104af8a582e780c04", - "1a3_buildings.csv.gz": "0a453053f1ff53f9e165e16c7f97354a", - "1a5_buildings.csv.gz": "1f6a823e223d5f29c66aa728933de684", - "1a7_buildings.csv.gz": "6130b724501fa16e6d84e484c4091f1f", - "1a9_buildings.csv.gz": "73022e8e7b994e76a58cc763a057d542", - "1b9_buildings.csv.gz": "48dea4af9d12b755e75b76c68c47de6b", - "1bb_buildings.csv.gz": "dfb9ee4d3843d81722b70f7582c775a4", - "1bd_buildings.csv.gz": "fdea2898fc50ae25b6196048373d8244", - "1bf_buildings.csv.gz": "96ef27d6128d0bcdfa896fed6f27cdd0", - "1c1_buildings.csv.gz": "32e3667d939e7f95316eb75a6ffdb603", - "1c3_buildings.csv.gz": "ed94b543da1bbe3101ed66f7d7727d24", - "1c5_buildings.csv.gz": "ce527ab33e564f0cc1b63ae467932a18", - "1c7_buildings.csv.gz": "d5fb474466d6a11d3b08e3a011984ada", - "1dd_buildings.csv.gz": "9e7e50e3f95b3f2ceff6351b75ca1e75", - "1e5_buildings.csv.gz": "f95ea85fce47ce7edf5729086d43f922", - "1e7_buildings.csv.gz": "2bca5682c48134e69b738d70dfe7d516", - "1e9_buildings.csv.gz": "f049ad06dbbb200f524b4f50d1df8c2e", - "1eb_buildings.csv.gz": "6822d7f202b453ec3cc03fb8f04691ad", - "1ed_buildings.csv.gz": "9dfc560e2c3d135ebdcd46fa09c47169", - "1ef_buildings.csv.gz": "506e7772c35b09cfd3b6f8691dc2947d", - "1f1_buildings.csv.gz": "b74f2b585cfad3b881fe4f124080440a", - "1f3_buildings.csv.gz": "12896642315320e11ed9ed2d3f0e5995", - "1f5_buildings.csv.gz": "334aea21e532e178bf5c54d028158906", - "1f7_buildings.csv.gz": "0e8c3d2e005eb04c6852a8aa993f5a76", - "217_buildings.csv.gz": "296e9ba121fea752b865a48e5c0fe8a5", - "219_buildings.csv.gz": "1d19b6626d738f7706f75c2935aaaff4", - "21d_buildings.csv.gz": "28bfca1f8668f59db021d3a195994768", - "21f_buildings.csv.gz": "06325c8b0a8f6ed598b7dc6f0bb5adf2", - "221_buildings.csv.gz": "a354ffc1f7226d525c7cf53848975da1", - "223_buildings.csv.gz": "3bda1339d561b3bc749220877f1384d9", - "225_buildings.csv.gz": "8eb02ad77919d9e551138a14d3ad1bbc", - "227_buildings.csv.gz": "c07aceb7c81f83a653810befa0695b61", - "22f_buildings.csv.gz": "97d63e30e008ec4424f6b0641b75377c", - "231_buildings.csv.gz": "f4bc384ed74552ddcfe2e69107b91345", - "233_buildings.csv.gz": "081756e7bdcfdc2aee9114c4cfe62bd8", - "23b_buildings.csv.gz": "75776d3dcbc90cf3a596664747880134", - "23d_buildings.csv.gz": "e5d0b9b7b14601f58cfdb9ea170e9520", - "23f_buildings.csv.gz": "77f38466419b4d391be8e4f05207fdf5", - "3d1_buildings.csv.gz": "6659c97bd765250b0dee4b1b7ff583a9", - "3d5_buildings.csv.gz": "c27d8f6b2808549606f00bc04d8b42bc", - "3d7_buildings.csv.gz": "abdef2e68cc31c67dbb6e60c4c40483e", - "3d9_buildings.csv.gz": "4c06ae37d8e76626345a52a32f989de9", - "3db_buildings.csv.gz": "e83ca0115eaf4ec0a72aaf932b00442a", - "b5b_buildings.csv.gz": "5e5f59cb17b81137d89c4bab8107e837", + '025_buildings.csv.gz': '41db2572bfd08628d01475a2ee1a2f17', + '04f_buildings.csv.gz': '3232c1c6d45c1543260b77e5689fc8b1', + '05b_buildings.csv.gz': '4fc57c63bbbf9a21a3902da7adc3a670', + '093_buildings.csv.gz': '00fce146dadf0b30255e750c4c5ac2de', + '095_buildings.csv.gz': 'f5765b0936f7ccbd0b4abed60d994f08', + '0c3_buildings.csv.gz': '013b130fe872387e0cff842399b423de', + '0c3_buildings.csv': 'a697ad2433e9a9f6001de25b4664651a', + '0c5_buildings.csv.gz': '16ca283e9344e9da8b47acaf03c1c6e4', + '0c7_buildings.csv.gz': 'b3774930006497a80c8a2fbf33056610', + '0d1_buildings.csv.gz': '41e652218ca5964d297d9cd1d84b831c', + '0d7_buildings.csv.gz': 'd365fe47d10b0756dd54ceca24598d8e', + '0d9_buildings.csv.gz': '3ebd47fa4f86857266e9a7346d6aa163', + '0db_buildings.csv.gz': '368213e9caa7ee229ef9403b0ca8c80d', + '0dd_buildings.csv.gz': '8f5fcefff262fdfd82800092d2e9d841', + '0df_buildings.csv.gz': 'cbb5f63b10daa25568bdde8d9f66f8a4', + '0e1_buildings.csv.gz': 'a9b9bf1e541b62c8a34d2f6f2ae71e1c', + '0e3_buildings.csv.gz': '3d9c2ffc11c02aec2bd008699f9c4bd1', + '0e5_buildings.csv.gz': '1e1b2bf63dfc520e62e4b68db23fe64c', + '0e7_buildings.csv.gz': 'c96797588c90e66268367cb56b4b9af8', + '0e9_buildings.csv.gz': 'c53bb7bbc8140034d1be2c49ff49af68', + '0eb_buildings.csv.gz': '407c771f614a15d69d78f1e25decf694', + '0ed_buildings.csv.gz': 'bddd10992d291677019d7106ce1f4fac', + '0ef_buildings.csv.gz': 'd1b91936e7ac06c661878ef9eb5dba7b', + '0f1_buildings.csv.gz': '9d86eb10d2d8766e1385b6c52c11d5e2', + '0f9_buildings.csv.gz': '1c6775131214b26f4a27b4c42d6e9fca', + '0fb_buildings.csv.gz': 'd39528cb4e0cbff589ca89dc86d9b5db', + '0fd_buildings.csv.gz': '304fe4a60e950c900697d975098f7536', + '0ff_buildings.csv.gz': '266ca7ed1ad0251b3999b0e2e9b54648', + '103_buildings.csv.gz': '8d3cafab5f1e02b2a0a6180eb34d1cac', + '105_buildings.csv.gz': 'dd61cc74239aa9a1b30f10859122807b', + '107_buildings.csv.gz': '823c05984f859a1bf17af8ce78bf2892', + '109_buildings.csv.gz': 'cfdee0e807168cd1c183d9c01535369b', + '10b_buildings.csv.gz': 'd8ecaf406abd864b641ba34985f3042e', + '10d_buildings.csv.gz': 'af584a542a17942ff7e94653322dba87', + '10f_buildings.csv.gz': '3d5369e15c4d1f59fb38cf61f4e6290b', + '111_buildings.csv.gz': '47504e43d1b67101bed5d924225328dc', + '113_buildings.csv.gz': '3f991c831569f91f34eaa8fc3882b2fd', + '117_buildings.csv.gz': 'a4145fa6e458480e30c807f80ae5cd65', + '119_buildings.csv.gz': '5661b7ac23f266542c7e0d962a8cae58', + '11b_buildings.csv.gz': '41b6d036610d0bddac069ec72e68710e', + '11d_buildings.csv.gz': '1ef75e9d176dd8d6bfa6012d36b1d25c', + '11f_buildings.csv.gz': 'f004873d1ef3933c1716ab6409565b7d', + '121_buildings.csv.gz': '0c7e7a9043ed069fbdefdcfcfc437482', + '123_buildings.csv.gz': 'c46bd53b67025c3de11657220cce0aec', + '125_buildings.csv.gz': '33253ae1a82656f4eedca9bd86f981a3', + '127_buildings.csv.gz': '2f827f8fc93485572178e9ad0c65e22d', + '129_buildings.csv.gz': '74f98346990a1d1e41241ce8f4bb201a', + '12f_buildings.csv.gz': 'b1b0777296df2bfef512df0945ca3e14', + '131_buildings.csv.gz': '8362825b10c9396ecbb85c49cd210bc6', + '137_buildings.csv.gz': '96da7389df820405b0010db4a6c98c61', + '139_buildings.csv.gz': 'c41e26fc6f3565c3d7c66ab977dc8159', + '13b_buildings.csv.gz': '981d4ccb0f41a103bdad8ef949eb4ffe', + '13d_buildings.csv.gz': 'd15585d06ee74b0095842dd887197035', + '141_buildings.csv.gz': 'ae0bf17778d45119c74e50e06a04020d', + '143_buildings.csv.gz': '9699809e57eb097dfaf9d484f1d9c5fa', + '145_buildings.csv.gz': '81e74e0165ea358278ce18507dddfdb0', + '147_buildings.csv.gz': '39edad15fa16c432f5d460f0a8166032', + '149_buildings.csv.gz': '94bf8f8fa221744fb1d57c7d4065e69e', + '14f_buildings.csv.gz': 'ca8410be89b5cf868c2a67861712e4ea', + '15b_buildings.csv.gz': '8c0071c0ae20a60e8dd4d7aa6aac5a99', + '15d_buildings.csv.gz': '35f044a323556adda5f31e8fc9307c85', + '161_buildings.csv.gz': 'ba08b70a26f07b5e2cd4eafd9d6f826b', + '163_buildings.csv.gz': '2bec83a2504b531cd1cb0311fcb6c952', + '165_buildings.csv.gz': '48f934733dd3054164f9b09abee63312', + '167_buildings.csv.gz': 'bba8657024d80d44e475759b65adc969', + '169_buildings.csv.gz': '13e142e48597ee7a8b0b812e226dfa72', + '16b_buildings.csv.gz': '9c62351d6cc8eaf761ab89d4586d26d6', + '16d_buildings.csv.gz': 'a33c23da3f603c8c3eacc5e6a47aaf66', + '16f_buildings.csv.gz': '4850dd7c8f0fb628ba5864ea9f47647b', + '171_buildings.csv.gz': '4217f1b025db869c8bed1014704c2a79', + '173_buildings.csv.gz': '5a5f3f07e261a9dc58c6180b69130e4a', + '175_buildings.csv.gz': '5bbf7a7c8f57d28e024ddf8f4039b575', + '177_buildings.csv.gz': '76cd4b17d68d62e1f088f229b65f8acf', + '179_buildings.csv.gz': 'a5a1c6609483336ddff91b2385e70eb9', + '17b_buildings.csv.gz': 'a47c1145a3b0bcdaba18c153b7b92b87', + '17d_buildings.csv.gz': '3226d0abf396f44c1a436be83538dfd8', + '17f_buildings.csv.gz': '3e18d4fc5837ee89274d30f2126b92b2', + '181_buildings.csv.gz': 'c87639d7f6d6a85a3fa6b06910b0e145', + '183_buildings.csv.gz': 'e94438ebf19b3b25035954d23a0e90cf', + '185_buildings.csv.gz': '8de8d1d50c16c575f85b96dee474cb56', + '189_buildings.csv.gz': 'da94cd495a99496fd687bbb4a1715c90', + '18b_buildings.csv.gz': '9ab353335fe6ff694e834889be2b305d', + '18d_buildings.csv.gz': 'e37e0f868ce96f7d14f7bf1a301da1d3', + '18f_buildings.csv.gz': 'e9000b9ef9bb0f838088e96becfc95a1', + '191_buildings.csv.gz': 'c00bb4d6b2b12615d576c06fe545cbfa', + '193_buildings.csv.gz': 'd48d4c03ef053f6987b3e6e9e78a8b03', + '195_buildings.csv.gz': 'd93ab833e74480f07a5ccf227067db5a', + '197_buildings.csv.gz': '8667e040f9863e43924aafe6071fabc7', + '199_buildings.csv.gz': '04ba65a4caf16cc1e0d5c4e1322c5885', + '19b_buildings.csv.gz': 'e49412e3e1bccceb0bdb4df5201288f4', + '19d_buildings.csv.gz': '92b5fb4e96529d90e99c788e3e8696d4', + '19f_buildings.csv.gz': 'c023f6c37d0026b56f530b841517a6cd', + '1a1_buildings.csv.gz': '471483b50c722af104af8a582e780c04', + '1a3_buildings.csv.gz': '0a453053f1ff53f9e165e16c7f97354a', + '1a5_buildings.csv.gz': '1f6a823e223d5f29c66aa728933de684', + '1a7_buildings.csv.gz': '6130b724501fa16e6d84e484c4091f1f', + '1a9_buildings.csv.gz': '73022e8e7b994e76a58cc763a057d542', + '1b9_buildings.csv.gz': '48dea4af9d12b755e75b76c68c47de6b', + '1bb_buildings.csv.gz': 'dfb9ee4d3843d81722b70f7582c775a4', + '1bd_buildings.csv.gz': 'fdea2898fc50ae25b6196048373d8244', + '1bf_buildings.csv.gz': '96ef27d6128d0bcdfa896fed6f27cdd0', + '1c1_buildings.csv.gz': '32e3667d939e7f95316eb75a6ffdb603', + '1c3_buildings.csv.gz': 'ed94b543da1bbe3101ed66f7d7727d24', + '1c5_buildings.csv.gz': 'ce527ab33e564f0cc1b63ae467932a18', + '1c7_buildings.csv.gz': 'd5fb474466d6a11d3b08e3a011984ada', + '1dd_buildings.csv.gz': '9e7e50e3f95b3f2ceff6351b75ca1e75', + '1e5_buildings.csv.gz': 'f95ea85fce47ce7edf5729086d43f922', + '1e7_buildings.csv.gz': '2bca5682c48134e69b738d70dfe7d516', + '1e9_buildings.csv.gz': 'f049ad06dbbb200f524b4f50d1df8c2e', + '1eb_buildings.csv.gz': '6822d7f202b453ec3cc03fb8f04691ad', + '1ed_buildings.csv.gz': '9dfc560e2c3d135ebdcd46fa09c47169', + '1ef_buildings.csv.gz': '506e7772c35b09cfd3b6f8691dc2947d', + '1f1_buildings.csv.gz': 'b74f2b585cfad3b881fe4f124080440a', + '1f3_buildings.csv.gz': '12896642315320e11ed9ed2d3f0e5995', + '1f5_buildings.csv.gz': '334aea21e532e178bf5c54d028158906', + '1f7_buildings.csv.gz': '0e8c3d2e005eb04c6852a8aa993f5a76', + '217_buildings.csv.gz': '296e9ba121fea752b865a48e5c0fe8a5', + '219_buildings.csv.gz': '1d19b6626d738f7706f75c2935aaaff4', + '21d_buildings.csv.gz': '28bfca1f8668f59db021d3a195994768', + '21f_buildings.csv.gz': '06325c8b0a8f6ed598b7dc6f0bb5adf2', + '221_buildings.csv.gz': 'a354ffc1f7226d525c7cf53848975da1', + '223_buildings.csv.gz': '3bda1339d561b3bc749220877f1384d9', + '225_buildings.csv.gz': '8eb02ad77919d9e551138a14d3ad1bbc', + '227_buildings.csv.gz': 'c07aceb7c81f83a653810befa0695b61', + '22f_buildings.csv.gz': '97d63e30e008ec4424f6b0641b75377c', + '231_buildings.csv.gz': 'f4bc384ed74552ddcfe2e69107b91345', + '233_buildings.csv.gz': '081756e7bdcfdc2aee9114c4cfe62bd8', + '23b_buildings.csv.gz': '75776d3dcbc90cf3a596664747880134', + '23d_buildings.csv.gz': 'e5d0b9b7b14601f58cfdb9ea170e9520', + '23f_buildings.csv.gz': '77f38466419b4d391be8e4f05207fdf5', + '3d1_buildings.csv.gz': '6659c97bd765250b0dee4b1b7ff583a9', + '3d5_buildings.csv.gz': 'c27d8f6b2808549606f00bc04d8b42bc', + '3d7_buildings.csv.gz': 'abdef2e68cc31c67dbb6e60c4c40483e', + '3d9_buildings.csv.gz': '4c06ae37d8e76626345a52a32f989de9', + '3db_buildings.csv.gz': 'e83ca0115eaf4ec0a72aaf932b00442a', + 'b5b_buildings.csv.gz': '5e5f59cb17b81137d89c4bab8107e837', } - filename_glob = "*_buildings.csv" - zipfile_glob = "*_buildings.csv.gz" + filename_glob = '*_buildings.csv' + zipfile_glob = '*_buildings.csv.gz' - meta_data_url = "https://sites.research.google/open-buildings/tiles.geojson" - meta_data_filename = "tiles.geojson" + meta_data_url = 'https://sites.research.google/open-buildings/tiles.geojson' + meta_data_filename = 'tiles.geojson' def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float = 0.0001, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -241,12 +241,12 @@ def __init__( self.index = Index(interleaved=False, properties=Property(dimension=3)) assert isinstance(self.paths, str) - with open(os.path.join(self.paths, "tiles.geojson")) as f: + with open(os.path.join(self.paths, 'tiles.geojson')) as f: data = json.load(f) - features = data["features"] + features = data['features'] features_filenames = [ - feature["properties"]["tile_url"].split("/")[-1] for feature in features + feature['properties']['tile_url'].split('/')[-1] for feature in features ] # get csv filename polygon_files = glob.glob(os.path.join(self.paths, self.zipfile_glob)) @@ -259,12 +259,12 @@ def __init__( ] i = 0 - source_crs = CRS.from_dict({"init": "epsg:4326"}) + source_crs = CRS.from_dict({'init': 'epsg:4326'}) for feature in matched_features: if crs is None: crs = CRS.from_dict(source_crs) - c = feature["geometry"]["coordinates"][0] + c = feature['geometry']['coordinates'][0] xs = [x[0] for x in c] ys = [x[1] for x in c] @@ -278,7 +278,7 @@ def __init__( coords = (minx, maxx, miny, maxy, mint, maxt) filepath = os.path.join( - self.paths, feature["properties"]["tile_url"].split("/")[-1] + self.paths, feature['properties']['tile_url'].split('/')[-1] ) self.index.insert(i, coords, filepath) i += 1 @@ -307,7 +307,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if not filepaths: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) shapes = self._filter_geometries(query, filepaths) @@ -326,7 +326,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: else: masks = torch.zeros(size=(1, round(height), round(width))) - sample = {"mask": masks, "crs": self.crs, "bbox": query} + sample = {'mask': masks, 'crs': self.crs, 'bbox': query} if self.transforms is not None: sample = self.transforms(sample) @@ -354,16 +354,16 @@ def _filter_geometries( [query.miny, query.maxy], ) df_query = ( - f"longitude >= {minx} & longitude <= {maxx} & " - f"latitude >= {miny} & latitude <= {maxy}" + f'longitude >= {minx} & longitude <= {maxx} & ' + f'latitude >= {miny} & latitude <= {maxy}' ) shapes = [] for f in filepaths: - csv_chunks = pd.read_csv(f, chunksize=200000, compression="gzip") + csv_chunks = pd.read_csv(f, chunksize=200000, compression='gzip') for chunk in csv_chunks: df = chunk.query(df_query) # Warp geometries to requested CRS - polygon_series = df["geometry"].map(self._wkt_fiona_geom_transform) + polygon_series = df['geometry'].map(self._wkt_fiona_geom_transform) shapes.extend(polygon_series.values.tolist()) return shapes @@ -382,7 +382,7 @@ def _wkt_fiona_geom_transform(self, x: str) -> dict[str, Any]: x = json.loads(x.replace("'", '"')) import fiona - if hasattr(fiona, "model"): + if hasattr(fiona, 'model'): import fiona.model geom = fiona.model.Geometry(**x) @@ -402,7 +402,7 @@ def _verify(self) -> None: for zipfile in glob.iglob(pathname): filename = os.path.basename(zipfile) if self.checksum and not check_integrity(zipfile, self.md5s[filename]): - raise RuntimeError(f"Dataset found, but corrupted: {filename}.") + raise RuntimeError(f'Dataset found, but corrupted: {filename}.') i += 1 if i != 0: @@ -426,11 +426,11 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - mask = sample["mask"].permute(1, 2, 0) + mask = sample['mask'].permute(1, 2, 0) - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - pred = sample["prediction"].permute(1, 2, 0) + pred = sample['prediction'].permute(1, 2, 0) ncols = 2 else: ncols = 1 @@ -439,17 +439,17 @@ def plot( if showing_predictions: axs[0].imshow(mask) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(pred) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Mask") - axs[1].set_title("Prediction") + axs[0].set_title('Mask') + axs[1].set_title('Prediction') else: axs.imshow(mask) - axs.axis("off") + axs.axis('off') if show_titles: - axs.set_title("Mask") + axs.set_title('Mask') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index ab085f1b923..14b76aea161 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -51,56 +51,56 @@ class OSCD(NonGeoDataset): """ urls = { - "Onera Satellite Change Detection dataset - Images.zip": ( - "https://partage.imt.fr/index.php/s/gKRaWgRnLMfwMGo/download" + 'Onera Satellite Change Detection dataset - Images.zip': ( + 'https://partage.imt.fr/index.php/s/gKRaWgRnLMfwMGo/download' ), - "Onera Satellite Change Detection dataset - Train Labels.zip": ( - "https://partage.mines-telecom.fr/index.php/s/2D6n03k58ygBSpu/download" + 'Onera Satellite Change Detection dataset - Train Labels.zip': ( + 'https://partage.mines-telecom.fr/index.php/s/2D6n03k58ygBSpu/download' ), - "Onera Satellite Change Detection dataset - Test Labels.zip": ( - "https://partage.imt.fr/index.php/s/gpStKn4Mpgfnr63/download" + 'Onera Satellite Change Detection dataset - Test Labels.zip': ( + 'https://partage.imt.fr/index.php/s/gpStKn4Mpgfnr63/download' ), } md5s = { - "Onera Satellite Change Detection dataset - Images.zip": ( - "c50d4a2941da64e03a47ac4dec63d915" + 'Onera Satellite Change Detection dataset - Images.zip': ( + 'c50d4a2941da64e03a47ac4dec63d915' ), - "Onera Satellite Change Detection dataset - Train Labels.zip": ( - "4d2965af8170c705ebad3d6ee71b6990" + 'Onera Satellite Change Detection dataset - Train Labels.zip': ( + '4d2965af8170c705ebad3d6ee71b6990' ), - "Onera Satellite Change Detection dataset - Test Labels.zip": ( - "8177d437793c522653c442aa4e66c617" + 'Onera Satellite Change Detection dataset - Test Labels.zip': ( + '8177d437793c522653c442aa4e66c617' ), } - zipfile_glob = "*Onera*.zip" - filename_glob = "*Onera*" - splits = ["train", "test"] + zipfile_glob = '*Onera*.zip' + filename_glob = '*Onera*' + splits = ['train', 'test'] - colormap = ["blue"] + colormap = ['blue'] all_bands = ( - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B10", - "B11", - "B12", + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B10', + 'B11', + 'B12', ) - rgb_bands = ("B04", "B03", "B02") + rgb_bands = ('B04', 'B03', 'B02') def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -146,10 +146,10 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - image1 = self._load_image(files["images1"]) - image2 = self._load_image(files["images2"]) - mask = self._load_target(str(files["mask"])) - sample = {"image1": image1, "image2": image2, "mask": mask} + image1 = self._load_image(files['images1']) + image2 = self._load_image(files['images2']) + mask = self._load_target(str(files['mask'])) + sample = {'image1': image1, 'image2': image2, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -168,21 +168,21 @@ def _load_files(self) -> list[dict[str, str | Sequence[str]]]: regions = [] labels_root = os.path.join( self.root, - f"Onera Satellite Change Detection dataset - {self.split.capitalize()} " - + "Labels", + f'Onera Satellite Change Detection dataset - {self.split.capitalize()} ' + + 'Labels', ) images_root = os.path.join( - self.root, "Onera Satellite Change Detection dataset - Images" + self.root, 'Onera Satellite Change Detection dataset - Images' ) - folders = glob.glob(os.path.join(labels_root, "*/")) + folders = glob.glob(os.path.join(labels_root, '*/')) for folder in folders: region = folder.split(os.sep)[-2] - mask = os.path.join(labels_root, region, "cm", "cm.png") + mask = os.path.join(labels_root, region, 'cm', 'cm.png') def get_image_paths(ind: int) -> list[str]: return sorted( glob.glob( - os.path.join(images_root, region, f"imgs_{ind}_rect", "*.tif") + os.path.join(images_root, region, f'imgs_{ind}_rect', '*.tif') ), key=sort_sentinel2_bands, ) @@ -191,7 +191,7 @@ def get_image_paths(ind: int) -> list[str]: images1 = [images1[i] for i in self.all_band_indices] images2 = [images2[i] for i in self.all_band_indices] - with open(os.path.join(images_root, region, "dates.txt")) as f: + with open(os.path.join(images_root, region, 'dates.txt')) as f: dates = tuple( line.split()[-1] for line in f.read().strip().splitlines() ) @@ -217,11 +217,11 @@ def _load_image(self, paths: Sequence[str]) -> Tensor: Returns: the image """ - images: list["np.typing.NDArray[np.int_]"] = [] + images: list['np.typing.NDArray[np.int_]'] = [] for path in paths: with Image.open(path) as img: images.append(np.array(img)) - array: "np.typing.NDArray[np.int_]" = np.stack(images, axis=0).astype(np.int_) + array: 'np.typing.NDArray[np.int_]' = np.stack(images, axis=0).astype(np.int_) tensor = torch.from_numpy(array).float() return tensor @@ -236,7 +236,7 @@ def _load_target(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("L")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('L')) tensor = torch.from_numpy(array) tensor = torch.clamp(tensor, min=0, max=1) tensor = tensor.to(torch.long) @@ -245,9 +245,9 @@ def _load_target(self, path: str) -> Tensor: def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist - pathname = os.path.join(self.root, "**", self.filename_glob) + pathname = os.path.join(self.root, '**', self.filename_glob) for fname in glob.iglob(pathname, recursive=True): - if not fname.endswith(".zip"): + if not fname.endswith('.zip'): return # Check if the zip files have already been downloaded @@ -308,32 +308,32 @@ def plot( except ValueError as e: raise RGBBandsMissingError() from e - def get_masked(img: Tensor) -> "np.typing.NDArray[np.uint8]": + def get_masked(img: Tensor) -> 'np.typing.NDArray[np.uint8]': rgb_img = img[rgb_indices].float().numpy() per02 = np.percentile(rgb_img, 2) per98 = np.percentile(rgb_img, 98) rgb_img = (np.clip((rgb_img - per02) / (per98 - per02), 0, 1) * 255).astype( np.uint8 ) - array: "np.typing.NDArray[np.uint8]" = draw_semantic_segmentation_masks( + array: 'np.typing.NDArray[np.uint8]' = draw_semantic_segmentation_masks( torch.from_numpy(rgb_img), - sample["mask"], + sample['mask'], alpha=alpha, colors=self.colormap, ) return array - image1 = get_masked(sample["image1"]) - image2 = get_masked(sample["image2"]) + image1 = get_masked(sample['image1']) + image2 = get_masked(sample['image2']) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image1) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(image2) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Pre change") - axs[1].set_title("Post change") + axs[0].set_title('Pre change') + axs[1].set_title('Post change') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py index 88d95884c2e..a3d3f62be71 100644 --- a/torchgeo/datasets/pastis.py +++ b/torchgeo/datasets/pastis.py @@ -70,26 +70,26 @@ class PASTIS(NonGeoDataset): """ classes = [ - "background", # all non-agricultural land - "meadow", - "soft_winter_wheat", - "corn", - "winter_barley", - "winter_rapeseed", - "spring_barley", - "sunflower", - "grapevine", - "beet", - "winter_triticale", - "winter_durum_wheat", - "fruits_vegetables_flowers", - "potatoes", - "leguminous_fodder", - "soybeans", - "orchard", - "mixed_cereal", - "sorghum", - "void_label", # for parcels mostly outside their patch + 'background', # all non-agricultural land + 'meadow', + 'soft_winter_wheat', + 'corn', + 'winter_barley', + 'winter_rapeseed', + 'spring_barley', + 'sunflower', + 'grapevine', + 'beet', + 'winter_triticale', + 'winter_durum_wheat', + 'fruits_vegetables_flowers', + 'potatoes', + 'leguminous_fodder', + 'soybeans', + 'orchard', + 'mixed_cereal', + 'sorghum', + 'void_label', # for parcels mostly outside their patch ] cmap = { 0: (0, 0, 0, 255), @@ -113,24 +113,24 @@ class PASTIS(NonGeoDataset): 18: (23, 190, 207, 255), 19: (255, 255, 255, 255), } - directory = "PASTIS-R" - filename = "PASTIS-R.zip" - url = "https://zenodo.org/record/5735646/files/PASTIS-R.zip?download=1" - md5 = "4887513d6c2d2b07fa935d325bd53e09" + directory = 'PASTIS-R' + filename = 'PASTIS-R.zip' + url = 'https://zenodo.org/record/5735646/files/PASTIS-R.zip?download=1' + md5 = '4887513d6c2d2b07fa935d325bd53e09' prefix = { - "s2": os.path.join("DATA_S2", "S2_"), - "s1a": os.path.join("DATA_S1A", "S1A_"), - "s1d": os.path.join("DATA_S1D", "S1D_"), - "semantic": os.path.join("ANNOTATIONS", "TARGET_"), - "instance": os.path.join("INSTANCE_ANNOTATIONS", "INSTANCES_"), + 's2': os.path.join('DATA_S2', 'S2_'), + 's1a': os.path.join('DATA_S1A', 'S1A_'), + 's1d': os.path.join('DATA_S1D', 'S1D_'), + 'semantic': os.path.join('ANNOTATIONS', 'TARGET_'), + 'instance': os.path.join('INSTANCE_ANNOTATIONS', 'INSTANCES_'), } def __init__( self, - root: str = "data", + root: str = 'data', folds: Sequence[int] = (1, 2, 3, 4, 5), - bands: str = "s2", - mode: str = "semantic", + bands: str = 's2', + mode: str = 'semantic', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -154,8 +154,8 @@ def __init__( """ for fold in folds: assert 1 <= fold <= 5 - assert bands in ["s1a", "s1d", "s2"] - assert mode in ["semantic", "instance"] + assert bands in ['s1a', 's1d', 's2'] + assert mode in ['semantic', 'instance'] self.root = root self.folds = folds self.bands = bands @@ -187,12 +187,12 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ image = self._load_image(index) - if self.mode == "semantic": + if self.mode == 'semantic': mask = self._load_semantic_targets(index) - sample = {"image": image, "mask": mask} - elif self.mode == "instance": + sample = {'image': image, 'mask': mask} + elif self.mode == 'instance': mask, boxes, labels = self._load_instance_targets(index) - sample = {"image": image, "mask": mask, "boxes": boxes, "label": labels} + sample = {'image': image, 'mask': mask, 'boxes': boxes, 'label': labels} if self.transforms is not None: sample = self.transforms(sample) @@ -233,7 +233,7 @@ def _load_semantic_targets(self, index: int) -> Tensor: """ # See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # noqa: E501 # even though the mask file is 3 bands, we just select the first band - array = np.load(self.files[index]["semantic"])[0].astype(np.uint8) + array = np.load(self.files[index]['semantic'])[0].astype(np.uint8) tensor = torch.from_numpy(array).long() return tensor @@ -246,8 +246,8 @@ def _load_instance_targets(self, index: int) -> tuple[Tensor, Tensor, Tensor]: Returns: the instance segmentation mask, box, and label for each instance """ - mask_array = np.load(self.files[index]["semantic"])[0] - instance_array = np.load(self.files[index]["instance"]) + mask_array = np.load(self.files[index]['semantic'])[0] + instance_array = np.load(self.files[index]['instance']) mask_tensor = torch.from_numpy(mask_array) instance_tensor = torch.from_numpy(instance_array) @@ -289,23 +289,23 @@ def _load_files(self) -> list[dict[str, str]]: list of dicts containing image and semantic/instance target file paths """ self.idxs = [] - metadata_fn = os.path.join(self.root, self.directory, "metadata.geojson") + metadata_fn = os.path.join(self.root, self.directory, 'metadata.geojson') with fiona.open(metadata_fn) as f: for row in f: - fold = int(row["properties"]["Fold"]) + fold = int(row['properties']['Fold']) if fold in self.folds: - self.idxs.append(row["properties"]["ID_PATCH"]) + self.idxs.append(row['properties']['ID_PATCH']) files = [] for i in self.idxs: - path = os.path.join(self.root, self.directory, "{}") + str(i) + ".npy" + path = os.path.join(self.root, self.directory, '{}') + str(i) + '.npy' files.append( { - "s2": path.format(self.prefix["s2"]), - "s1a": path.format(self.prefix["s1a"]), - "s1d": path.format(self.prefix["s1d"]), - "semantic": path.format(self.prefix["semantic"]), - "instance": path.format(self.prefix["instance"]), + 's2': path.format(self.prefix['s2']), + 's1a': path.format(self.prefix['s1a']), + 's1d': path.format(self.prefix['s1d']), + 'semantic': path.format(self.prefix['semantic']), + 'instance': path.format(self.prefix['instance']), } ) return files @@ -321,7 +321,7 @@ def _verify(self) -> None: filepath = os.path.join(self.root, self.filename) if os.path.exists(filepath): if self.checksum and not check_integrity(filepath, self.md5): - raise RuntimeError("Dataset found, but corrupted.") + raise RuntimeError('Dataset found, but corrupted.') extract_archive(filepath) return @@ -359,42 +359,42 @@ def plot( a matplotlib Figure with the rendered sample """ # Keep the RGB bands and convert to T x H x W x C format - images = sample["image"][:, [2, 1, 0], :, :].numpy().transpose(0, 2, 3, 1) - mask = sample["mask"].numpy() + images = sample['image'][:, [2, 1, 0], :, :].numpy().transpose(0, 2, 3, 1) + mask = sample['mask'].numpy() - if self.mode == "instance": - label = sample["label"] + if self.mode == 'instance': + label = sample['label'] mask = label[mask.argmax(axis=0)].numpy() num_panels = 3 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - predictions = sample["prediction"].numpy() + predictions = sample['prediction'].numpy() num_panels += 1 - if self.mode == "instance": + if self.mode == 'instance': predictions = predictions.argmax(axis=0) - label = sample["prediction_labels"] + label = sample['prediction_labels'] predictions = label[predictions].numpy() fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 4)) axs[0].imshow(images[0] / 5000) axs[1].imshow(images[1] / 5000) - axs[2].imshow(mask, vmin=0, vmax=19, cmap=self._cmap, interpolation="none") - axs[0].axis("off") - axs[1].axis("off") - axs[2].axis("off") + axs[2].imshow(mask, vmin=0, vmax=19, cmap=self._cmap, interpolation='none') + axs[0].axis('off') + axs[1].axis('off') + axs[2].axis('off') if showing_predictions: axs[3].imshow( - predictions, vmin=0, vmax=19, cmap=self._cmap, interpolation="none" + predictions, vmin=0, vmax=19, cmap=self._cmap, interpolation='none' ) - axs[3].axis("off") + axs[3].axis('off') if show_titles: - axs[0].set_title("Image 0") - axs[1].set_title("Image 1") - axs[2].set_title("Mask") + axs[0].set_title('Image 0') + axs[1].set_title('Image 1') + axs[2].set_title('Mask') if showing_predictions: - axs[3].set_title("Prediction") + axs[3].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/patternnet.py b/torchgeo/datasets/patternnet.py index ba445d6711c..2b6bb8d77cb 100644 --- a/torchgeo/datasets/patternnet.py +++ b/torchgeo/datasets/patternnet.py @@ -77,14 +77,14 @@ class PatternNet(NonGeoClassificationDataset): * https://doi.org/10.1016/j.isprsjprs.2018.01.004 """ - url = "https://drive.google.com/file/d/127lxXYqzO6Bd0yZhvEbgIfz95HaEnr9K" - md5 = "96d54b3224c5350a98d55d5a7e6984ad" - filename = "PatternNet.zip" - directory = os.path.join("PatternNet", "images") + url = 'https://drive.google.com/file/d/127lxXYqzO6Bd0yZhvEbgIfz95HaEnr9K' + md5 = '96d54b3224c5350a98d55d5a7e6984ad' + filename = 'PatternNet.zip' + directory = os.path.join('PatternNet', 'images') def __init__( self, - root: str = "data", + root: str = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -160,21 +160,21 @@ def plot( .. versionadded:: 0.2 """ - image, label = sample["image"], cast(int, sample["label"].item()) + image, label = sample['image'], cast(int, sample['label'].item()) - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = cast(int, sample["prediction"].item()) + prediction = cast(int, sample['prediction'].item()) fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(image.permute(1, 2, 0)) - ax.axis("off") + ax.axis('off') if show_titles: - title = f"Label: {self.classes[label]}" + title = f'Label: {self.classes[label]}' if showing_predictions: - title += f"\nPrediction: {self.classes[prediction]}" + title += f'\nPrediction: {self.classes[prediction]}' ax.set_title(title) if suptitle is not None: diff --git a/torchgeo/datasets/potsdam.py b/torchgeo/datasets/potsdam.py index b1b5f21541f..b54cc141301 100644 --- a/torchgeo/datasets/potsdam.py +++ b/torchgeo/datasets/potsdam.py @@ -55,60 +55,60 @@ class Potsdam2D(NonGeoDataset): .. versionadded:: 0.2 """ # noqa: E501 - filenames = ["4_Ortho_RGBIR.zip", "5_Labels_all.zip"] - md5s = ["c4a8f7d8c7196dd4eba4addd0aae10c1", "cf7403c1a97c0d279414db"] - image_root = "4_Ortho_RGBIR" + filenames = ['4_Ortho_RGBIR.zip', '5_Labels_all.zip'] + md5s = ['c4a8f7d8c7196dd4eba4addd0aae10c1', 'cf7403c1a97c0d279414db'] + image_root = '4_Ortho_RGBIR' splits = { - "train": [ - "top_potsdam_2_10", - "top_potsdam_2_11", - "top_potsdam_2_12", - "top_potsdam_3_10", - "top_potsdam_3_11", - "top_potsdam_3_12", - "top_potsdam_4_10", - "top_potsdam_4_11", - "top_potsdam_4_12", - "top_potsdam_5_10", - "top_potsdam_5_11", - "top_potsdam_5_12", - "top_potsdam_6_10", - "top_potsdam_6_11", - "top_potsdam_6_12", - "top_potsdam_6_7", - "top_potsdam_6_8", - "top_potsdam_6_9", - "top_potsdam_7_10", - "top_potsdam_7_11", - "top_potsdam_7_12", - "top_potsdam_7_7", - "top_potsdam_7_8", - "top_potsdam_7_9", + 'train': [ + 'top_potsdam_2_10', + 'top_potsdam_2_11', + 'top_potsdam_2_12', + 'top_potsdam_3_10', + 'top_potsdam_3_11', + 'top_potsdam_3_12', + 'top_potsdam_4_10', + 'top_potsdam_4_11', + 'top_potsdam_4_12', + 'top_potsdam_5_10', + 'top_potsdam_5_11', + 'top_potsdam_5_12', + 'top_potsdam_6_10', + 'top_potsdam_6_11', + 'top_potsdam_6_12', + 'top_potsdam_6_7', + 'top_potsdam_6_8', + 'top_potsdam_6_9', + 'top_potsdam_7_10', + 'top_potsdam_7_11', + 'top_potsdam_7_12', + 'top_potsdam_7_7', + 'top_potsdam_7_8', + 'top_potsdam_7_9', ], - "test": [ - "top_potsdam_5_15", - "top_potsdam_6_15", - "top_potsdam_6_13", - "top_potsdam_3_13", - "top_potsdam_4_14", - "top_potsdam_6_14", - "top_potsdam_5_14", - "top_potsdam_2_13", - "top_potsdam_4_15", - "top_potsdam_2_14", - "top_potsdam_5_13", - "top_potsdam_4_13", - "top_potsdam_3_14", - "top_potsdam_7_13", + 'test': [ + 'top_potsdam_5_15', + 'top_potsdam_6_15', + 'top_potsdam_6_13', + 'top_potsdam_3_13', + 'top_potsdam_4_14', + 'top_potsdam_6_14', + 'top_potsdam_5_14', + 'top_potsdam_2_13', + 'top_potsdam_4_15', + 'top_potsdam_2_14', + 'top_potsdam_5_13', + 'top_potsdam_4_13', + 'top_potsdam_3_14', + 'top_potsdam_7_13', ], } classes = [ - "Clutter/background", - "Impervious surfaces", - "Building", - "Low Vegetation", - "Tree", - "Car", + 'Clutter/background', + 'Impervious surfaces', + 'Building', + 'Low Vegetation', + 'Tree', + 'Car', ] colormap = [ (255, 0, 0), @@ -121,8 +121,8 @@ class Potsdam2D(NonGeoDataset): def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: @@ -149,8 +149,8 @@ def __init__( self.files = [] for name in self.splits[split]: - image = os.path.join(root, self.image_root, name) + "_RGBIR.tif" - mask = os.path.join(root, name) + "_label.tif" + image = os.path.join(root, self.image_root, name) + '_RGBIR.tif' + mask = os.path.join(root, name) + '_label.tif' if os.path.exists(image) and os.path.exists(mask): self.files.append(dict(image=image, mask=mask)) @@ -165,7 +165,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ image = self._load_image(index) mask = self._load_target(index) - sample = {"image": image, "mask": mask} + sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -189,7 +189,7 @@ def _load_image(self, index: int) -> Tensor: Returns: the image """ - path = self.files[index]["image"] + path = self.files[index]['image'] with rasterio.open(path) as f: array = f.read() tensor = torch.from_numpy(array).float() @@ -204,9 +204,9 @@ def _load_target(self, index: int) -> Tensor: Returns: the target mask """ - path = self.files[index]["mask"] + path = self.files[index]['mask'] with Image.open(path) as img: - array: "np.typing.NDArray[np.uint8]" = np.array(img.convert("RGB")) + array: 'np.typing.NDArray[np.uint8]' = np.array(img.convert('RGB')) array = rgb_to_mask(array, self.colormap) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW @@ -225,7 +225,7 @@ def _verify(self) -> None: filepath = os.path.join(self.root, filename) if os.path.isfile(filepath): if self.checksum and not check_integrity(filepath, md5): - raise RuntimeError("Dataset found, but corrupted.") + raise RuntimeError('Dataset found, but corrupted.') exists.append(True) extract_archive(filepath) else: @@ -256,13 +256,13 @@ def plot( """ ncols = 1 image1 = draw_semantic_segmentation_masks( - sample["image"][:3], sample["mask"], alpha=alpha, colors=self.colormap + sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap ) - if "prediction" in sample: + if 'prediction' in sample: ncols += 1 image2 = draw_semantic_segmentation_masks( - sample["image"][:3], - sample["prediction"], + sample['image'][:3], + sample['prediction'], alpha=alpha, colors=self.colormap, ) @@ -274,15 +274,15 @@ def plot( ax0 = axs ax0.imshow(image1) - ax0.axis("off") + ax0.axis('off') if ncols > 1: ax1.imshow(image2) - ax1.axis("off") + ax1.axis('off') if show_titles: - ax0.set_title("Ground Truth") + ax0.set_title('Ground Truth') if ncols > 1: - ax1.set_title("Predictions") + ax1.set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/prisma.py b/torchgeo/datasets/prisma.py index 442706578ce..c2e6e66b598 100644 --- a/torchgeo/datasets/prisma.py +++ b/torchgeo/datasets/prisma.py @@ -63,7 +63,7 @@ class PRISMA(RasterDataset): # * 6.3.6: L0A Product Naming Convention # * 7.5: L1 Product Naming Convention # * 7.8.5: FKDP, GKDP, ICU-KDP and CDP Products Naming Convention - filename_glob = "PRS_*" + filename_glob = 'PRS_*' filename_regex = r""" ^PRS _(?P[A-Z\d]+) @@ -75,7 +75,7 @@ class PRISMA(RasterDataset): (_(?P\d))? \. """ - date_format = "%Y%m%d%H%M%S" + date_format = '%Y%m%d%H%M%S' def plot( self, @@ -95,15 +95,15 @@ def plot( """ # RGB band indices based on https://doi.org/10.3390/rs14164080 rgb_indices = [34, 23, 11] - image = sample["image"][rgb_indices].permute(1, 2, 0).float() + image = sample['image'][rgb_indices].permute(1, 2, 0).float() image = percentile_normalization(image, axis=(0, 1)) fig, ax = plt.subplots(1, 1, figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: - ax.set_title("Image") + ax.set_title('Image') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py index c345c79e7a9..78d1664fcb8 100644 --- a/torchgeo/datasets/reforestree.py +++ b/torchgeo/datasets/reforestree.py @@ -60,15 +60,15 @@ class ReforesTree(NonGeoDataset): .. versionadded:: 0.3 """ - classes = ["other", "banana", "cacao", "citrus", "fruit", "timber"] - url = "https://zenodo.org/record/6813783/files/reforesTree.zip?download=1" + classes = ['other', 'banana', 'cacao', 'citrus', 'fruit', 'timber'] + url = 'https://zenodo.org/record/6813783/files/reforesTree.zip?download=1' - md5 = "f6a4a1d8207aeaa5fbab7b21b683a302" - zipfilename = "reforesTree.zip" + md5 = 'f6a4a1d8207aeaa5fbab7b21b683a302' + zipfilename = 'reforesTree.zip' def __init__( self, - root: str = "data", + root: str = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -94,7 +94,7 @@ def __init__( self.files = self._load_files(self.root) - self.annot_df = pd.read_csv(os.path.join(root, "mapping", "final_dataset.csv")) + self.annot_df = pd.read_csv(os.path.join(root, 'mapping', 'final_dataset.csv')) self.class2idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} @@ -113,7 +113,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: boxes, labels, agb = self._load_target(filepath) - sample = {"image": image, "boxes": boxes, "label": labels, "agb": agb} + sample = {'image': image, 'boxes': boxes, 'label': labels, 'agb': agb} if self.transforms is not None: sample = self.transforms(sample) @@ -137,7 +137,7 @@ def _load_files(self, root: str) -> list[str]: Returns: list of dicts containing paths for each pair of image, annotation """ - image_paths = sorted(glob.glob(os.path.join(root, "tiles", "**", "*.png"))) + image_paths = sorted(glob.glob(os.path.join(root, 'tiles', '**', '*.png'))) return image_paths @@ -151,7 +151,7 @@ def _load_image(self, path: str) -> Tensor: the image """ with Image.open(path) as img: - array: "np.typing.NDArray[np.uint8]" = np.array(img) + array: 'np.typing.NDArray[np.uint8]' = np.array(img) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -166,26 +166,26 @@ def _load_target(self, filepath: str) -> tuple[Tensor, ...]: Returns: dictionary containing boxes, label, and agb value """ - tile_df = self.annot_df[self.annot_df["img_path"] == os.path.basename(filepath)] + tile_df = self.annot_df[self.annot_df['img_path'] == os.path.basename(filepath)] - boxes = torch.Tensor(tile_df[["xmin", "ymin", "xmax", "ymax"]].values.tolist()) + boxes = torch.Tensor(tile_df[['xmin', 'ymin', 'xmax', 'ymax']].values.tolist()) labels = torch.Tensor( - [self.class2idx[label] for label in tile_df["group"].tolist()] + [self.class2idx[label] for label in tile_df['group'].tolist()] ) - agb = torch.Tensor(tile_df["AGB"].tolist()) + agb = torch.Tensor(tile_df['AGB'].tolist()) return boxes, labels, agb def _verify(self) -> None: """Checks the integrity of the dataset structure.""" - filepaths = [os.path.join(self.root, dir) for dir in ["tiles", "mapping"]] + filepaths = [os.path.join(self.root, dir) for dir in ['tiles', 'mapping']] if all([os.path.exists(filepath) for filepath in filepaths]): return filepath = os.path.join(self.root, self.zipfilename) if os.path.isfile(filepath): if self.checksum and not check_integrity(filepath, self.md5): - raise RuntimeError("Dataset found, but corrupted.") + raise RuntimeError('Dataset found, but corrupted.') extract_archive(filepath) return @@ -221,9 +221,9 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - image = sample["image"].permute((1, 2, 0)).numpy() + image = sample['image'].permute((1, 2, 0)).numpy() ncols = 1 - showing_predictions = "prediction_boxes" in sample + showing_predictions = 'prediction_boxes' in sample if showing_predictions: ncols += 1 @@ -232,7 +232,7 @@ def plot( axs = [axs] axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') bboxes = [ patches.Rectangle( @@ -240,20 +240,20 @@ def plot( bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=1, - edgecolor="r", - facecolor="none", + edgecolor='r', + facecolor='none', ) - for bbox in sample["boxes"].numpy() + for bbox in sample['boxes'].numpy() ] for bbox in bboxes: axs[0].add_patch(bbox) if show_titles: - axs[0].set_title("Ground Truth") + axs[0].set_title('Ground Truth') if showing_predictions: axs[1].imshow(image) - axs[1].axis("off") + axs[1].axis('off') pred_bboxes = [ patches.Rectangle( @@ -261,16 +261,16 @@ def plot( bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=1, - edgecolor="r", - facecolor="none", + edgecolor='r', + facecolor='none', ) - for bbox in sample["prediction_boxes"].numpy() + for bbox in sample['prediction_boxes'].numpy() ] for bbox in pred_bboxes: axs[1].add_patch(bbox) if show_titles: - axs[1].set_title("Predictions") + axs[1].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index 7cb1fcf0dcd..64c2e089203 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -92,27 +92,27 @@ class RESISC45(NonGeoClassificationDataset): * https://doi.org/10.1109/jproc.2017.2675998 """ - url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv" - md5 = "d824acb73957502b00efd559fc6cfbbb" - filename = "NWPU-RESISC45.rar" - directory = "NWPU-RESISC45" + url = 'https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv' + md5 = 'd824acb73957502b00efd559fc6cfbbb' + filename = 'NWPU-RESISC45.rar' + directory = 'NWPU-RESISC45' - splits = ["train", "val", "test"] + splits = ['train', 'val', 'test'] split_urls = { - "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501 - "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501 - "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501 + 'train': 'https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt', # noqa: E501 + 'val': 'https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt', # noqa: E501 + 'test': 'https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt', # noqa: E501 } split_md5s = { - "train": "b5a4c05a37de15e4ca886696a85c403e", - "val": "a0770cee4c5ca20b8c32bbd61e114805", - "test": "3dda9e4988b47eb1de9f07993653eb08", + 'train': 'b5a4c05a37de15e4ca886696a85c403e', + 'val': 'a0770cee4c5ca20b8c32bbd61e114805', + 'test': '3dda9e4988b47eb1de9f07993653eb08', } def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -137,7 +137,7 @@ def __init__( self._verify() valid_fns = set() - with open(os.path.join(self.root, f"resisc45-{split}.txt")) as f: + with open(os.path.join(self.root, f'resisc45-{split}.txt')) as f: for fn in f: valid_fns.add(fn.strip()) @@ -183,7 +183,7 @@ def _download(self) -> None: download_url( self.split_urls[split], self.root, - filename=f"resisc45-{split}.txt", + filename=f'resisc45-{split}.txt', md5=self.split_md5s[split] if self.checksum else None, ) @@ -210,22 +210,22 @@ def plot( .. versionadded:: 0.2 """ - image = np.rollaxis(sample["image"].numpy(), 0, 3) - label = cast(int, sample["label"].item()) + image = np.rollaxis(sample['image'].numpy(), 0, 3) + label = cast(int, sample['label'].item()) label_class = self.classes[label] - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = cast(int, sample["prediction"].item()) + prediction = cast(int, sample['prediction'].item()) prediction_class = self.classes[prediction] fig, ax = plt.subplots(figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: - title = f"Label: {label_class}" + title = f'Label: {label_class}' if showing_predictions: - title += f"\nPrediction: {prediction_class}" + title += f'\nPrediction: {prediction_class}' ax.set_title(title) if suptitle is not None: diff --git a/torchgeo/datasets/rwanda_field_boundary.py b/torchgeo/datasets/rwanda_field_boundary.py index 102067e5f9b..8f8f93637a9 100644 --- a/torchgeo/datasets/rwanda_field_boundary.py +++ b/torchgeo/datasets/rwanda_field_boundary.py @@ -57,38 +57,38 @@ class RwandaFieldBoundary(NonGeoDataset): .. versionadded:: 0.5 """ - dataset_id = "nasa_rwanda_field_boundary_competition" + dataset_id = 'nasa_rwanda_field_boundary_competition' collection_ids = [ - "nasa_rwanda_field_boundary_competition_source_train", - "nasa_rwanda_field_boundary_competition_labels_train", - "nasa_rwanda_field_boundary_competition_source_test", + 'nasa_rwanda_field_boundary_competition_source_train', + 'nasa_rwanda_field_boundary_competition_labels_train', + 'nasa_rwanda_field_boundary_competition_source_test', ] - number_of_patches_per_split = {"train": 57, "test": 13} + number_of_patches_per_split = {'train': 57, 'test': 13} filenames = { - "train_images": "nasa_rwanda_field_boundary_competition_source_train.tar.gz", - "test_images": "nasa_rwanda_field_boundary_competition_source_test.tar.gz", - "train_labels": "nasa_rwanda_field_boundary_competition_labels_train.tar.gz", + 'train_images': 'nasa_rwanda_field_boundary_competition_source_train.tar.gz', + 'test_images': 'nasa_rwanda_field_boundary_competition_source_test.tar.gz', + 'train_labels': 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz', } md5s = { - "train_images": "1f9ec08038218e67e11f82a86849b333", - "test_images": "17bb0e56eedde2e7a43c57aa908dc125", - "train_labels": "10e4eb761523c57b6d3bdf9394004f5f", + 'train_images': '1f9ec08038218e67e11f82a86849b333', + 'test_images': '17bb0e56eedde2e7a43c57aa908dc125', + 'train_labels': '10e4eb761523c57b6d3bdf9394004f5f', } - dates = ("2021_03", "2021_04", "2021_08", "2021_10", "2021_11", "2021_12") + dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12') - all_bands = ("B01", "B02", "B03", "B04") - rgb_bands = ("B03", "B02", "B01") + all_bands = ('B01', 'B02', 'B03', 'B04') + rgb_bands = ('B03', 'B02', 'B01') - classes = ["No field-boundary", "Field-boundary"] + classes = ['No field-boundary', 'Field-boundary'] - splits = ["train", "test"] + splits = ['train', 'test'] def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -113,7 +113,7 @@ def __init__( self._validate_bands(bands) assert split in self.splits if download and api_key is None: - raise RuntimeError("Must provide an API key to download the dataset") + raise RuntimeError('Must provide an API key to download the dataset') self.root = root self.bands = bands self.transforms = transforms @@ -132,9 +132,9 @@ def __init__( for band in self.bands: fn = os.path.join( self.root, - f"nasa_rwanda_field_boundary_competition_source_{split}", - f"nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}", # noqa: E501 - f"{band}.tif", + f'nasa_rwanda_field_boundary_competition_source_{split}', + f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501 + f'{band}.tif', ) patch.append(fn) dates.append(patch) @@ -142,9 +142,9 @@ def __init__( self.mask_filenames.append( os.path.join( self.root, - f"nasa_rwanda_field_boundary_competition_labels_{split}", - f"nasa_rwanda_field_boundary_competition_labels_{split}_{i:02d}", - "raster_labels.tif", + f'nasa_rwanda_field_boundary_competition_labels_{split}', + f'nasa_rwanda_field_boundary_competition_labels_{split}_{i:02d}', + 'raster_labels.tif', ) ) @@ -169,13 +169,13 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: imgs.append(bands) img = torch.from_numpy(np.array(imgs)) - sample = {"image": img} + sample = {'image': img} - if self.split == "train": + if self.split == 'train': with rasterio.open(mask_fn) as f: mask = f.read(1) mask = torch.from_numpy(mask) - sample["mask"] = mask + sample['mask'] = mask if self.transforms is not None: sample = self.transforms(sample) @@ -209,7 +209,7 @@ def _verify(self) -> None: checks = [] for split, num_patches in self.number_of_patches_per_split.items(): path = os.path.join( - self.root, f"nasa_rwanda_field_boundary_competition_source_{split}" + self.root, f'nasa_rwanda_field_boundary_competition_source_{split}' ) if os.path.exists(path): num_files = len(os.listdir(path)) @@ -223,11 +223,11 @@ def _verify(self) -> None: # Check if tar file already exists (if so then extract) have_all_files = True - for group in ["train_images", "train_labels", "test_images"]: + for group in ['train_images', 'train_labels', 'test_images']: filepath = os.path.join(self.root, self.filenames[group]) if os.path.exists(filepath): if self.checksum and not check_integrity(filepath, self.md5s[group]): - raise RuntimeError("Dataset found, but corrupted.") + raise RuntimeError('Dataset found, but corrupted.') extract_archive(filepath) else: have_all_files = False @@ -246,10 +246,10 @@ def _download(self) -> None: for collection_id in self.collection_ids: download_radiant_mlhub_collection(collection_id, self.root, self.api_key) - for group in ["train_images", "train_labels", "test_images"]: + for group in ['train_images', 'train_labels', 'test_images']: filepath = os.path.join(self.root, self.filenames[group]) if self.checksum and not check_integrity(filepath, self.md5s[group]): - raise RuntimeError("Dataset not found or corrupted.") + raise RuntimeError('Dataset not found or corrupted.') extract_archive(filepath, self.root) def plot( @@ -280,40 +280,40 @@ def plot( else: raise RGBBandsMissingError() - num_time_points = sample["image"].shape[0] + num_time_points = sample['image'].shape[0] assert time_step < num_time_points - image = np.rollaxis(sample["image"][time_step, rgb_indices].numpy(), 0, 3) + image = np.rollaxis(sample['image'][time_step, rgb_indices].numpy(), 0, 3) image = np.clip(image / 2000, 0, 1) - if "mask" in sample: - mask = sample["mask"].numpy() + if 'mask' in sample: + mask = sample['mask'].numpy() else: mask = np.zeros_like(image) num_panels = 2 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - predictions = sample["prediction"].numpy() + predictions = sample['prediction'].numpy() num_panels += 1 fig, axs = plt.subplots(ncols=num_panels, figsize=(4 * num_panels, 4)) axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') if show_titles: - axs[0].set_title(f"t={time_step}") + axs[0].set_title(f't={time_step}') - axs[1].imshow(mask, vmin=0, vmax=1, interpolation="none") - axs[1].axis("off") + axs[1].imshow(mask, vmin=0, vmax=1, interpolation='none') + axs[1].axis('off') if show_titles: - axs[1].set_title("Mask") + axs[1].set_title('Mask') if showing_predictions: - axs[2].imshow(predictions, vmin=0, vmax=1, interpolation="none") - axs[2].axis("off") + axs[2].imshow(predictions, vmin=0, vmax=1, interpolation='none') + axs[2].axis('off') if show_titles: - axs[2].set_title("Predictions") + axs[2].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/seasonet.py b/torchgeo/datasets/seasonet.py index 5f2918f748e..dc5cd85d599 100644 --- a/torchgeo/datasets/seasonet.py +++ b/torchgeo/datasets/seasonet.py @@ -92,87 +92,87 @@ class SeasoNet(NonGeoDataset): metadata = [ { - "name": "spring", - "ext": ".zip", - "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip", # noqa: E501 - "md5": "de4cdba7b6196aff624073991b187561", + 'name': 'spring', + 'ext': '.zip', + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip', # noqa: E501 + 'md5': 'de4cdba7b6196aff624073991b187561', }, { - "name": "summer", - "ext": ".zip", - "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip", # noqa: E501 - "md5": "6a54d4e134d27ae4eb03f180ee100550", + 'name': 'summer', + 'ext': '.zip', + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip', # noqa: E501 + 'md5': '6a54d4e134d27ae4eb03f180ee100550', }, { - "name": "fall", - "ext": ".zip", - "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip", # noqa: E501 - "md5": "5f94920fe41a63c6bfbab7295f7d6b95", + 'name': 'fall', + 'ext': '.zip', + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip', # noqa: E501 + 'md5': '5f94920fe41a63c6bfbab7295f7d6b95', }, { - "name": "winter", - "ext": ".zip", - "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip", # noqa: E501 - "md5": "dc5e3e09e52ab5c72421b1e3186c9a48", + 'name': 'winter', + 'ext': '.zip', + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip', # noqa: E501 + 'md5': 'dc5e3e09e52ab5c72421b1e3186c9a48', }, { - "name": "snow", - "ext": ".zip", - "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip", # noqa: E501 - "md5": "e1b300994143f99ebb03f51d6ab1cbe6", + 'name': 'snow', + 'ext': '.zip', + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip', # noqa: E501 + 'md5': 'e1b300994143f99ebb03f51d6ab1cbe6', }, { - "name": "splits", - "ext": ".zip", - "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip", # noqa: E501 - "md5": "e4ec4a18bc4efc828f0944a7cf4d5fed", + 'name': 'splits', + 'ext': '.zip', + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip', # noqa: E501 + 'md5': 'e4ec4a18bc4efc828f0944a7cf4d5fed', }, { - "name": "meta.csv", - "ext": "", - "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv", # noqa: E501 - "md5": "43ea07974936a6bf47d989c32e16afe7", + 'name': 'meta.csv', + 'ext': '', + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv', # noqa: E501 + 'md5': '43ea07974936a6bf47d989c32e16afe7', }, ] classes = [ - "Continuous urban fabric", - "Discontinuous urban fabric", - "Industrial or commercial units", - "Road and rail networks and associated land", - "Port areas", - "Airports", - "Mineral extraction sites", - "Dump sites", - "Construction sites", - "Green urban areas", - "Sport and leisure facilities", - "Non-irrigated arable land", - "Vineyards", - "Fruit trees and berry plantations", - "Pastures", - "Broad-leaved forest", - "Coniferous forest", - "Mixed forest", - "Natural grasslands", - "Moors and heathland", - "Transitional woodland/shrub", - "Beaches, dunes, sands", - "Bare rock", - "Sparsely vegetated areas", - "Inland marshes", - "Peat bogs", - "Salt marshes", - "Intertidal flats", - "Water courses", - "Water bodies", - "Coastal lagoons", - "Estuaries", - "Sea and ocean", + 'Continuous urban fabric', + 'Discontinuous urban fabric', + 'Industrial or commercial units', + 'Road and rail networks and associated land', + 'Port areas', + 'Airports', + 'Mineral extraction sites', + 'Dump sites', + 'Construction sites', + 'Green urban areas', + 'Sport and leisure facilities', + 'Non-irrigated arable land', + 'Vineyards', + 'Fruit trees and berry plantations', + 'Pastures', + 'Broad-leaved forest', + 'Coniferous forest', + 'Mixed forest', + 'Natural grasslands', + 'Moors and heathland', + 'Transitional woodland/shrub', + 'Beaches, dunes, sands', + 'Bare rock', + 'Sparsely vegetated areas', + 'Inland marshes', + 'Peat bogs', + 'Salt marshes', + 'Intertidal flats', + 'Water courses', + 'Water bodies', + 'Coastal lagoons', + 'Estuaries', + 'Sea and ocean', ] - all_seasons = {"Spring", "Summer", "Fall", "Winter", "Snow"} - all_bands = ("10m_RGB", "10m_IR", "20m", "60m") - band_nums = {"10m_RGB": 3, "10m_IR": 1, "20m": 6, "60m": 2} - splits = ["train", "val", "test"] + all_seasons = {'Spring', 'Summer', 'Fall', 'Winter', 'Snow'} + all_bands = ('10m_RGB', '10m_IR', '20m', '60m') + band_nums = {'10m_RGB': 3, '10m_IR': 1, '20m': 6, '60m': 2} + splits = ['train', 'val', 'test'] cmap = { 0: (230, 000, 77, 255), 1: (255, 000, 000, 255), @@ -212,8 +212,8 @@ class SeasoNet(NonGeoDataset): def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', seasons: Collection[str] = all_seasons, bands: Iterable[str] = all_bands, grids: Iterable[int] = [1, 2], @@ -261,35 +261,35 @@ def __init__( for b in bands: self.channels += self.band_nums[b] - csv = pd.read_csv(os.path.join(self.root, "meta.csv"), index_col="Index") + csv = pd.read_csv(os.path.join(self.root, 'meta.csv'), index_col='Index') if split is not None: # Filter entries by split split_csv = pd.read_csv( - os.path.join(self.root, f"splits/{split}.csv"), header=None + os.path.join(self.root, f'splits/{split}.csv'), header=None )[0] csv = csv.iloc[split_csv] # Filter entries by grids and seasons - csv = csv[csv["Grid"].isin(grids)] - csv = csv[csv["Season"].isin(seasons)] + csv = csv[csv['Grid'].isin(grids)] + csv = csv[csv['Season'].isin(seasons)] # Replace relative data paths with absolute paths - csv["Path"] = csv["Path"].apply( + csv['Path'] = csv['Path'].apply( lambda p: [os.path.join(self.root, p, os.path.basename(p))] ) if self.concat_seasons > 1: # Group entries by location - self.files = csv.groupby(["Latitude", "Longitude"]) - self.files = self.files["Path"].agg("sum") + self.files = csv.groupby(['Latitude', 'Longitude']) + self.files = self.files['Path'].agg('sum') # Remove entries with less than concat_seasons available seasons self.files = self.files[ self.files.apply(lambda d: len(d) >= self.concat_seasons) ] else: - self.files = csv["Path"] + self.files = csv['Path'] def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. @@ -303,7 +303,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ image = self._load_image(index) mask = self._load_target(index) - sample = {"image": image, "mask": mask} + sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -334,10 +334,10 @@ def _load_image(self, index: int) -> Tensor: for img_idx, path in enumerate(paths): bnd_idx = 0 for band in self.bands: - with rasterio.open(f"{path}_{band}.tif") as f: + with rasterio.open(f'{path}_{band}.tif') as f: array = f.read( out_shape=[f.count] + list(self.image_size), - out_dtype="int32", + out_dtype='int32', resampling=Resampling.bilinear, ) image = torch.from_numpy(array).float() @@ -356,7 +356,7 @@ def _load_target(self, index: int) -> Tensor: the target mask """ path = self.files.iloc[index][0] - with rasterio.open(f"{path}_labels.tif") as f: + with rasterio.open(f'{path}_labels.tif') as f: array = f.read() - 1 tensor = torch.from_numpy(array).squeeze().long() return tensor @@ -365,7 +365,7 @@ def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if all files already exist if all( - os.path.exists(os.path.join(self.root, file_info["name"])) + os.path.exists(os.path.join(self.root, file_info['name'])) for file_info in self.metadata ): return @@ -374,10 +374,10 @@ def _verify(self) -> None: missing = [] extractable = [] for file_info in self.metadata: - file_path = os.path.join(self.root, file_info["name"] + file_info["ext"]) + file_path = os.path.join(self.root, file_info['name'] + file_info['ext']) if not os.path.exists(file_path): missing.append(file_info) - elif file_info["ext"] == ".zip": + elif file_info['ext'] == '.zip': extractable.append(file_path) # Check if the user requested to download the dataset @@ -387,13 +387,13 @@ def _verify(self) -> None: # Download missing files for file_info in missing: download_url( - file_info["url"], + file_info['url'], self.root, - filename=file_info["name"] + file_info["ext"], - md5=file_info["md5"] if self.checksum else None, + filename=file_info['name'] + file_info['ext'], + md5=file_info['md5'] if self.checksum else None, ) - if file_info["ext"] == ".zip": - extractable.append(os.path.join(self.root, file_info["name"] + ".zip")) + if file_info['ext'] == '.zip': + extractable.append(os.path.join(self.root, file_info['name'] + '.zip')) # Extract downloaded files for file_path in extractable: @@ -421,22 +421,22 @@ def plot( Raises: RGBBandsMissingError: If *bands* does not include all RGB bands. """ - if "10m_RGB" not in self.bands: + if '10m_RGB' not in self.bands: raise RGBBandsMissingError() ncols = self.concat_seasons + 1 - images, mask = sample["image"], sample["mask"] - show_predictions = "prediction" in sample + images, mask = sample['image'], sample['mask'] + show_predictions = 'prediction' in sample if show_predictions: - prediction = sample["prediction"] + prediction = sample['prediction'] ncols += 1 plt_cmap = ListedColormap(np.array(list(self.cmap.values())) / 255) start = 0 for b in self.bands: - if b == "10m_RGB": + if b == '10m_RGB': break start += self.band_nums[b] rgb_indices = [start + s * self.channels for s in range(self.concat_seasons)] @@ -447,22 +447,22 @@ def plot( image = images[index : index + 3].permute(1, 2, 0).numpy() image = percentile_normalization(image) axs[ax].imshow(image) - axs[ax].axis("off") + axs[ax].axis('off') if show_titles: - axs[ax].set_title(f"Image {ax+1}") + axs[ax].set_title(f'Image {ax+1}') - axs[ax + 1].imshow(mask, vmin=0, vmax=32, cmap=plt_cmap, interpolation="none") - axs[ax + 1].axis("off") + axs[ax + 1].imshow(mask, vmin=0, vmax=32, cmap=plt_cmap, interpolation='none') + axs[ax + 1].axis('off') if show_titles: - axs[ax + 1].set_title("Mask") + axs[ax + 1].set_title('Mask') if show_predictions: axs[ax + 2].imshow( - prediction, vmin=0, vmax=32, cmap=plt_cmap, interpolation="none" + prediction, vmin=0, vmax=32, cmap=plt_cmap, interpolation='none' ) - axs[ax + 2].axis("off") + axs[ax + 2].axis('off') if show_titles: - axs[ax + 2].set_title("Prediction") + axs[ax + 2].set_title('Prediction') if show_legend: lgd = np.unique(mask) @@ -477,6 +477,6 @@ def plot( ) if suptitle is not None: - plt.suptitle(suptitle, size="xx-large") + plt.suptitle(suptitle, size='xx-large') return fig diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index faab3503843..a423c281fa2 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -43,40 +43,40 @@ class SeasonalContrastS2(NonGeoDataset): """ all_bands = [ - "B1", - "B2", - "B3", - "B4", - "B5", - "B6", - "B7", - "B8", - "B8A", - "B9", - "B11", - "B12", + 'B1', + 'B2', + 'B3', + 'B4', + 'B5', + 'B6', + 'B7', + 'B8', + 'B8A', + 'B9', + 'B11', + 'B12', ] - rgb_bands = ["B4", "B3", "B2"] + rgb_bands = ['B4', 'B3', 'B2'] metadata = { - "100k": { - "url": "https://zenodo.org/record/4728033/files/seco_100k.zip?download=1", - "md5": "ebf2d5e03adc6e657f9a69a20ad863e0", - "filename": "seco_100k.zip", - "directory": "seasonal_contrast_100k", + '100k': { + 'url': 'https://zenodo.org/record/4728033/files/seco_100k.zip?download=1', + 'md5': 'ebf2d5e03adc6e657f9a69a20ad863e0', + 'filename': 'seco_100k.zip', + 'directory': 'seasonal_contrast_100k', }, - "1m": { - "url": "https://zenodo.org/record/4728033/files/seco_1m.zip?download=1", - "md5": "187963d852d4d3ce6637743ec3a4bd9e", - "filename": "seco_1m.zip", - "directory": "seasonal_contrast_1m", + '1m': { + 'url': 'https://zenodo.org/record/4728033/files/seco_1m.zip?download=1', + 'md5': '187963d852d4d3ce6637743ec3a4bd9e', + 'filename': 'seco_1m.zip', + 'directory': 'seasonal_contrast_1m', }, } def __init__( self, - root: str = "data", - version: str = "100k", + root: str = 'data', + version: str = '100k', seasons: int = 1, bands: list[str] = rgb_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -130,14 +130,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Image shape changed from 5xCxHxW to SCxHxW """ root = os.path.join( - self.root, self.metadata[self.version]["directory"], f"{index:06}" + self.root, self.metadata[self.version]['directory'], f'{index:06}' ) subdirs = [f for f in os.listdir(root) if os.path.isdir(os.path.join(root, f))] subdirs = random.sample(subdirs, self.seasons) images = [self._load_patch(root, subdir) for subdir in subdirs] - sample = {"image": torch.cat(images)} + sample = {'image': torch.cat(images)} if self.transforms is not None: sample = self.transforms(sample) @@ -150,7 +150,7 @@ def __len__(self) -> int: Returns: length of the dataset """ - return (10**5 if self.version == "100k" else 10**6) // 5 + return (10**5 if self.version == '100k' else 10**6) // 5 def _load_patch(self, root: str, subdir: str) -> Tensor: """Load a single image patch. @@ -164,7 +164,7 @@ def _load_patch(self, root: str, subdir: str) -> Tensor: """ all_data = [] for band in self.bands: - fn = os.path.join(root, subdir, f"{band}.tif") + fn = os.path.join(root, subdir, f'{band}.tif') with rasterio.open(fn) as f: band_data = f.read(1).astype(np.float32) height, width = band_data.shape @@ -191,13 +191,13 @@ def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist directory_path = os.path.join( - self.root, self.metadata[self.version]["directory"] + self.root, self.metadata[self.version]['directory'] ) if os.path.exists(directory_path): return # Check if the zip files have already been downloaded - zip_path = os.path.join(self.root, self.metadata[self.version]["filename"]) + zip_path = os.path.join(self.root, self.metadata[self.version]['filename']) if os.path.exists(zip_path): self._extract() return @@ -213,16 +213,16 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" download_url( - self.metadata[self.version]["url"], + self.metadata[self.version]['url'], self.root, - filename=self.metadata[self.version]["filename"], - md5=self.metadata[self.version]["md5"] if self.checksum else None, + filename=self.metadata[self.version]['filename'], + md5=self.metadata[self.version]['md5'] if self.checksum else None, ) def _extract(self) -> None: """Extract the dataset.""" extract_archive( - os.path.join(self.root, self.metadata[self.version]["filename"]) + os.path.join(self.root, self.metadata[self.version]['filename']) ) def plot( @@ -247,7 +247,7 @@ def plot( .. versionadded:: 0.2 """ - if "prediction" in sample: + if 'prediction' in sample: raise ValueError("This dataset doesn't support plotting predictions") rgb_indices = [] @@ -263,12 +263,12 @@ def plot( indices = torch.tensor(rgb_indices) for i in range(self.seasons): - image = sample["image"][indices + i * len(self.bands)].numpy() + image = sample['image'][indices + i * len(self.bands)].numpy() image = np.rollaxis(image, 0, 3) image = percentile_normalization(image, 0, 100) axes[i].imshow(image) - axes[i].axis("off") + axes[i].axis('off') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index 7ff44d4174f..031ced4256a 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -70,108 +70,108 @@ class SEN12MS(NonGeoDataset): """ # noqa: E501 BAND_SETS: dict[str, tuple[str, ...]] = { - "all": ( - "VV", - "VH", - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B10", - "B11", - "B12", + 'all': ( + 'VV', + 'VH', + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B10', + 'B11', + 'B12', ), - "s1": ("VV", "VH"), - "s2-all": ( - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B10", - "B11", - "B12", + 's1': ('VV', 'VH'), + 's2-all': ( + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B10', + 'B11', + 'B12', ), - "s2-reduced": ("B02", "B03", "B04", "B08", "B10", "B11"), + 's2-reduced': ('B02', 'B03', 'B04', 'B08', 'B10', 'B11'), } band_names = ( - "VV", - "VH", - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B10", - "B11", - "B12", + 'VV', + 'VH', + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B10', + 'B11', + 'B12', ) - rgb_bands = ["B04", "B03", "B02"] + rgb_bands = ['B04', 'B03', 'B02'] filenames = [ - "ROIs1158_spring_lc.tar.gz", - "ROIs1158_spring_s1.tar.gz", - "ROIs1158_spring_s2.tar.gz", - "ROIs1868_summer_lc.tar.gz", - "ROIs1868_summer_s1.tar.gz", - "ROIs1868_summer_s2.tar.gz", - "ROIs1970_fall_lc.tar.gz", - "ROIs1970_fall_s1.tar.gz", - "ROIs1970_fall_s2.tar.gz", - "ROIs2017_winter_lc.tar.gz", - "ROIs2017_winter_s1.tar.gz", - "ROIs2017_winter_s2.tar.gz", - "train_list.txt", - "test_list.txt", + 'ROIs1158_spring_lc.tar.gz', + 'ROIs1158_spring_s1.tar.gz', + 'ROIs1158_spring_s2.tar.gz', + 'ROIs1868_summer_lc.tar.gz', + 'ROIs1868_summer_s1.tar.gz', + 'ROIs1868_summer_s2.tar.gz', + 'ROIs1970_fall_lc.tar.gz', + 'ROIs1970_fall_s1.tar.gz', + 'ROIs1970_fall_s2.tar.gz', + 'ROIs2017_winter_lc.tar.gz', + 'ROIs2017_winter_s1.tar.gz', + 'ROIs2017_winter_s2.tar.gz', + 'train_list.txt', + 'test_list.txt', ] light_filenames = [ - "ROIs1158_spring", - "ROIs1868_summer", - "ROIs1970_fall", - "ROIs2017_winter", - "train_list.txt", - "test_list.txt", + 'ROIs1158_spring', + 'ROIs1868_summer', + 'ROIs1970_fall', + 'ROIs2017_winter', + 'train_list.txt', + 'test_list.txt', ] md5s = [ - "6e2e8fa8b8cba77ddab49fd20ff5c37b", - "fba019bb27a08c1db96b31f718c34d79", - "d58af2c15a16f376eb3308dc9b685af2", - "2c5bd80244440b6f9d54957c6b1f23d4", - "01044b7f58d33570c6b57fec28a3d449", - "4dbaf72ecb704a4794036fe691427ff3", - "9b126a68b0e3af260071b3139cb57cee", - "19132e0aab9d4d6862fd42e8e6760847", - "b8f117818878da86b5f5e06400eb1866", - "0fa0420ef7bcfe4387c7e6fe226dc728", - "bb8cbfc16b95a4f054a3d5380e0130ed", - "3807545661288dcca312c9c538537b63", - "0a68d4e1eb24f128fccdb930000b2546", - "c7faad064001e646445c4c634169484d", + '6e2e8fa8b8cba77ddab49fd20ff5c37b', + 'fba019bb27a08c1db96b31f718c34d79', + 'd58af2c15a16f376eb3308dc9b685af2', + '2c5bd80244440b6f9d54957c6b1f23d4', + '01044b7f58d33570c6b57fec28a3d449', + '4dbaf72ecb704a4794036fe691427ff3', + '9b126a68b0e3af260071b3139cb57cee', + '19132e0aab9d4d6862fd42e8e6760847', + 'b8f117818878da86b5f5e06400eb1866', + '0fa0420ef7bcfe4387c7e6fe226dc728', + 'bb8cbfc16b95a4f054a3d5380e0130ed', + '3807545661288dcca312c9c538537b63', + '0a68d4e1eb24f128fccdb930000b2546', + 'c7faad064001e646445c4c634169484d', ] def __init__( self, - root: str = "data", - split: str = "train", - bands: Sequence[str] = BAND_SETS["all"], + root: str = 'data', + split: str = 'train', + bands: Sequence[str] = BAND_SETS['all'], transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: @@ -195,7 +195,7 @@ def __init__( AssertionError: if ``split`` argument is invalid DatasetNotFoundError: If dataset is not found. """ - assert split in ["train", "test"] + assert split in ['train', 'test'] self._validate_bands(bands) self.band_indices = torch.tensor( @@ -213,7 +213,7 @@ def __init__( ) or not self._check_integrity_light(): raise DatasetNotFoundError(self) - with open(os.path.join(self.root, split + "_list.txt")) as f: + with open(os.path.join(self.root, split + '_list.txt')) as f: self.ids = [line.rstrip() for line in f.readlines()] def __getitem__(self, index: int) -> dict[str, Tensor]: @@ -227,14 +227,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ filename = self.ids[index] - lc = self._load_raster(filename, "lc").long() - s1 = self._load_raster(filename, "s1") - s2 = self._load_raster(filename, "s2") + lc = self._load_raster(filename, 'lc').long() + s1 = self._load_raster(filename, 's1') + s2 = self._load_raster(filename, 's2') image = torch.cat(tensors=[s1, s2], dim=0) image = torch.index_select(image, dim=0, index=self.band_indices) - sample: dict[str, Tensor] = {"image": image, "mask": lc[0]} + sample: dict[str, Tensor] = {'image': image, 'mask': lc[0]} if self.transforms is not None: sample = self.transforms(sample) @@ -259,15 +259,15 @@ def _load_raster(self, filename: str, source: str) -> Tensor: Returns: the raster image or target """ - parts = filename.split("_") + parts = filename.split('_') parts[2] = source with rasterio.open( os.path.join( self.root, - "{}_{}".format(*parts), - "{2}_{3}".format(*parts), - "{}_{}_{}_{}_{}".format(*parts), + '{}_{}'.format(*parts), + '{2}_{3}'.format(*parts), + '{}_{}_{}_{}_{}'.format(*parts), ) ) as f: array = f.read() @@ -343,31 +343,31 @@ def plot( else: raise RGBBandsMissingError() - image, mask = sample["image"][rgb_indices].numpy(), sample["mask"] + image, mask = sample['image'][rgb_indices].numpy(), sample['mask'] image = percentile_normalization(image) ncols = 2 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = sample["prediction"] + prediction = sample['prediction'] ncols += 1 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) axs[0].imshow(np.transpose(image, (1, 2, 0))) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(mask) - axs[1].axis("off") + axs[1].axis('off') if showing_predictions: axs[2].imshow(prediction) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[0].set_title("Image") - axs[1].set_title("Mask") + axs[0].set_title('Image') + axs[1].set_title('Mask') if showing_predictions: - axs[2].set_title("Prediction") + axs[2].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py index 8cd7989f8a5..68f64311335 100644 --- a/torchgeo/datasets/sentinel.py +++ b/torchgeo/datasets/sentinel.py @@ -116,7 +116,7 @@ class Sentinel1(Sentinel): # l: Entire Area (e) or Clipped Area (c) # m: Dead Reckoning (d) or DEM Matching (m) # ssss: Product ID - filename_glob = "S1*{}.*" + filename_glob = 'S1*{}.*' filename_regex = r""" ^S1(?P[AB]) _(?PSM|IW|EW|WV) @@ -135,16 +135,16 @@ class Sentinel1(Sentinel): _(?P[VH]{2}) \. """ - date_format = "%Y%m%dT%H%M%S" - all_bands = ["HH", "HV", "VV", "VH"] + date_format = '%Y%m%dT%H%M%S' + all_bands = ['HH', 'HV', 'VV', 'VH'] separate_files = True def __init__( self, - paths: str | list[str] = "data", + paths: str | list[str] = 'data', crs: CRS | None = None, res: float = 10, - bands: Sequence[str] = ["VV", "VH"], + bands: Sequence[str] = ['VV', 'VH'], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, ) -> None: @@ -211,20 +211,20 @@ def plot( if len(bands) == 1: # Only horizontal or vertical receive, plot as grayscale - image = sample["image"][0] + image = sample['image'][0] image = torch.clamp(image, min=0, max=1) - title = f"({bands[0]})" + title = f'({bands[0]})' else: # Both horizontal and vertical receive, plot as RGB # Deal with reverse order - if bands in [["HV", "HH"], ["VH", "VV"]]: + if bands in [['HV', 'HH'], ['VH', 'VV']]: bands = bands[::-1] - sample["image"] = torch.flip(sample["image"], dims=[0]) + sample['image'] = torch.flip(sample['image'], dims=[0]) - co_polarization = sample["image"][0] # transmit == receive - cross_polarization = sample["image"][1] # transmit != receive + co_polarization = sample['image'][0] # transmit == receive + cross_polarization = sample['image'][1] # transmit != receive ratio = co_polarization / cross_polarization # https://gis.stackexchange.com/a/400780/123758 @@ -234,12 +234,12 @@ def plot( image = torch.stack((co_polarization, cross_polarization, ratio), dim=-1) - title = "({0}, {1}, {0}/{1})".format(*bands) + title = '({0}, {1}, {0}/{1})'.format(*bands) fig, ax = plt.subplots(1, 1, figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: ax.set_title(title) @@ -265,7 +265,7 @@ class Sentinel2(Sentinel): # https://sentinels.copernicus.eu/web/sentinel/user-guides/sentinel-2-msi/naming-convention # https://sentinel.esa.int/documents/247904/685211/Sentinel-2-MSI-L2A-Product-Format-Specifications.pdf - filename_glob = "T*_*_{}*.*" + filename_glob = 'T*_*_{}*.*' filename_regex = r""" ^T(?P\d{{2}}[A-Z]{{3}}) _(?P\d{{8}}T\d{{6}}) @@ -273,31 +273,31 @@ class Sentinel2(Sentinel): (?:_(?P{}m))? \..*$ """ - date_format = "%Y%m%dT%H%M%S" + date_format = '%Y%m%dT%H%M%S' # https://gisgeography.com/sentinel-2-bands-combinations/ all_bands = [ - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B10", - "B11", - "B12", + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B10', + 'B11', + 'B12', ] - rgb_bands = ["B04", "B03", "B02"] + rgb_bands = ['B04', 'B03', 'B02'] separate_files = True def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float = 10, bands: Sequence[str] | None = None, @@ -359,7 +359,7 @@ def plot( else: raise RGBBandsMissingError() - image = sample["image"][rgb_indices].permute(1, 2, 0) + image = sample['image'][rgb_indices].permute(1, 2, 0) # DN = 10000 * REFLECTANCE # https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/ image = torch.clamp(image / 10000, min=0, max=1) @@ -367,10 +367,10 @@ def plot( fig, ax = plt.subplots(1, 1, figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: - ax.set_title("Image") + ax.set_title('Image') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 71ec3c4d662..a54c797acb4 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -55,26 +55,26 @@ class SKIPPD(NonGeoDataset): .. versionadded:: 0.5 """ - url = "https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}" # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}' # noqa: E501 md5 = { - "forecast": "f4f3509ddcc83a55c433be9db2e51077", - "nowcast": "0000761d403e45bb5f86c21d3c69aa80", + 'forecast': 'f4f3509ddcc83a55c433be9db2e51077', + 'nowcast': '0000761d403e45bb5f86c21d3c69aa80', } - data_file_name = "2017_2019_images_pv_processed_{}.hdf5" - zipfile_name = "2017_2019_images_pv_processed_{}.zip" + data_file_name = '2017_2019_images_pv_processed_{}.hdf5' + zipfile_name = '2017_2019_images_pv_processed_{}.zip' - valid_splits = ["trainval", "test"] + valid_splits = ['trainval', 'test'] - valid_tasks = ["nowcast", "forecast"] + valid_tasks = ['nowcast', 'forecast'] - dateformat = "%m/%d/%Y, %H:%M:%S" + dateformat = '%m/%d/%Y, %H:%M:%S' def __init__( self, - root: str = "data", - split: str = "trainval", - task: str = "nowcast", + root: str = 'data', + split: str = 'trainval', + task: str = 'nowcast', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, checksum: bool = False, @@ -97,12 +97,12 @@ def __init__( """ assert ( split in self.valid_splits - ), f"Please choose one of these valid data splits {self.valid_splits}." + ), f'Please choose one of these valid data splits {self.valid_splits}.' self.split = split assert ( task in self.valid_tasks - ), f"Please choose one of these valid tasks {self.valid_tasks}." + ), f'Please choose one of these valid tasks {self.valid_tasks}.' self.task = task self.root = root @@ -114,7 +114,7 @@ def __init__( import h5py # noqa: F401 except ImportError: raise ImportError( - "h5py is not installed and is required to use this dataset" + 'h5py is not installed and is required to use this dataset' ) self._verify() @@ -128,9 +128,9 @@ def __len__(self) -> int: import h5py with h5py.File( - os.path.join(self.root, self.data_file_name.format(self.task)), "r" + os.path.join(self.root, self.data_file_name.format(self.task)), 'r' ) as f: - num_datapoints: int = f[self.split]["pv_log"].shape[0] + num_datapoints: int = f[self.split]['pv_log'].shape[0] return num_datapoints @@ -143,7 +143,7 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]: Returns: data and label at that index """ - sample: dict[str, str | Tensor] = {"image": self._load_image(index)} + sample: dict[str, str | Tensor] = {'image': self._load_image(index)} sample.update(self._load_features(index)) if self.transforms is not None: @@ -163,16 +163,16 @@ def _load_image(self, index: int) -> Tensor: import h5py with h5py.File( - os.path.join(self.root, self.data_file_name.format(self.task)), "r" + os.path.join(self.root, self.data_file_name.format(self.task)), 'r' ) as f: - arr = f[self.split]["images_log"][index] + arr = f[self.split]['images_log'][index] # forecast has dimension [16, 64, 64, 3] but reshape to [48, 64, 64] # https://github.com/yuhao-nie/Stanford-solar-forecasting-dataset/blob/main/models/SUNSET_forecast.ipynb - if self.task == "forecast": - arr = rearrange(arr, "t h w c-> (t c) h w") + if self.task == 'forecast': + arr = rearrange(arr, 't h w c-> (t c) h w') else: - arr = rearrange(arr, "h w c -> c h w") + arr = rearrange(arr, 'h w c -> c h w') tensor = torch.from_numpy(arr).to(torch.float32) return tensor @@ -189,16 +189,16 @@ def _load_features(self, index: int) -> dict[str, str | Tensor]: import h5py with h5py.File( - os.path.join(self.root, self.data_file_name.format(self.task)), "r" + os.path.join(self.root, self.data_file_name.format(self.task)), 'r' ) as f: - label = f[self.split]["pv_log"][index] + label = f[self.split]['pv_log'][index] - path = os.path.join(self.root, f"times_{self.split}_{self.task}.npy") + path = os.path.join(self.root, f'times_{self.split}_{self.task}.npy') datestring = np.load(path, allow_pickle=True)[index].strftime(self.dateformat) features: dict[str, str | Tensor] = { - "label": torch.tensor(label, dtype=torch.float32), - "date": datestring, + 'label': torch.tensor(label, dtype=torch.float32), + 'date': datestring, } return features @@ -256,27 +256,27 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - if self.task == "nowcast": - image, label = sample["image"].permute(1, 2, 0), sample["label"].item() + if self.task == 'nowcast': + image, label = sample['image'].permute(1, 2, 0), sample['label'].item() else: image, label = ( - sample["image"].permute(1, 2, 0).reshape(64, 64, 3, 16)[:, :, :, -1], - sample["label"][-1].item(), + sample['image'].permute(1, 2, 0).reshape(64, 64, 3, 16)[:, :, :, -1], + sample['label'][-1].item(), ) - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = sample["prediction"].item() + prediction = sample['prediction'].item() fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(image / 255) - ax.axis("off") + ax.axis('off') if show_titles: - title = f"Label: {label:.3f}" + title = f'Label: {label:.3f}' if showing_predictions: - title += f"\nPrediction: {prediction:.3f}" + title += f'\nPrediction: {prediction:.3f}' ax.set_title(title) if suptitle is not None: diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index dafbb886369..9a36c8e0786 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -103,99 +103,99 @@ class So2Sat(NonGeoDataset): or manually downloaded from https://mediatum.ub.tum.de/1613658 """ # noqa: E501 - versions = ["2", "3_random", "3_block", "3_culture_10"] + versions = ['2', '3_random', '3_block', '3_culture_10'] filenames_by_version = { - "2": { - "train": "training.h5", - "validation": "validation.h5", - "test": "testing.h5", + '2': { + 'train': 'training.h5', + 'validation': 'validation.h5', + 'test': 'testing.h5', }, - "3_random": {"train": "random/training.h5", "test": "random/testing.h5"}, - "3_block": {"train": "block/training.h5", "test": "block/testing.h5"}, - "3_culture_10": { - "train": "culture_10/training.h5", - "test": "culture_10/testing.h5", + '3_random': {'train': 'random/training.h5', 'test': 'random/testing.h5'}, + '3_block': {'train': 'block/training.h5', 'test': 'block/testing.h5'}, + '3_culture_10': { + 'train': 'culture_10/training.h5', + 'test': 'culture_10/testing.h5', }, } md5s_by_version = { - "2": { - "train": "702bc6a9368ebff4542d791e53469244", - "validation": "71cfa6795de3e22207229d06d6f8775d", - "test": "e81426102b488623a723beab52b31a8a", + '2': { + 'train': '702bc6a9368ebff4542d791e53469244', + 'validation': '71cfa6795de3e22207229d06d6f8775d', + 'test': 'e81426102b488623a723beab52b31a8a', }, - "3_random": { - "train": "94e2e2e667b406c2adf61e113b42204e", - "test": "1e15c425585ce816342d1cd779d453d8", + '3_random': { + 'train': '94e2e2e667b406c2adf61e113b42204e', + 'test': '1e15c425585ce816342d1cd779d453d8', }, - "3_block": { - "train": "a91d6150e8b059dac86105853f377a11", - "test": "6414af1ec33ace417e879f9c88066d47", + '3_block': { + 'train': 'a91d6150e8b059dac86105853f377a11', + 'test': '6414af1ec33ace417e879f9c88066d47', }, - "3_culture_10": { - "train": "702bc6a9368ebff4542d791e53469244", - "test": "58335ce34ca3a18424e19da84f2832fc", + '3_culture_10': { + 'train': '702bc6a9368ebff4542d791e53469244', + 'test': '58335ce34ca3a18424e19da84f2832fc', }, } classes = [ - "Compact high rise", - "Compact mid rise", - "Compact low rise", - "Open high rise", - "Open mid rise", - "Open low rise", - "Lightweight low rise", - "Large low rise", - "Sparsely built", - "Heavy industry", - "Dense trees", - "Scattered trees", - "Bush, scrub", - "Low plants", - "Bare rock or paved", - "Bare soil or sand", - "Water", + 'Compact high rise', + 'Compact mid rise', + 'Compact low rise', + 'Open high rise', + 'Open mid rise', + 'Open low rise', + 'Lightweight low rise', + 'Large low rise', + 'Sparsely built', + 'Heavy industry', + 'Dense trees', + 'Scattered trees', + 'Bush, scrub', + 'Low plants', + 'Bare rock or paved', + 'Bare soil or sand', + 'Water', ] all_s1_band_names = ( - "S1_B1", - "S1_B2", - "S1_B3", - "S1_B4", - "S1_B5", - "S1_B6", - "S1_B7", - "S1_B8", + 'S1_B1', + 'S1_B2', + 'S1_B3', + 'S1_B4', + 'S1_B5', + 'S1_B6', + 'S1_B7', + 'S1_B8', ) all_s2_band_names = ( - "S2_B02", - "S2_B03", - "S2_B04", - "S2_B05", - "S2_B06", - "S2_B07", - "S2_B08", - "S2_B8A", - "S2_B11", - "S2_B12", + 'S2_B02', + 'S2_B03', + 'S2_B04', + 'S2_B05', + 'S2_B06', + 'S2_B07', + 'S2_B08', + 'S2_B8A', + 'S2_B11', + 'S2_B12', ) all_band_names = all_s1_band_names + all_s2_band_names - rgb_bands = ["S2_B04", "S2_B03", "S2_B02"] + rgb_bands = ['S2_B04', 'S2_B03', 'S2_B02'] BAND_SETS = { - "all": all_band_names, - "s1": all_s1_band_names, - "s2": all_s2_band_names, - "rgb": rgb_bands, + 'all': all_band_names, + 's1': all_s1_band_names, + 's2': all_s2_band_names, + 'rgb': rgb_bands, } def __init__( self, - root: str = "data", - version: str = "2", - split: str = "train", - bands: Sequence[str] = BAND_SETS["all"], + root: str = 'data', + version: str = '2', + split: str = 'train', + bands: Sequence[str] = BAND_SETS['all'], transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: @@ -225,13 +225,13 @@ def __init__( import h5py # noqa: F401 except ImportError: raise ImportError( - "h5py is not installed and is required to use this dataset" + 'h5py is not installed and is required to use this dataset' ) assert version in self.versions assert split in self.filenames_by_version[version] self._validate_bands(bands) - self.s1_band_indices: "np.typing.NDArray[np.int_]" = np.array( + self.s1_band_indices: 'np.typing.NDArray[np.int_]' = np.array( [ self.all_s1_band_names.index(b) for b in bands @@ -241,7 +241,7 @@ def __init__( self.s1_band_names = [self.all_s1_band_names[i] for i in self.s1_band_indices] - self.s2_band_indices: "np.typing.NDArray[np.int_]" = np.array( + self.s2_band_indices: 'np.typing.NDArray[np.int_]' = np.array( [ self.all_s2_band_names.index(b) for b in bands @@ -264,8 +264,8 @@ def __init__( if not self._check_integrity(): raise DatasetNotFoundError(self) - with h5py.File(self.fn, "r") as f: - self.size: int = f["label"].shape[0] + with h5py.File(self.fn, 'r') as f: + self.size: int = f['label'].shape[0] def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. @@ -278,14 +278,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ import h5py - with h5py.File(self.fn, "r") as f: - s1 = f["sen1"][index].astype(np.float64) # convert from dict[str, Tensor]: s1 = torch.from_numpy(s1) s2 = torch.from_numpy(s2) - sample = {"image": torch.cat([s1, s2]).float(), "label": label} + sample = {'image': torch.cat([s1, s2]).float(), 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -365,25 +365,25 @@ def plot( else: raise RGBBandsMissingError() - image = np.take(sample["image"].numpy(), indices=rgb_indices, axis=0) + image = np.take(sample['image'].numpy(), indices=rgb_indices, axis=0) image = np.rollaxis(image, 0, 3) image = percentile_normalization(image, 0, 100) - label = cast(int, sample["label"].item()) + label = cast(int, sample['label'].item()) label_class = self.classes[label] - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = cast(int, sample["prediction"].item()) + prediction = cast(int, sample['prediction'].item()) prediction_class = self.classes[prediction] fig, ax = plt.subplots(figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: - title = f"Label: {label_class}" + title = f'Label: {label_class}' if showing_predictions: - title += f"\nPrediction: {prediction_class}" + title += f'\nPrediction: {prediction_class}' ax.set_title(title) if suptitle is not None: diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py index c98c4803fb1..b5e0cf8d41e 100644 --- a/torchgeo/datasets/south_africa_crop_type.py +++ b/torchgeo/datasets/south_africa_crop_type.py @@ -65,22 +65,22 @@ class SouthAfricaCropType(RasterDataset): _(?P[0-9]{4}_[0-9]{2}_[0-9]{2}) _(?P(B[0-9A-Z]{2} | VH | VV)) _10m""" - date_format = "%Y_%m_%d" - rgb_bands = ["B04", "B03", "B02"] - s1_bands = ["VH", "VV"] + date_format = '%Y_%m_%d' + rgb_bands = ['B04', 'B03', 'B02'] + s1_bands = ['VH', 'VV'] s2_bands = [ - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B11", - "B12", + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', ] all_bands: list[str] = s1_bands + s2_bands cmap = { @@ -98,7 +98,7 @@ class SouthAfricaCropType(RasterDataset): def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, classes: list[int] = list(cmap.keys()), bands: list[str] = all_bands, @@ -119,8 +119,8 @@ def __init__( """ assert ( set(classes) <= self.cmap.keys() - ), f"Only the following classes are valid: {list(self.cmap.keys())}." - assert 0 in classes, "Classes must include the background class: 0" + ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths self.classes = classes @@ -151,7 +151,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: if not filepaths: raise IndexError( - f"query: {query} not found in index with bounds: {self.bounds}" + f'query: {query} not found in index with bounds: {self.bounds}' ) data_list: list[Tensor] = [] @@ -166,33 +166,33 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: filename = os.path.basename(filepath) match = re.match(filename_regex, filename) if match: - field_id = match.group("field_id") - date = match.group("date") - band = match.group("band") - band_type = "s1" if band in self.s1_bands else "s2" + field_id = match.group('field_id') + date = match.group('date') + band = match.group('band') + band_type = 's1' if band in self.s1_bands else 's2' if field_id not in field_ids: field_ids.append(field_id) - imagery_dates[field_id] = {"s1": "", "s2": ""} + imagery_dates[field_id] = {'s1': '', 's2': ''} if ( - date.split("_")[1] == "07" + date.split('_')[1] == '07' and not imagery_dates[field_id][band_type] ): imagery_dates[field_id][band_type] = date # Create Tensors for each band using stored dates for band in self.bands: - band_type = "s1" if band in self.s1_bands else "s2" + band_type = 's1' if band in self.s1_bands else 's2' band_filepaths = [] for field_id in field_ids: date = imagery_dates[field_id][band_type] filepath = os.path.join( self.paths, - "train", - "imagery", + 'train', + 'imagery', band_type, field_id, date, - f"{field_id}_{date}_{band}_10m.tif", + f'{field_id}_{date}_{band}_10m.tif', ) band_filepaths.append(filepath) data_list.append(self._merge_files(band_filepaths, query)) @@ -202,17 +202,17 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: mask_filepaths: list[str] = [] for field_id in field_ids: file_path = filepath = os.path.join( - self.paths, "train", "labels", f"{field_id}.tif" + self.paths, 'train', 'labels', f'{field_id}.tif' ) mask_filepaths.append(file_path) mask = self._merge_files(mask_filepaths, query) sample = { - "crs": self.crs, - "bbox": query, - "image": image.float(), - "mask": mask.long(), + 'crs': self.crs, + 'bbox': query, + 'image': image.float(), + 'mask': mask.long(), } if self.transforms is not None: @@ -246,31 +246,31 @@ def plot( else: raise RGBBandsMissingError() - image = sample["image"][rgb_indices].permute(1, 2, 0) + image = sample['image'][rgb_indices].permute(1, 2, 0) image = (image - image.min()) / (image.max() - image.min()) - mask = sample["mask"].squeeze() + mask = sample['mask'].squeeze() ncols = 2 - showing_prediction = "prediction" in sample + showing_prediction = 'prediction' in sample if showing_prediction: - pred = sample["prediction"].squeeze() + pred = sample['prediction'].squeeze() ncols += 1 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4)) axs[0].imshow(image) - axs[0].axis("off") - axs[1].imshow(self.ordinal_cmap[mask], interpolation="none") - axs[1].axis("off") + axs[0].axis('off') + axs[1].imshow(self.ordinal_cmap[mask], interpolation='none') + axs[1].axis('off') if show_titles: - axs[0].set_title("Image") - axs[1].set_title("Mask") + axs[0].set_title('Image') + axs[1].set_title('Mask') if showing_prediction: axs[2].imshow(pred) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[2].set_title("Prediction") + axs[2].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index 131841d775b..bde34deebb7 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -38,40 +38,40 @@ class SouthAmericaSoybean(RasterDataset): .. versionadded:: 0.6 """ - filename_glob = "South_America_Soybean_*.*" - filename_regex = r"South_America_Soybean_(?P\d{4})" + filename_glob = 'South_America_Soybean_*.*' + filename_regex = r'South_America_Soybean_(?P\d{4})' - date_format = "%Y" + date_format = '%Y' is_image = False - url = "https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_{}.tif" + url = 'https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_{}.tif' md5s = { - 2021: "edff3ada13a1a9910d1fe844d28ae4f", - 2020: "0709dec807f576c9707c8c7e183db31", - 2019: "441836493bbcd5e123cff579a58f5a4f", - 2018: "503c2d0a803c2a2629ebbbd9558a3013", - 2017: "4d0487ac1105d171e5f506f1766ea777", - 2016: "770c558f6ac40550d0e264da5e44b3e", - 2015: "6beb96a61fe0e9ce8c06263e500dde8f", - 2014: "824ff91c62a4ba9f4ccfd281729830e5", - 2013: "0263e19b3cae6fdaba4e3b450cef985e", - 2012: "9f3a71097c9836fcff18a13b9ba608b2", - 2011: "b73352ebea3d5658959e9044ec526143", - 2010: "9264532d36ffa93493735a6e44caef0d", - 2009: "341387c1bb42a15140c80702e4cca02d", - 2008: "96fc3f737ab3ce9bcd16cbf7761427e2", - 2007: "bb8549b6674163fe20ffd47ec4ce8903", - 2006: "eabaa525414ecbff89301d3d5c706f0b", - 2005: "89faae27f9b5afbd06935a465e5fe414", - 2004: "f9882ca9c70e054e50172835cb75a8c3", - 2003: "cad5ed461ff4ab45c90177841aaecad2", - 2002: "8a4a9dcea54b3ec7de07657b9f2c0893", - 2001: "2914b0af7590a0ca4dfa9ccefc99020f", + 2021: 'edff3ada13a1a9910d1fe844d28ae4f', + 2020: '0709dec807f576c9707c8c7e183db31', + 2019: '441836493bbcd5e123cff579a58f5a4f', + 2018: '503c2d0a803c2a2629ebbbd9558a3013', + 2017: '4d0487ac1105d171e5f506f1766ea777', + 2016: '770c558f6ac40550d0e264da5e44b3e', + 2015: '6beb96a61fe0e9ce8c06263e500dde8f', + 2014: '824ff91c62a4ba9f4ccfd281729830e5', + 2013: '0263e19b3cae6fdaba4e3b450cef985e', + 2012: '9f3a71097c9836fcff18a13b9ba608b2', + 2011: 'b73352ebea3d5658959e9044ec526143', + 2010: '9264532d36ffa93493735a6e44caef0d', + 2009: '341387c1bb42a15140c80702e4cca02d', + 2008: '96fc3f737ab3ce9bcd16cbf7761427e2', + 2007: 'bb8549b6674163fe20ffd47ec4ce8903', + 2006: 'eabaa525414ecbff89301d3d5c706f0b', + 2005: '89faae27f9b5afbd06935a465e5fe414', + 2004: 'f9882ca9c70e054e50172835cb75a8c3', + 2003: 'cad5ed461ff4ab45c90177841aaecad2', + 2002: '8a4a9dcea54b3ec7de07657b9f2c0893', + 2001: '2914b0af7590a0ca4dfa9ccefc99020f', } def __init__( self, - paths: str | Iterable[str] = "data", + paths: str | Iterable[str] = 'data', crs: CRS | None = None, res: float | None = None, years: list[int] = [2021], @@ -145,29 +145,29 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - mask = sample["mask"].squeeze() + mask = sample['mask'].squeeze() ncols = 1 - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - pred = sample["prediction"].squeeze() + pred = sample['prediction'].squeeze() ncols = 2 fig, axs = plt.subplots( nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False ) - axs[0, 0].imshow(mask, interpolation="none") - axs[0, 0].axis("off") + axs[0, 0].imshow(mask, interpolation='none') + axs[0, 0].axis('off') if show_titles: - axs[0, 0].set_title("Mask") + axs[0, 0].set_title('Mask') if showing_predictions: - axs[0, 1].imshow(pred, interpolation="none") - axs[0, 1].axis("off") + axs[0, 1].imshow(pred, interpolation='none') + axs[0, 1].axis('off') if show_titles: - axs[0, 1].set_title("Prediction") + axs[0, 1].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index c838b96cece..f4fd3d2b8a7 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -135,13 +135,13 @@ def _load_files(self, root: str) -> list[dict[str, str]]: """ files = [] for collection in self.collections: - images = glob.glob(os.path.join(root, collection, "*", self.filename)) + images = glob.glob(os.path.join(root, collection, '*', self.filename)) images = sorted(images) for imgpath in images: lbl_path = os.path.join( - f"{os.path.dirname(imgpath)}-labels", self.label_glob + f'{os.path.dirname(imgpath)}-labels', self.label_glob ) - files.append({"image_path": imgpath, "label_path": lbl_path}) + files.append({'image_path': imgpath, 'label_path': lbl_path}) return files def _load_image(self, path: str) -> tuple[Tensor, Affine, CRS]: @@ -177,13 +177,13 @@ def _load_mask( with fiona.open(path) as src: vector_crs = CRS(src.crs) if raster_crs == vector_crs: - labels = [feature["geometry"] for feature in src] + labels = [feature['geometry'] for feature in src] else: labels = [ transform_geom( vector_crs.to_string(), raster_crs.to_string(), - feature["geometry"], + feature['geometry'], ) for feature in src ] @@ -224,12 +224,12 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - img, tfm, raster_crs = self._load_image(files["image_path"]) + img, tfm, raster_crs = self._load_image(files['image_path']) h, w = img.shape[1:] - mask = self._load_mask(files["label_path"], tfm, raster_crs, (h, w)) + mask = self._load_mask(files['label_path'], tfm, raster_crs, (h, w)) ch, cw = self.chip_size[self.image] - sample = {"image": img[:, :ch, :cw], "mask": mask[:ch, :cw]} + sample = {'image': img[:, :ch, :cw], 'mask': mask[:ch, :cw]} if self.transforms is not None: sample = self.transforms(sample) @@ -245,7 +245,7 @@ def _check_integrity(self) -> list[str]: # Check if collections exist missing_collections = [] for collection in self.collections: - stacpath = os.path.join(self.root, collection, "collection.json") + stacpath = os.path.join(self.root, collection, 'collection.json') if not os.path.exists(stacpath): missing_collections.append(collection) @@ -255,9 +255,9 @@ def _check_integrity(self) -> list[str]: to_be_downloaded = [] for collection in missing_collections: - archive_path = os.path.join(self.root, f"{collection}.tar.gz") + archive_path = os.path.join(self.root, f'{collection}.tar.gz') if os.path.exists(archive_path): - print(f"Found {collection} archive") + print(f'Found {collection} archive') if ( self.checksum and check_integrity( @@ -265,13 +265,13 @@ def _check_integrity(self) -> list[str]: ) or not self.checksum ): - print("Extracting...") + print('Extracting...') extract_archive(archive_path) else: - print(f"Collection {collection} is corrupted") + print(f'Collection {collection} is corrupted') to_be_downloaded.append(collection) else: - print(f"{collection} not found") + print(f'{collection} not found') to_be_downloaded.append(collection) return to_be_downloaded @@ -285,16 +285,16 @@ def _download(self, collections: list[str], api_key: str | None = None) -> None: """ for collection in collections: download_radiant_mlhub_collection(collection, self.root, api_key) - archive_path = os.path.join(self.root, f"{collection}.tar.gz") + archive_path = os.path.join(self.root, f'{collection}.tar.gz') if ( not self.checksum or not check_integrity( archive_path, self.collection_md5_dict[collection] ) ) and self.checksum: - raise RuntimeError(f"Collection {collection} corrupted") + raise RuntimeError(f'Collection {collection} corrupted') - print("Extracting...") + print('Extracting...') extract_archive(archive_path) def plot( @@ -316,43 +316,43 @@ def plot( .. versionadded:: 0.2 """ # image can be 1 channel or >3 channels - if sample["image"].shape[0] == 1: - image = np.rollaxis(sample["image"].numpy(), 0, 3) + if sample['image'].shape[0] == 1: + image = np.rollaxis(sample['image'].numpy(), 0, 3) else: - image = np.rollaxis(sample["image"][:3].numpy(), 0, 3) + image = np.rollaxis(sample['image'][:3].numpy(), 0, 3) image = percentile_normalization(image, axis=(0, 1)) ncols = 1 - show_mask = "mask" in sample - show_predictions = "prediction" in sample + show_mask = 'mask' in sample + show_predictions = 'prediction' in sample if show_mask: - mask = sample["mask"].numpy() + mask = sample['mask'].numpy() ncols += 1 if show_predictions: - prediction = sample["prediction"].numpy() + prediction = sample['prediction'].numpy() ncols += 1 fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8)) if not isinstance(axs, np.ndarray): axs = [axs] axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') if show_titles: - axs[0].set_title("Image") + axs[0].set_title('Image') if show_mask: - axs[1].imshow(mask, interpolation="none") - axs[1].axis("off") + axs[1].imshow(mask, interpolation='none') + axs[1].axis('off') if show_titles: - axs[1].set_title("Label") + axs[1].set_title('Label') if show_predictions: - axs[2].imshow(prediction, interpolation="none") - axs[2].axis("off") + axs[2].imshow(prediction, interpolation='none') + axs[2].axis('off') if show_titles: - axs[2].set_title("Prediction") + axs[2].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) @@ -390,16 +390,16 @@ class SpaceNet1(SpaceNet): """ - dataset_id = "spacenet1" - imagery = {"rgb": "RGB.tif", "8band": "8Band.tif"} - chip_size = {"rgb": (406, 438), "8band": (101, 110)} - label_glob = "labels.geojson" - collection_md5_dict = {"sn1_AOI_1_RIO": "e6ea35331636fa0c036c04b3d1cbf226"} + dataset_id = 'spacenet1' + imagery = {'rgb': 'RGB.tif', '8band': '8Band.tif'} + chip_size = {'rgb': (406, 438), '8band': (101, 110)} + label_glob = 'labels.geojson' + collection_md5_dict = {'sn1_AOI_1_RIO': 'e6ea35331636fa0c036c04b3d1cbf226'} def __init__( self, - root: str = "data", - image: str = "rgb", + root: str = 'data', + image: str = 'rgb', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, api_key: str | None = None, @@ -419,8 +419,8 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ - collections = ["sn1_AOI_1_RIO"] - assert image in {"rgb", "8band"} + collections = ['sn1_AOI_1_RIO'] + assert image in {'rgb', '8band'} super().__init__( root, image, collections, transforms, download, api_key, checksum ) @@ -489,32 +489,32 @@ class SpaceNet2(SpaceNet): """ - dataset_id = "spacenet2" + dataset_id = 'spacenet2' collection_md5_dict = { - "sn2_AOI_2_Vegas": "a5a8de355290783b88ac4d69c7ef0694", - "sn2_AOI_3_Paris": "8299186b7bbfb9a256d515bad1b7f146", - "sn2_AOI_4_Shanghai": "4e3e80f2f437faca10ca2e6e6df0ef99", - "sn2_AOI_5_Khartoum": "8070ff9050f94cd9f0efe9417205d7c3", + 'sn2_AOI_2_Vegas': 'a5a8de355290783b88ac4d69c7ef0694', + 'sn2_AOI_3_Paris': '8299186b7bbfb9a256d515bad1b7f146', + 'sn2_AOI_4_Shanghai': '4e3e80f2f437faca10ca2e6e6df0ef99', + 'sn2_AOI_5_Khartoum': '8070ff9050f94cd9f0efe9417205d7c3', } imagery = { - "MS": "MS.tif", - "PAN": "PAN.tif", - "PS-MS": "PS-MS.tif", - "PS-RGB": "PS-RGB.tif", + 'MS': 'MS.tif', + 'PAN': 'PAN.tif', + 'PS-MS': 'PS-MS.tif', + 'PS-RGB': 'PS-RGB.tif', } chip_size = { - "MS": (162, 162), - "PAN": (650, 650), - "PS-MS": (650, 650), - "PS-RGB": (650, 650), + 'MS': (162, 162), + 'PAN': (650, 650), + 'PS-MS': (650, 650), + 'PS-RGB': (650, 650), } - label_glob = "label.geojson" + label_glob = 'label.geojson' def __init__( self, - root: str = "data", - image: str = "PS-RGB", + root: str = 'data', + image: str = 'PS-RGB', collections: list[str] = [], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, @@ -539,7 +539,7 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert image in {"MS", "PAN", "PS-MS", "PS-RGB"} + assert image in {'MS', 'PAN', 'PS-MS', 'PS-RGB'} super().__init__( root, image, collections, transforms, download, api_key, checksum ) @@ -609,32 +609,32 @@ class SpaceNet3(SpaceNet): .. versionadded:: 0.3 """ - dataset_id = "spacenet3" + dataset_id = 'spacenet3' collection_md5_dict = { - "sn3_AOI_2_Vegas": "8ce7e6abffb8849eb88885035f061ee8", - "sn3_AOI_3_Paris": "90b9ebd64cd83dc8d3d4773f45050d8f", - "sn3_AOI_4_Shanghai": "3ea291df34548962dfba8b5ed37d700c", - "sn3_AOI_5_Khartoum": "b8d549ac9a6d7456c0f7a8e6de23d9f9", + 'sn3_AOI_2_Vegas': '8ce7e6abffb8849eb88885035f061ee8', + 'sn3_AOI_3_Paris': '90b9ebd64cd83dc8d3d4773f45050d8f', + 'sn3_AOI_4_Shanghai': '3ea291df34548962dfba8b5ed37d700c', + 'sn3_AOI_5_Khartoum': 'b8d549ac9a6d7456c0f7a8e6de23d9f9', } imagery = { - "MS": "MS.tif", - "PAN": "PAN.tif", - "PS-MS": "PS-MS.tif", - "PS-RGB": "PS-RGB.tif", + 'MS': 'MS.tif', + 'PAN': 'PAN.tif', + 'PS-MS': 'PS-MS.tif', + 'PS-RGB': 'PS-RGB.tif', } chip_size = { - "MS": (325, 325), - "PAN": (1300, 1300), - "PS-MS": (1300, 1300), - "PS-RGB": (1300, 1300), + 'MS': (325, 325), + 'PAN': (1300, 1300), + 'PS-MS': (1300, 1300), + 'PS-RGB': (1300, 1300), } - label_glob = "labels.geojson" + label_glob = 'labels.geojson' def __init__( self, - root: str = "data", - image: str = "PS-RGB", + root: str = 'data', + image: str = 'PS-RGB', speed_mask: bool | None = False, collections: list[str] = [], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -662,7 +662,7 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert image in {"MS", "PAN", "PS-MS", "PS-RGB"} + assert image in {'MS', 'PAN', 'PS-MS', 'PS-RGB'} self.speed_mask = speed_mask super().__init__( root, image, collections, transforms, download, api_key, checksum @@ -686,7 +686,7 @@ def _load_mask( max_speed_bin = 65 speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1) bin_size_mph = 10.0 - speed_cls_arr: "np.typing.NDArray[np.int_]" = np.array( + speed_cls_arr: 'np.typing.NDArray[np.int_]' = np.array( [math.ceil(s / bin_size_mph) for s in speed_arr_bin] ) @@ -700,14 +700,14 @@ def _load_mask( geom = transform_geom( vector_crs.to_string(), raster_crs.to_string(), - feature["geometry"], + feature['geometry'], ) else: - geom = feature["geometry"] + geom = feature['geometry'] if self.speed_mask: val = speed_cls_arr[ - int(feature["properties"]["inferred_speed_mph"]) - 1 + int(feature['properties']['inferred_speed_mph']) - 1 ] else: val = 1 @@ -750,55 +750,55 @@ def plot( """ # image can be 1 channel or >3 channels - if sample["image"].shape[0] == 1: - image = np.rollaxis(sample["image"].numpy(), 0, 3) + if sample['image'].shape[0] == 1: + image = np.rollaxis(sample['image'].numpy(), 0, 3) else: - image = np.rollaxis(sample["image"][:3].numpy(), 0, 3) + image = np.rollaxis(sample['image'][:3].numpy(), 0, 3) image = percentile_normalization(image, axis=(0, 1)) ncols = 1 - show_mask = "mask" in sample - show_predictions = "prediction" in sample + show_mask = 'mask' in sample + show_predictions = 'prediction' in sample if show_mask: - mask = sample["mask"].numpy() + mask = sample['mask'].numpy() ncols += 1 if show_predictions: - prediction = sample["prediction"].numpy() + prediction = sample['prediction'].numpy() ncols += 1 fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8)) if not isinstance(axs, np.ndarray): axs = [axs] axs[0].imshow(image) - axs[0].axis("off") + axs[0].axis('off') if show_titles: - axs[0].set_title("Image") + axs[0].set_title('Image') if show_mask: if self.speed_mask: - cmap = copy.copy(plt.get_cmap("autumn_r")) - cmap.set_under(color="black") - axs[1].imshow(mask, vmin=0.1, vmax=7, cmap=cmap, interpolation="none") + cmap = copy.copy(plt.get_cmap('autumn_r')) + cmap.set_under(color='black') + axs[1].imshow(mask, vmin=0.1, vmax=7, cmap=cmap, interpolation='none') else: - axs[1].imshow(mask, cmap="Greys_r", interpolation="none") - axs[1].axis("off") + axs[1].imshow(mask, cmap='Greys_r', interpolation='none') + axs[1].axis('off') if show_titles: - axs[1].set_title("Label") + axs[1].set_title('Label') if show_predictions: if self.speed_mask: - cmap = copy.copy(plt.get_cmap("autumn_r")) - cmap.set_under(color="black") + cmap = copy.copy(plt.get_cmap('autumn_r')) + cmap.set_under(color='black') axs[2].imshow( - prediction, vmin=0.1, vmax=7, cmap=cmap, interpolation="none" + prediction, vmin=0.1, vmax=7, cmap=cmap, interpolation='none' ) else: - axs[2].imshow(prediction, cmap="Greys_r", interpolation="none") - axs[2].axis("off") + axs[2].imshow(prediction, cmap='Greys_r', interpolation='none') + axs[2].axis('off') if show_titles: - axs[2].set_title("Prediction") + axs[2].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) @@ -839,53 +839,53 @@ class SpaceNet4(SpaceNet): """ - dataset_id = "spacenet4" - collection_md5_dict = {"sn4_AOI_6_Atlanta": "c597d639cba5257927a97e3eff07b753"} + dataset_id = 'spacenet4' + collection_md5_dict = {'sn4_AOI_6_Atlanta': 'c597d639cba5257927a97e3eff07b753'} - imagery = {"MS": "MS.tif", "PAN": "PAN.tif", "PS-RGBNIR": "PS-RGBNIR.tif"} - chip_size = {"MS": (225, 225), "PAN": (900, 900), "PS-RGBNIR": (900, 900)} - label_glob = "labels.geojson" + imagery = {'MS': 'MS.tif', 'PAN': 'PAN.tif', 'PS-RGBNIR': 'PS-RGBNIR.tif'} + chip_size = {'MS': (225, 225), 'PAN': (900, 900), 'PS-RGBNIR': (900, 900)} + label_glob = 'labels.geojson' angle_catalog_map = { - "nadir": [ - "1030010003D22F00", - "10300100023BC100", - "1030010003993E00", - "1030010003CAF100", - "1030010002B7D800", - "10300100039AB000", - "1030010002649200", - "1030010003C92000", - "1030010003127500", - "103001000352C200", - "103001000307D800", + 'nadir': [ + '1030010003D22F00', + '10300100023BC100', + '1030010003993E00', + '1030010003CAF100', + '1030010002B7D800', + '10300100039AB000', + '1030010002649200', + '1030010003C92000', + '1030010003127500', + '103001000352C200', + '103001000307D800', ], - "off-nadir": [ - "1030010003472200", - "1030010003315300", - "10300100036D5200", - "103001000392F600", - "1030010003697400", - "1030010003895500", - "1030010003832800", + 'off-nadir': [ + '1030010003472200', + '1030010003315300', + '10300100036D5200', + '103001000392F600', + '1030010003697400', + '1030010003895500', + '1030010003832800', ], - "very-off-nadir": [ - "10300100035D1B00", - "1030010003CCD700", - "1030010003713C00", - "10300100033C5200", - "1030010003492700", - "10300100039E6200", - "1030010003BDDC00", - "1030010003CD4300", - "1030010003193D00", + 'very-off-nadir': [ + '10300100035D1B00', + '1030010003CCD700', + '1030010003713C00', + '10300100033C5200', + '1030010003492700', + '10300100039E6200', + '1030010003BDDC00', + '1030010003CD4300', + '1030010003193D00', ], } def __init__( self, - root: str = "data", - image: str = "PS-RGBNIR", + root: str = 'data', + image: str = 'PS-RGBNIR', angles: list[str] = [], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, @@ -908,8 +908,8 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ - collections = ["sn4_AOI_6_Atlanta"] - assert image in {"MS", "PAN", "PS-RGBNIR"} + collections = ['sn4_AOI_6_Atlanta'] + assert image in {'MS', 'PAN', 'PS-RGBNIR'} self.angles = angles if self.angles: for angle in self.angles: @@ -931,33 +931,33 @@ def _load_files(self, root: str) -> list[dict[str, str]]: nadir = [] offnadir = [] veryoffnadir = [] - images = glob.glob(os.path.join(root, self.collections[0], "*", self.filename)) + images = glob.glob(os.path.join(root, self.collections[0], '*', self.filename)) images = sorted(images) - catalog_id_pattern = re.compile(r"(_[A-Z0-9])\w+$") + catalog_id_pattern = re.compile(r'(_[A-Z0-9])\w+$') for imgpath in images: imgdir = os.path.basename(os.path.dirname(imgpath)) match = catalog_id_pattern.search(imgdir) - assert match is not None, "Invalid image directory" + assert match is not None, 'Invalid image directory' catalog_id = match.group()[1:] - lbl_dir = os.path.dirname(imgpath).split("-nadir")[0] + lbl_dir = os.path.dirname(imgpath).split('-nadir')[0] - lbl_path = os.path.join(f"{lbl_dir}-labels", self.label_glob) + lbl_path = os.path.join(f'{lbl_dir}-labels', self.label_glob) assert os.path.exists(lbl_path) - _file = {"image_path": imgpath, "label_path": lbl_path} - if catalog_id in self.angle_catalog_map["very-off-nadir"]: + _file = {'image_path': imgpath, 'label_path': lbl_path} + if catalog_id in self.angle_catalog_map['very-off-nadir']: veryoffnadir.append(_file) - elif catalog_id in self.angle_catalog_map["off-nadir"]: + elif catalog_id in self.angle_catalog_map['off-nadir']: offnadir.append(_file) - elif catalog_id in self.angle_catalog_map["nadir"]: + elif catalog_id in self.angle_catalog_map['nadir']: nadir.append(_file) angle_file_map = { - "nadir": nadir, - "off-nadir": offnadir, - "very-off-nadir": veryoffnadir, + 'nadir': nadir, + 'off-nadir': offnadir, + 'very-off-nadir': veryoffnadir, } if not self.angles: @@ -1030,30 +1030,30 @@ class SpaceNet5(SpaceNet3): .. versionadded:: 0.2 """ - dataset_id = "spacenet5" + dataset_id = 'spacenet5' collection_md5_dict = { - "sn5_AOI_7_Moscow": "b18107f878152fe7e75444373c320cba", - "sn5_AOI_8_Mumbai": "1f1e2b3c26fbd15bfbcdbb6b02ae051c", + 'sn5_AOI_7_Moscow': 'b18107f878152fe7e75444373c320cba', + 'sn5_AOI_8_Mumbai': '1f1e2b3c26fbd15bfbcdbb6b02ae051c', } imagery = { - "MS": "MS.tif", - "PAN": "PAN.tif", - "PS-MS": "PS-MS.tif", - "PS-RGB": "PS-RGB.tif", + 'MS': 'MS.tif', + 'PAN': 'PAN.tif', + 'PS-MS': 'PS-MS.tif', + 'PS-RGB': 'PS-RGB.tif', } chip_size = { - "MS": (325, 325), - "PAN": (1300, 1300), - "PS-MS": (1300, 1300), - "PS-RGB": (1300, 1300), + 'MS': (325, 325), + 'PAN': (1300, 1300), + 'PS-MS': (1300, 1300), + 'PS-RGB': (1300, 1300), } - label_glob = "labels.geojson" + label_glob = 'labels.geojson' def __init__( self, - root: str = "data", - image: str = "PS-RGB", + root: str = 'data', + image: str = 'PS-RGB', speed_mask: bool | None = False, collections: list[str] = [], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -1162,30 +1162,30 @@ class SpaceNet6(SpaceNet): .. versionadded:: 0.4 """ - dataset_id = "spacenet6" - collections = ["sn6_AOI_11_Rotterdam"] + dataset_id = 'spacenet6' + collections = ['sn6_AOI_11_Rotterdam'] # This is actually the metadata hash - collection_md5_dict = {"sn6_AOI_11_Rotterdam": "66f7312218fec67a1e0b3b02b22c95cc"} + collection_md5_dict = {'sn6_AOI_11_Rotterdam': '66f7312218fec67a1e0b3b02b22c95cc'} imagery = { - "PAN": "PAN.tif", - "RGBNIR": "RGBNIR.tif", - "PS-RGB": "PS-RGB.tif", - "PS-RGBNIR": "PS-RGBNIR.tif", - "SAR-Intensity": "SAR-Intensity.tif", + 'PAN': 'PAN.tif', + 'RGBNIR': 'RGBNIR.tif', + 'PS-RGB': 'PS-RGB.tif', + 'PS-RGBNIR': 'PS-RGBNIR.tif', + 'SAR-Intensity': 'SAR-Intensity.tif', } chip_size = { - "PAN": (900, 900), - "RGBNIR": (450, 450), - "PS-RGB": (900, 900), - "PS-RGBNIR": (900, 900), - "SAR-Intensity": (900, 900), + 'PAN': (900, 900), + 'RGBNIR': (450, 450), + 'PS-RGB': (900, 900), + 'PS-RGBNIR': (900, 900), + 'SAR-Intensity': (900, 900), } - label_glob = "labels.geojson" + label_glob = 'labels.geojson' def __init__( self, - root: str = "data", - image: str = "PS-RGB", + root: str = 'data', + image: str = 'PS-RGB', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, api_key: str | None = None, @@ -1223,10 +1223,10 @@ def __download(self, api_key: str | None = None) -> None: """ if os.path.exists( os.path.join( - self.root, self.dataset_id, self.collections[0], "collection.json" + self.root, self.dataset_id, self.collections[0], 'collection.json' ) ): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key) @@ -1268,22 +1268,22 @@ class SpaceNet7(SpaceNet): .. versionadded:: 0.2 """ - dataset_id = "spacenet7" + dataset_id = 'spacenet7' collection_md5_dict = { - "sn7_train_source": "9f8cc109d744537d087bd6ff33132340", - "sn7_train_labels": "16f873e3f0f914d95a916fb39b5111b5", - "sn7_test_source": "e97914f58e962bba3e898f08a14f83b2", + 'sn7_train_source': '9f8cc109d744537d087bd6ff33132340', + 'sn7_train_labels': '16f873e3f0f914d95a916fb39b5111b5', + 'sn7_test_source': 'e97914f58e962bba3e898f08a14f83b2', } - imagery = {"img": "mosaic.tif"} - chip_size = {"img": (1023, 1023)} + imagery = {'img': 'mosaic.tif'} + chip_size = {'img': (1023, 1023)} - label_glob = "labels.geojson" + label_glob = 'labels.geojson' def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, api_key: str | None = None, @@ -1305,16 +1305,16 @@ def __init__( """ self.root = root self.split = split - self.filename = self.imagery["img"] + self.filename = self.imagery['img'] self.transforms = transforms self.checksum = checksum - assert split in {"train", "test"}, "Invalid split" + assert split in {'train', 'test'}, 'Invalid split' - if split == "test": - self.collections = ["sn7_test_source"] + if split == 'test': + self.collections = ['sn7_test_source'] else: - self.collections = ["sn7_train_source", "sn7_train_labels"] + self.collections = ['sn7_train_source', 'sn7_train_labels'] to_be_downloaded = self._check_integrity() @@ -1336,21 +1336,21 @@ def _load_files(self, root: str) -> list[dict[str, str]]: list of dicts containing paths for images and labels (if train split) """ files = [] - if self.split == "train": + if self.split == 'train': imgs = sorted( - glob.glob(os.path.join(root, "sn7_train_source", "*", self.filename)) + glob.glob(os.path.join(root, 'sn7_train_source', '*', self.filename)) ) lbls = sorted( - glob.glob(os.path.join(root, "sn7_train_labels", "*", self.label_glob)) + glob.glob(os.path.join(root, 'sn7_train_labels', '*', self.label_glob)) ) for img, lbl in zip(imgs, lbls): - files.append({"image_path": img, "label_path": lbl}) + files.append({'image_path': img, 'label_path': lbl}) else: imgs = sorted( - glob.glob(os.path.join(root, "sn7_test_source", "*", self.filename)) + glob.glob(os.path.join(root, 'sn7_test_source', '*', self.filename)) ) for img in imgs: - files.append({"image_path": img}) + files.append({'image_path': img}) return files def __getitem__(self, index: int) -> dict[str, Tensor]: @@ -1363,14 +1363,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data at that index """ files = self.files[index] - img, tfm, raster_crs = self._load_image(files["image_path"]) + img, tfm, raster_crs = self._load_image(files['image_path']) h, w = img.shape[1:] - ch, cw = self.chip_size["img"] - sample = {"image": img[:, :ch, :cw]} - if self.split == "train": - mask = self._load_mask(files["label_path"], tfm, raster_crs, (h, w)) - sample["mask"] = mask[:ch, :cw] + ch, cw = self.chip_size['img'] + sample = {'image': img[:, :ch, :cw]} + if self.split == 'train': + mask = self._load_mask(files['label_path'], tfm, raster_crs, (h, w)) + sample['mask'] = mask[:ch, :cw] if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index ef875fd9c1b..7de376eb57c 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -16,11 +16,11 @@ from .utils import BoundingBox __all__ = ( - "random_bbox_assignment", - "random_bbox_splitting", - "random_grid_cell_assignment", - "roi_split", - "time_series_split", + 'random_bbox_assignment', + 'random_bbox_splitting', + 'random_grid_cell_assignment', + 'roi_split', + 'time_series_split', ) @@ -73,7 +73,7 @@ def random_bbox_assignment( ) if any(n <= 0 for n in lengths): - raise ValueError("All items in input lengths must be greater than 0.") + raise ValueError('All items in input lengths must be greater than 0.') if isclose(sum(lengths), 1): lengths = _fractions_to_lengths(lengths, len(dataset)) @@ -123,10 +123,10 @@ def random_bbox_splitting( .. versionadded:: 0.5 """ if not isclose(sum(fractions), 1): - raise ValueError("Sum of input fractions must equal 1.") + raise ValueError('Sum of input fractions must equal 1.') if any(n <= 0 for n in fractions): - raise ValueError("All items in input fractions must be greater than 0.") + raise ValueError('All items in input fractions must be greater than 0.') new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in fractions @@ -191,13 +191,13 @@ def random_grid_cell_assignment( .. versionadded:: 0.5 """ if not isclose(sum(fractions), 1): - raise ValueError("Sum of input fractions must equal 1.") + raise ValueError('Sum of input fractions must equal 1.') if any(n <= 0 for n in fractions): - raise ValueError("All items in input fractions must be greater than 0.") + raise ValueError('All items in input fractions must be greater than 0.') if grid_size < 2: - raise ValueError("Input grid_size must be greater than 1.") + raise ValueError('Input grid_size must be greater than 1.') new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in fractions @@ -316,7 +316,7 @@ def time_series_split( ) if any(n <= 0 for n in lengths): - raise ValueError("All items in input lengths must be greater than 0.") + raise ValueError('All items in input lengths must be greater than 0.') if isclose(sum(lengths), 1): lengths = [totalt * f for f in lengths] @@ -336,7 +336,7 @@ def time_series_split( for i, (start, end) in enumerate(lengths): if start >= end: raise ValueError( - "Pairs of timestamps in lengths must have end greater than start." + 'Pairs of timestamps in lengths must have end greater than start.' ) if start < mint or end > maxt: diff --git a/torchgeo/datasets/ssl4eo.py b/torchgeo/datasets/ssl4eo.py index 39ddbc187cd..04cb5923034 100644 --- a/torchgeo/datasets/ssl4eo.py +++ b/torchgeo/datasets/ssl4eo.py @@ -99,70 +99,70 @@ class _Metadata(TypedDict): rgb_bands: list[int] metadata: dict[str, _Metadata] = { - "tm_toa": {"num_bands": 7, "rgb_bands": [2, 1, 0]}, - "etm_toa": {"num_bands": 9, "rgb_bands": [2, 1, 0]}, - "etm_sr": {"num_bands": 6, "rgb_bands": [2, 1, 0]}, - "oli_tirs_toa": {"num_bands": 11, "rgb_bands": [3, 2, 1]}, - "oli_sr": {"num_bands": 7, "rgb_bands": [3, 2, 1]}, + 'tm_toa': {'num_bands': 7, 'rgb_bands': [2, 1, 0]}, + 'etm_toa': {'num_bands': 9, 'rgb_bands': [2, 1, 0]}, + 'etm_sr': {'num_bands': 6, 'rgb_bands': [2, 1, 0]}, + 'oli_tirs_toa': {'num_bands': 11, 'rgb_bands': [3, 2, 1]}, + 'oli_sr': {'num_bands': 7, 'rgb_bands': [3, 2, 1]}, } - url = "https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}" # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}' # noqa: E501 checksums = { - "tm_toa": { - "aa": "553795b8d73aa253445b1e67c5b81f11", - "ab": "e9e0739b5171b37d16086cb89ab370e8", - "ac": "6cb27189f6abe500c67343bfcab2432c", - "ad": "15a885d4f544d0c1849523f689e27402", - "ae": "35523336bf9f8132f38ff86413dcd6dc", - "af": "fa1108436034e6222d153586861f663b", - "ag": "d5c91301c115c00acaf01ceb3b78c0fe", + 'tm_toa': { + 'aa': '553795b8d73aa253445b1e67c5b81f11', + 'ab': 'e9e0739b5171b37d16086cb89ab370e8', + 'ac': '6cb27189f6abe500c67343bfcab2432c', + 'ad': '15a885d4f544d0c1849523f689e27402', + 'ae': '35523336bf9f8132f38ff86413dcd6dc', + 'af': 'fa1108436034e6222d153586861f663b', + 'ag': 'd5c91301c115c00acaf01ceb3b78c0fe', }, - "etm_toa": { - "aa": "587c3efc7d0a0c493dfb36139d91ccdf", - "ab": "ec34f33face893d2d8fd152496e1df05", - "ac": "947acc2c6bc3c1d1415ac92bab695380", - "ad": "e31273dec921e187f5c0dc73af5b6102", - "ae": "43390a47d138593095e9a6775ae7dc75", - "af": "082881464ca6dcbaa585f72de1ac14fd", - "ag": "de2511aaebd640bd5e5404c40d7494cb", - "ah": "124c5fbcda6871f27524ae59480dabc5", - "ai": "12b5f94824b7f102df30a63b1139fc57", + 'etm_toa': { + 'aa': '587c3efc7d0a0c493dfb36139d91ccdf', + 'ab': 'ec34f33face893d2d8fd152496e1df05', + 'ac': '947acc2c6bc3c1d1415ac92bab695380', + 'ad': 'e31273dec921e187f5c0dc73af5b6102', + 'ae': '43390a47d138593095e9a6775ae7dc75', + 'af': '082881464ca6dcbaa585f72de1ac14fd', + 'ag': 'de2511aaebd640bd5e5404c40d7494cb', + 'ah': '124c5fbcda6871f27524ae59480dabc5', + 'ai': '12b5f94824b7f102df30a63b1139fc57', }, - "etm_sr": { - "aa": "baa36a9b8e42e234bb44ab4046f8f2ac", - "ab": "9fb0f948c76154caabe086d2d0008fdf", - "ac": "99a55367178373805d357a096d68e418", - "ad": "59d53a643b9e28911246d4609744ef25", - "ae": "7abfcfc57528cb9c619c66ee307a2cc9", - "af": "bb23cf26cc9fe156e7a68589ec69f43e", - "ag": "97347e5a81d24c93cf33d99bb46a5b91", + 'etm_sr': { + 'aa': 'baa36a9b8e42e234bb44ab4046f8f2ac', + 'ab': '9fb0f948c76154caabe086d2d0008fdf', + 'ac': '99a55367178373805d357a096d68e418', + 'ad': '59d53a643b9e28911246d4609744ef25', + 'ae': '7abfcfc57528cb9c619c66ee307a2cc9', + 'af': 'bb23cf26cc9fe156e7a68589ec69f43e', + 'ag': '97347e5a81d24c93cf33d99bb46a5b91', }, - "oli_tirs_toa": { - "aa": "4711369b861c856ebfadbc861e928d3a", - "ab": "660a96cda1caf54df837c4b3c6c703f6", - "ac": "c9b6a1117916ba318ac3e310447c60dc", - "ad": "b8502e9e92d4a7765a287d21d7c9146c", - "ae": "5c11c14cfe45f78de4f6d6faf03f3146", - "af": "5b0ed3901be1000137ddd3a6d58d5109", - "ag": "a3b6734f8fe6763dcf311c9464a05d5b", - "ah": "5e55f92e3238a8ab3e471be041f8111b", - "ai": "e20617f73d0232a0c0472ce336d4c92f", + 'oli_tirs_toa': { + 'aa': '4711369b861c856ebfadbc861e928d3a', + 'ab': '660a96cda1caf54df837c4b3c6c703f6', + 'ac': 'c9b6a1117916ba318ac3e310447c60dc', + 'ad': 'b8502e9e92d4a7765a287d21d7c9146c', + 'ae': '5c11c14cfe45f78de4f6d6faf03f3146', + 'af': '5b0ed3901be1000137ddd3a6d58d5109', + 'ag': 'a3b6734f8fe6763dcf311c9464a05d5b', + 'ah': '5e55f92e3238a8ab3e471be041f8111b', + 'ai': 'e20617f73d0232a0c0472ce336d4c92f', }, - "oli_sr": { - "aa": "ca338511c9da4dcbfddda28b38ca9e0a", - "ab": "7f4100aa9791156958dccf1bb2a88ae0", - "ac": "6b0f18be2b63ba9da194cc7886dbbc01", - "ad": "57efbcc894d8da8c4975c29437d8b775", - "ae": "2594a0a856897f3f5a902c830186872d", - "af": "a03839311a2b3dc17dfb9fb9bc4f9751", - "ag": "6a329d8fd9fdd591e400ab20f9d11dea", + 'oli_sr': { + 'aa': 'ca338511c9da4dcbfddda28b38ca9e0a', + 'ab': '7f4100aa9791156958dccf1bb2a88ae0', + 'ac': '6b0f18be2b63ba9da194cc7886dbbc01', + 'ad': '57efbcc894d8da8c4975c29437d8b775', + 'ae': '2594a0a856897f3f5a902c830186872d', + 'af': 'a03839311a2b3dc17dfb9fb9bc4f9751', + 'ag': '6a329d8fd9fdd591e400ab20f9d11dea', }, } def __init__( self, - root: str = "data", - split: str = "oli_sr", + root: str = 'data', + split: str = 'oli_sr', seasons: int = 1, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -187,7 +187,7 @@ def __init__( assert seasons in range(1, 5) self.root = root - self.subdir = os.path.join(root, f"ssl4eo_l_{split}") + self.subdir = os.path.join(root, f'ssl4eo_l_{split}') self.split = split self.seasons = seasons self.transforms = transforms @@ -214,12 +214,12 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: images = [] for subdir in subdirs: directory = os.path.join(root, subdir) - filename = os.path.join(directory, "all_bands.tif") + filename = os.path.join(directory, 'all_bands.tif') with rasterio.open(filename) as f: image = f.read() images.append(torch.from_numpy(image.astype(np.float32))) - sample = {"image": torch.cat(images)} + sample = {'image': torch.cat(images)} if self.transforms is not None: sample = self.transforms(sample) @@ -237,14 +237,14 @@ def __len__(self) -> int: def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist - path = os.path.join(self.subdir, "00000*", "*", "all_bands.tif") + path = os.path.join(self.subdir, '00000*', '*', 'all_bands.tif') if glob.glob(path): return # Check if the tar.gz files have already been downloaded exists = [] for suffix in self.checksums[self.split]: - path = self.subdir + f".tar.gz{suffix}" + path = self.subdir + f'.tar.gz{suffix}' exists.append(os.path.exists(path)) if all(exists): @@ -272,10 +272,10 @@ def _extract(self) -> None: """Extract the dataset.""" # Concatenate all tarballs together chunk_size = 2**15 # same as torchvision - path = self.subdir + ".tar.gz" - with open(path, "wb") as f: + path = self.subdir + '.tar.gz' + with open(path, 'wb') as f: for suffix in self.checksums[self.split]: - with open(path + suffix, "rb") as g: + with open(path + suffix, 'rb') as g: while chunk := g.read(chunk_size): f.write(chunk) @@ -301,18 +301,18 @@ def plot( fig, axes = plt.subplots( ncols=self.seasons, squeeze=False, figsize=(4 * self.seasons, 4) ) - num_bands = self.metadata[self.split]["num_bands"] - rgb_bands = self.metadata[self.split]["rgb_bands"] + num_bands = self.metadata[self.split]['num_bands'] + rgb_bands = self.metadata[self.split]['rgb_bands'] for i in range(self.seasons): - image = sample["image"][i * num_bands : (i + 1) * num_bands].byte() + image = sample['image'][i * num_bands : (i + 1) * num_bands].byte() image = image[rgb_bands].permute(1, 2, 0) axes[0, i].imshow(image) - axes[0, i].axis("off") + axes[0, i].axis('off') if show_titles: - axes[0, i].set_title(f"Split {self.split}, Season {i + 1}") + axes[0, i].set_title(f'Split {self.split}, Season {i + 1}') if suptitle is not None: plt.suptitle(suptitle) @@ -357,54 +357,54 @@ class _Metadata(TypedDict): bands: list[str] metadata: dict[str, _Metadata] = { - "s1": { - "filename": "s1.tar.gz", - "md5": "51ee23b33eb0a2f920bda25225072f3a", - "bands": ["VV", "VH"], + 's1': { + 'filename': 's1.tar.gz', + 'md5': '51ee23b33eb0a2f920bda25225072f3a', + 'bands': ['VV', 'VH'], }, - "s2c": { - "filename": "s2_l1c.tar.gz", - "md5": "b4f8b03c365e4a85780ded600b7497ab", - "bands": [ - "B1", - "B2", - "B3", - "B4", - "B5", - "B6", - "B7", - "B8", - "B8A", - "B9", - "B10", - "B11", - "B12", + 's2c': { + 'filename': 's2_l1c.tar.gz', + 'md5': 'b4f8b03c365e4a85780ded600b7497ab', + 'bands': [ + 'B1', + 'B2', + 'B3', + 'B4', + 'B5', + 'B6', + 'B7', + 'B8', + 'B8A', + 'B9', + 'B10', + 'B11', + 'B12', ], }, - "s2a": { - "filename": "s2_l2a.tar.gz", - "md5": "85496cd9d6742aee03b6a1c99cee0ac1", - "bands": [ - "B1", - "B2", - "B3", - "B4", - "B5", - "B6", - "B7", - "B8", - "B8A", - "B9", - "B11", - "B12", + 's2a': { + 'filename': 's2_l2a.tar.gz', + 'md5': '85496cd9d6742aee03b6a1c99cee0ac1', + 'bands': [ + 'B1', + 'B2', + 'B3', + 'B4', + 'B5', + 'B6', + 'B7', + 'B8', + 'B8A', + 'B9', + 'B11', + 'B12', ], }, } def __init__( self, - root: str = "data", - split: str = "s2c", + root: str = 'data', + split: str = 's2c', seasons: int = 1, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, @@ -434,7 +434,7 @@ def __init__( self.transforms = transforms self.checksum = checksum - self.bands = self.metadata[self.split]["bands"] + self.bands = self.metadata[self.split]['bands'] self._verify() @@ -447,7 +447,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: image sample """ - root = os.path.join(self.root, self.split, f"{index:07}") + root = os.path.join(self.root, self.split, f'{index:07}') subdirs = os.listdir(root) subdirs = random.sample(subdirs, self.seasons) @@ -455,12 +455,12 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: for subdir in subdirs: directory = os.path.join(root, subdir) for band in self.bands: - filename = os.path.join(directory, f"{band}.tif") + filename = os.path.join(directory, f'{band}.tif') with rasterio.open(filename) as f: image = f.read(out_shape=(1, self.size, self.size)) images.append(torch.from_numpy(image.astype(np.float32))) - sample = {"image": torch.cat(images)} + sample = {'image': torch.cat(images)} if self.transforms is not None: sample = self.transforms(sample) @@ -483,9 +483,9 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - filename = self.metadata[self.split]["filename"] + filename = self.metadata[self.split]['filename'] zip_path = os.path.join(self.root, filename) - md5 = self.metadata[self.split]["md5"] if self.checksum else None + md5 = self.metadata[self.split]['md5'] if self.checksum else None integrity = check_integrity(zip_path, md5) if integrity: self._extract() @@ -494,7 +494,7 @@ def _verify(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - filename = self.metadata[self.split]["filename"] + filename = self.metadata[self.split]['filename'] extract_archive(os.path.join(self.root, filename)) def plot( @@ -513,7 +513,7 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - nrows = 2 if self.split == "s1" else 1 + nrows = 2 if self.split == 's1' else 1 fig, axes = plt.subplots( nrows=nrows, ncols=self.seasons, @@ -522,9 +522,9 @@ def plot( ) for i in range(self.seasons): - image = sample["image"][i * len(self.bands) : (i + 1) * len(self.bands)] + image = sample['image'][i * len(self.bands) : (i + 1) * len(self.bands)] - if self.split == "s1": + if self.split == 's1': axes[0, i].imshow(image[0]) axes[1, i].imshow(image[1]) else: @@ -532,10 +532,10 @@ def plot( image = torch.clamp(image / 3000, min=0, max=1) axes[0, i].imshow(image) - axes[0, i].axis("off") + axes[0, i].axis('off') if show_titles: - axes[0, i].set_title(f"Split {self.split}, Season {i + 1}") + axes[0, i].set_title(f'Split {self.split}, Season {i + 1}') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py index d37097928d0..fca210ff152 100644 --- a/torchgeo/datasets/ssl4eo_benchmark.py +++ b/torchgeo/datasets/ssl4eo_benchmark.py @@ -47,69 +47,69 @@ class SSL4EOLBenchmark(NonGeoDataset): .. versionadded:: 0.5 """ # noqa: E501 - url = "https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz" # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz' # noqa: E501 - valid_sensors = ["tm_toa", "etm_toa", "etm_sr", "oli_tirs_toa", "oli_sr"] - valid_products = ["cdl", "nlcd"] - valid_splits = ["train", "val", "test"] + valid_sensors = ['tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr'] + valid_products = ['cdl', 'nlcd'] + valid_splits = ['train', 'val', 'test'] - image_root = "ssl4eo_l_{}_benchmark" + image_root = 'ssl4eo_l_{}_benchmark' img_md5s = { - "tm_toa": "8e3c5bcd56d3780a442f1332013b8d15", - "etm_toa": "1b051c7fe4d61c581b341370c9e76f1f", - "etm_sr": "34a24fa89a801654f8d01e054662c8cd", - "oli_tirs_toa": "6e9d7cf0392e1de2cbdb39962ba591aa", - "oli_sr": "0700cd15cc2366fe68c2f8c02fa09a15", + 'tm_toa': '8e3c5bcd56d3780a442f1332013b8d15', + 'etm_toa': '1b051c7fe4d61c581b341370c9e76f1f', + 'etm_sr': '34a24fa89a801654f8d01e054662c8cd', + 'oli_tirs_toa': '6e9d7cf0392e1de2cbdb39962ba591aa', + 'oli_sr': '0700cd15cc2366fe68c2f8c02fa09a15', } mask_dir_dict = { - "tm_toa": "ssl4eo_l_tm_{}", - "etm_toa": "ssl4eo_l_etm_{}", - "etm_sr": "ssl4eo_l_etm_{}", - "oli_tirs_toa": "ssl4eo_l_oli_{}", - "oli_sr": "ssl4eo_l_oli_{}", + 'tm_toa': 'ssl4eo_l_tm_{}', + 'etm_toa': 'ssl4eo_l_etm_{}', + 'etm_sr': 'ssl4eo_l_etm_{}', + 'oli_tirs_toa': 'ssl4eo_l_oli_{}', + 'oli_sr': 'ssl4eo_l_oli_{}', } mask_md5s = { - "tm": { - "cdl": "3d676770ffb56c7e222a7192a652a846", - "nlcd": "261149d7614fcfdcb3be368eefa825c7", + 'tm': { + 'cdl': '3d676770ffb56c7e222a7192a652a846', + 'nlcd': '261149d7614fcfdcb3be368eefa825c7', }, - "etm": { - "cdl": "008098c968544049eaf7b307e14241de", - "nlcd": "9c031049d665202ba42ac1d89b687999", + 'etm': { + 'cdl': '008098c968544049eaf7b307e14241de', + 'nlcd': '9c031049d665202ba42ac1d89b687999', }, - "oli": { - "cdl": "1cb057de6eafeca975deb35cb9fb036f", - "nlcd": "9de0d6d4d0b94313b80450f650813922", + 'oli': { + 'cdl': '1cb057de6eafeca975deb35cb9fb036f', + 'nlcd': '9de0d6d4d0b94313b80450f650813922', }, } year_dict = { - "tm_toa": 2011, - "etm_toa": 2019, - "etm_sr": 2019, - "oli_tirs_toa": 2019, - "oli_sr": 2019, + 'tm_toa': 2011, + 'etm_toa': 2019, + 'etm_sr': 2019, + 'oli_tirs_toa': 2019, + 'oli_sr': 2019, } rgb_indices = { - "tm_toa": [2, 1, 0], - "etm_toa": [2, 1, 0], - "etm_sr": [2, 1, 0], - "oli_tirs_toa": [3, 2, 1], - "oli_sr": [3, 2, 1], + 'tm_toa': [2, 1, 0], + 'etm_toa': [2, 1, 0], + 'etm_sr': [2, 1, 0], + 'oli_tirs_toa': [3, 2, 1], + 'oli_sr': [3, 2, 1], } split_percentages = [0.7, 0.15, 0.15] - cmaps = {"nlcd": NLCD.cmap, "cdl": CDL.cmap} + cmaps = {'nlcd': NLCD.cmap, 'cdl': CDL.cmap} def __init__( self, - root: str = "data", - sensor: str = "oli_sr", - product: str = "cdl", - split: str = "train", + root: str = 'data', + sensor: str = 'oli_sr', + product: str = 'cdl', + split: str = 'train', classes: list[int] | None = None, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -135,15 +135,15 @@ def __init__( """ assert ( sensor in self.valid_sensors - ), f"Only supports one of {self.valid_sensors}, but found {sensor}." + ), f'Only supports one of {self.valid_sensors}, but found {sensor}.' self.sensor = sensor assert ( product in self.valid_products - ), f"Only supports one of {self.valid_products}, but found {product}." + ), f'Only supports one of {self.valid_products}, but found {product}.' self.product = product assert ( split in self.valid_splits - ), f"Only supports one of {self.valid_splits}, but found {split}." + ), f'Only supports one of {self.valid_splits}, but found {split}.' self.split = split self.cmap = self.cmaps[product] @@ -152,8 +152,8 @@ def __init__( assert ( set(classes) <= self.cmap.keys() - ), f"Only the following classes are valid: {list(self.cmap.keys())}." - assert 0 in classes, "Classes must include the background class: 0" + ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert 0 in classes, 'Classes must include the background class: 0' self.root = root self.classes = classes @@ -178,7 +178,7 @@ def __init__( sample_indices = np.arange(len(self.sample_collection)) np.random.shuffle(sample_indices) groups = np.split(sample_indices, cutoffs) - split_indices = {"train": groups[0], "val": groups[1], "test": groups[2]}[ + split_indices = {'train': groups[0], 'val': groups[1], 'test': groups[2]}[ self.split ] @@ -192,14 +192,14 @@ def __init__( def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist - img_pathname = os.path.join(self.root, self.img_dir_name, "**", "all_bands.tif") + img_pathname = os.path.join(self.root, self.img_dir_name, '**', 'all_bands.tif') exists = [] exists.append(bool(glob.glob(img_pathname, recursive=True))) mask_pathname = os.path.join( self.root, self.mask_dir_name, - "**", - f"{self.product}_{self.year_dict[self.sensor]}.tif", + '**', + f'{self.product}_{self.year_dict[self.sensor]}.tif', ) exists.append(bool(glob.glob(mask_pathname, recursive=True))) @@ -207,10 +207,10 @@ def _verify(self) -> None: return # Check if the tar.gz files have already been downloaded exists = [] - img_pathname = os.path.join(self.root, f"{self.img_dir_name}.tar.gz") + img_pathname = os.path.join(self.root, f'{self.img_dir_name}.tar.gz') exists.append(os.path.exists(img_pathname)) - mask_pathname = os.path.join(self.root, f"{self.mask_dir_name}.tar.gz") + mask_pathname = os.path.join(self.root, f'{self.mask_dir_name}.tar.gz') exists.append(os.path.exists(mask_pathname)) if all(exists): @@ -238,7 +238,7 @@ def _download(self) -> None: self.url.format(self.mask_dir_name), self.root, md5=( - self.mask_md5s[self.sensor.split("_")[0]][self.product] + self.mask_md5s[self.sensor.split('_')[0]][self.product] if self.checksum else None ), @@ -246,10 +246,10 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - img_pathname = os.path.join(self.root, f"{self.img_dir_name}.tar.gz") + img_pathname = os.path.join(self.root, f'{self.img_dir_name}.tar.gz') extract_archive(img_pathname) - mask_pathname = os.path.join(self.root, f"{self.mask_dir_name}.tar.gz") + mask_pathname = os.path.join(self.root, f'{self.mask_dir_name}.tar.gz') extract_archive(mask_pathname) def __getitem__(self, index: int) -> dict[str, Tensor]: @@ -264,8 +264,8 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: img_path, mask_path = self.sample_collection[index] sample = { - "image": self._load_image(img_path), - "mask": self._load_mask(mask_path), + 'image': self._load_image(img_path), + 'mask': self._load_mask(mask_path), } if self.transforms is not None: @@ -284,14 +284,14 @@ def __len__(self) -> int: def retrieve_sample_collection(self) -> list[tuple[str, str]]: """Retrieve paths to samples in data directory.""" img_paths = glob.glob( - os.path.join(self.root, self.img_dir_name, "**", "all_bands.tif"), + os.path.join(self.root, self.img_dir_name, '**', 'all_bands.tif'), recursive=True, ) img_paths = sorted(img_paths) sample_collection: list[tuple[str, str]] = [] for img_path in img_paths: mask_path = img_path.replace(self.img_dir_name, self.mask_dir_name).replace( - "all_bands.tif", f"{self.product}_{self.year_dict[self.sensor]}.tif" + 'all_bands.tif', f'{self.product}_{self.year_dict[self.sensor]}.tif' ) sample_collection.append((img_path, mask_path)) return sample_collection @@ -340,29 +340,29 @@ def plot( a matplotlib Figure with the rendered sample """ ncols = 2 - image = sample["image"][self.rgb_indices[self.sensor]].permute(1, 2, 0) + image = sample['image'][self.rgb_indices[self.sensor]].permute(1, 2, 0) image = image / 255 - mask = sample["mask"].squeeze(0) + mask = sample['mask'].squeeze(0) - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - pred = sample["prediction"].squeeze(0) + pred = sample['prediction'].squeeze(0) ncols = 3 fig, ax = plt.subplots(ncols=ncols, figsize=(4 * ncols, 4)) ax[0].imshow(image) - ax[0].axis("off") - ax[1].imshow(self.ordinal_cmap[mask], interpolation="none") - ax[1].axis("off") + ax[0].axis('off') + ax[1].imshow(self.ordinal_cmap[mask], interpolation='none') + ax[1].axis('off') if show_titles: - ax[0].set_title("Image") - ax[1].set_title("Mask") + ax[0].set_title('Image') + ax[1].set_title('Mask') if showing_predictions: - ax[2].imshow(self.ordinal_cmap[pred], interpolation="none") + ax[2].imshow(self.ordinal_cmap[pred], interpolation='none') if show_titles: - ax[2].set_title("Prediction") + ax[2].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index 03cdf06b3d1..63e81dd3e37 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -46,21 +46,21 @@ class SustainBenchCropYield(NonGeoDataset): .. versionadded:: 0.5 """ # noqa: E501 - valid_countries = ["usa", "brazil", "argentina"] + valid_countries = ['usa', 'brazil', 'argentina'] - md5 = "362bad07b51a1264172b8376b39d1fc9" + md5 = '362bad07b51a1264172b8376b39d1fc9' - url = "https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link" # noqa: E501 + url = 'https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link' # noqa: E501 - dir = "soybeans" + dir = 'soybeans' - valid_splits = ["train", "dev", "test"] + valid_splits = ['train', 'dev', 'test'] def __init__( self, - root: str = "data", - split: str = "train", - countries: list[str] = ["usa"], + root: str = 'data', + split: str = 'train', + countries: list[str] = ['usa'], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, checksum: bool = False, @@ -83,12 +83,12 @@ def __init__( """ assert set(countries).issubset( self.valid_countries - ), f"Please choose a subset of these valid countried: {self.valid_countries}." + ), f'Please choose a subset of these valid countried: {self.valid_countries}.' self.countries = countries assert ( split in self.valid_splits - ), f"Pleas choose one of these valid data splits {self.valid_splits}." + ), f'Pleas choose one of these valid data splits {self.valid_splits}.' self.split = split self.root = root @@ -103,16 +103,16 @@ def __init__( for country in self.countries: image_file_path = os.path.join( - self.root, self.dir, country, f"{self.split}_hists.npz" + self.root, self.dir, country, f'{self.split}_hists.npz' ) - target_file_path = image_file_path.replace("_hists", "_yields") - years_file_path = image_file_path.replace("_hists", "_years") - ndvi_file_path = image_file_path.replace("_hists", "_ndvi") - - npz_file = np.load(image_file_path)["data"] - target_npz_file = np.load(target_file_path)["data"] - year_npz_file = np.load(years_file_path)["data"] - ndvi_npz_file = np.load(ndvi_file_path)["data"] + target_file_path = image_file_path.replace('_hists', '_yields') + years_file_path = image_file_path.replace('_hists', '_years') + ndvi_file_path = image_file_path.replace('_hists', '_ndvi') + + npz_file = np.load(image_file_path)['data'] + target_npz_file = np.load(target_file_path)['data'] + year_npz_file = np.load(years_file_path)['data'] + ndvi_npz_file = np.load(ndvi_file_path)['data'] num_data_points = npz_file.shape[0] for idx in range(num_data_points): sample = npz_file[idx] @@ -124,9 +124,9 @@ def __init__( ndvi = ndvi_npz_file[idx] features = { - "label": torch.tensor(target).to(torch.float32), - "year": torch.tensor(int(year)), - "ndvi": torch.from_numpy(ndvi).to(dtype=torch.float32), + 'label': torch.tensor(target).to(torch.float32), + 'year': torch.tensor(int(year)), + 'ndvi': torch.from_numpy(ndvi).to(dtype=torch.float32), } self.features.append(features) @@ -147,7 +147,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and label at that index """ - sample: dict[str, Tensor] = {"image": self.images[index]} + sample: dict[str, Tensor] = {'image': self.images[index]} sample.update(self.features[index]) if self.transforms is not None: @@ -163,7 +163,7 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - pathname = os.path.join(self.root, self.dir) + ".zip" + pathname = os.path.join(self.root, self.dir) + '.zip' if os.path.exists(pathname): self._extract() return @@ -181,14 +181,14 @@ def _download(self) -> None: download_url( self.url, self.root, - filename=self.dir + ".zip", + filename=self.dir + '.zip', md5=self.md5 if self.checksum else None, ) self._extract() def _extract(self) -> None: """Extract the dataset.""" - zipfile_path = os.path.join(self.root, self.dir) + ".zip" + zipfile_path = os.path.join(self.root, self.dir) + '.zip' extract_archive(zipfile_path, self.root) def plot( @@ -210,21 +210,21 @@ def plot( a matplotlib Figure with the rendered sample """ - image, label = sample["image"], sample["label"].item() + image, label = sample['image'], sample['label'].item() - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = sample["prediction"].item() + prediction = sample['prediction'].item() fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(image.permute(1, 2, 0)[:, :, band_idx]) - ax.axis("off") + ax.axis('off') if show_titles: - title = f"Label: {label:.3f}" + title = f'Label: {label:.3f}' if showing_predictions: - title += f"\nPrediction: {prediction:.3f}" + title += f'\nPrediction: {prediction:.3f}' ax.set_title(title) if suptitle is not None: diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index d37d8f1ba4a..abcf3679563 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -65,28 +65,28 @@ class UCMerced(NonGeoClassificationDataset): * https://dl.acm.org/doi/10.1145/1869790.1869829 """ - url = "https://hf.co/datasets/torchgeo/ucmerced/resolve/d0af6e2eeea2322af86078068bd83337148a2149/UCMerced_LandUse.zip" # noqa: E501 - filename = "UCMerced_LandUse.zip" - md5 = "5b7ec56793786b6dc8a908e8854ac0e4" + url = 'https://hf.co/datasets/torchgeo/ucmerced/resolve/d0af6e2eeea2322af86078068bd83337148a2149/UCMerced_LandUse.zip' # noqa: E501 + filename = 'UCMerced_LandUse.zip' + md5 = '5b7ec56793786b6dc8a908e8854ac0e4' - base_dir = os.path.join("UCMerced_LandUse", "Images") + base_dir = os.path.join('UCMerced_LandUse', 'Images') - splits = ["train", "val", "test"] + splits = ['train', 'val', 'test'] split_urls = { - "train": "https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt", # noqa: E501 - "val": "https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt", # noqa: E501 - "test": "https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt", # noqa: E501 + 'train': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt', # noqa: E501 + 'val': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt', # noqa: E501 + 'test': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt', # noqa: E501 } split_md5s = { - "train": "f2fb12eb2210cfb53f93f063a35ff374", - "val": "11ecabfc52782e5ea6a9c7c0d263aca0", - "test": "046aff88472d8fc07c4678d03749e28d", + 'train': 'f2fb12eb2210cfb53f93f063a35ff374', + 'val': '11ecabfc52782e5ea6a9c7c0d263aca0', + 'test': '046aff88472d8fc07c4678d03749e28d', } def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -112,7 +112,7 @@ def __init__( self._verify() valid_fns = set() - with open(os.path.join(self.root, f"uc_merced-{split}.txt")) as f: + with open(os.path.join(self.root, f'uc_merced-{split}.txt')) as f: for fn in f: valid_fns.add(fn.strip()) @@ -181,7 +181,7 @@ def _download(self) -> None: download_url( self.split_urls[split], self.root, - filename=f"uc_merced-{split}.txt", + filename=f'uc_merced-{split}.txt', md5=self.split_md5s[split] if self.checksum else None, ) @@ -208,27 +208,27 @@ def plot( .. versionadded:: 0.2 """ - image = np.rollaxis(sample["image"].numpy(), 0, 3) + image = np.rollaxis(sample['image'].numpy(), 0, 3) # Normalize the image if the max value is greater than 1 if image.max() > 1: image = image.astype(np.float32) / 255.0 # Scale to [0, 1] - label = cast(int, sample["label"].item()) + label = cast(int, sample['label'].item()) label_class = self.classes[label] - showing_predictions = "prediction" in sample + showing_predictions = 'prediction' in sample if showing_predictions: - prediction = cast(int, sample["prediction"].item()) + prediction = cast(int, sample['prediction'].item()) prediction_class = self.classes[prediction] fig, ax = plt.subplots(figsize=(4, 4)) ax.imshow(image) - ax.axis("off") + ax.axis('off') if show_titles: - title = f"Label: {label_class}" + title = f'Label: {label_class}' if showing_predictions: - title += f"\nPrediction: {prediction_class}" + title += f'\nPrediction: {prediction_class}' ax.set_title(title) if suptitle is not None: diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index 5b27c2e8606..fcbaaa27ed0 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -48,45 +48,45 @@ class USAVars(NonGeoDataset): .. versionadded:: 0.3 """ - data_url = "https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}" # noqa: E501 - dirname = "uar" + data_url = 'https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}' # noqa: E501 + dirname = 'uar' - md5 = "677e89fd20e5dd0fe4d29b61827c2456" + md5 = '677e89fd20e5dd0fe4d29b61827c2456' label_urls = { - "housing": data_url.format("housing.csv"), - "income": data_url.format("income.csv"), - "roads": data_url.format("roads.csv"), - "nightlights": data_url.format("nightlights.csv"), - "population": data_url.format("population.csv"), - "elevation": data_url.format("elevation.csv"), - "treecover": data_url.format("treecover.csv"), + 'housing': data_url.format('housing.csv'), + 'income': data_url.format('income.csv'), + 'roads': data_url.format('roads.csv'), + 'nightlights': data_url.format('nightlights.csv'), + 'population': data_url.format('population.csv'), + 'elevation': data_url.format('elevation.csv'), + 'treecover': data_url.format('treecover.csv'), } split_metadata = { - "train": { - "url": data_url.format("train_split.txt"), - "filename": "train_split.txt", - "md5": "3f58fffbf5fe177611112550297200e7", + 'train': { + 'url': data_url.format('train_split.txt'), + 'filename': 'train_split.txt', + 'md5': '3f58fffbf5fe177611112550297200e7', }, - "val": { - "url": data_url.format("val_split.txt"), - "filename": "val_split.txt", - "md5": "bca7183b132b919dec0fc24fb11662a0", + 'val': { + 'url': data_url.format('val_split.txt'), + 'filename': 'val_split.txt', + 'md5': 'bca7183b132b919dec0fc24fb11662a0', }, - "test": { - "url": data_url.format("test_split.txt"), - "filename": "test_split.txt", - "md5": "97bb36bc003ae0bf556a8d6e8f77141a", + 'test': { + 'url': data_url.format('test_split.txt'), + 'filename': 'test_split.txt', + 'md5': '97bb36bc003ae0bf556a8d6e8f77141a', }, } - ALL_LABELS = ["treecover", "elevation", "population"] + ALL_LABELS = ['treecover', 'elevation', 'population'] def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', labels: Sequence[str] = ALL_LABELS, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -125,7 +125,7 @@ def __init__( self.files = self._load_files() self.label_dfs = { - lab: pd.read_csv(os.path.join(self.root, lab + ".csv"), index_col="ID") + lab: pd.read_csv(os.path.join(self.root, lab + '.csv'), index_col='ID') for lab in self.labels } @@ -142,12 +142,12 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: id_ = tif_file[5:-4] sample = { - "labels": Tensor( + 'labels': Tensor( [self.label_dfs[lab].loc[id_][lab] for lab in self.labels] ), - "image": self._load_image(os.path.join(self.root, "uar", tif_file)), - "centroid_lat": Tensor([self.label_dfs[self.labels[0]].loc[id_]["lat"]]), - "centroid_lon": Tensor([self.label_dfs[self.labels[0]].loc[id_]["lon"]]), + 'image': self._load_image(os.path.join(self.root, 'uar', tif_file)), + 'centroid_lat': Tensor([self.label_dfs[self.labels[0]].loc[id_]['lat']]), + 'centroid_lon': Tensor([self.label_dfs[self.labels[0]].loc[id_]['lon']]), } if self.transforms is not None: @@ -165,7 +165,7 @@ def __len__(self) -> int: def _load_files(self) -> list[str]: """Loads file names.""" - with open(os.path.join(self.root, f"{self.split}_split.txt")) as f: + with open(os.path.join(self.root, f'{self.split}_split.txt')) as f: files = f.read().splitlines() return files @@ -179,23 +179,23 @@ def _load_image(self, path: str) -> Tensor: the image """ with rasterio.open(path) as f: - array: "np.typing.NDArray[np.int_]" = f.read() + array: 'np.typing.NDArray[np.int_]' = f.read() tensor = torch.from_numpy(array).float() return tensor def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist - pathname = os.path.join(self.root, "uar") - csv_pathname = os.path.join(self.root, "*.csv") - split_pathname = os.path.join(self.root, "*_split.txt") + pathname = os.path.join(self.root, 'uar') + csv_pathname = os.path.join(self.root, '*.csv') + split_pathname = os.path.join(self.root, '*_split.txt') csv_split_count = (len(glob.glob(csv_pathname)), len(glob.glob(split_pathname))) if glob.glob(pathname) and csv_split_count == (7, 3): return # Check if the zip files have already been downloaded - pathname = os.path.join(self.root, self.dirname + ".zip") + pathname = os.path.join(self.root, self.dirname + '.zip') if glob.glob(pathname) and csv_split_count == (7, 3): self._extract() return @@ -210,20 +210,20 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" for f_name in self.label_urls: - download_url(self.label_urls[f_name], self.root, filename=f_name + ".csv") + download_url(self.label_urls[f_name], self.root, filename=f_name + '.csv') download_url(self.data_url, self.root, md5=self.md5 if self.checksum else None) for metadata in self.split_metadata.values(): download_url( - metadata["url"], + metadata['url'], self.root, - md5=metadata["md5"] if self.checksum else None, + md5=metadata['md5'] if self.checksum else None, ) def _extract(self) -> None: """Extract the dataset.""" - extract_archive(os.path.join(self.root, self.dirname + ".zip")) + extract_archive(os.path.join(self.root, self.dirname + '.zip')) def plot( self, @@ -241,18 +241,18 @@ def plot( Returns: a matplotlib Figure with the rendered sample """ - image = sample["image"][:3].numpy() # get RGB inds + image = sample['image'][:3].numpy() # get RGB inds image = np.moveaxis(image, 0, 2) fig, axs = plt.subplots(figsize=(10, 10)) axs.imshow(image) - axs.axis("off") + axs.axis('off') if show_labels: - labels = [(lab, val) for lab, val in sample.items() if lab != "image"] - label_string = "" + labels = [(lab, val) for lab, val in sample.items() if lab != 'image'] + label_string = '' for lab, val in labels: - label_string += f"{lab}={round(val[0].item(), 2)} " + label_string += f'{lab}={round(val[0].item(), 2)} ' axs.set_title(label_string) if suptitle is not None: diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 2ef37dfd64d..c76fabeb697 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -28,24 +28,24 @@ from torchvision.utils import draw_segmentation_masks __all__ = ( - "check_integrity", - "DatasetNotFoundError", - "RGBBandsMissingError", - "download_url", - "download_and_extract_archive", - "extract_archive", - "BoundingBox", - "disambiguate_timestamp", - "working_dir", - "stack_samples", - "concat_samples", - "merge_samples", - "unbind_samples", - "rasterio_loader", - "sort_sentinel2_bands", - "draw_semantic_segmentation_masks", - "rgb_to_mask", - "percentile_normalization", + 'check_integrity', + 'DatasetNotFoundError', + 'RGBBandsMissingError', + 'download_url', + 'download_and_extract_archive', + 'extract_archive', + 'BoundingBox', + 'disambiguate_timestamp', + 'working_dir', + 'stack_samples', + 'concat_samples', + 'merge_samples', + 'unbind_samples', + 'rasterio_loader', + 'sort_sentinel2_bands', + 'draw_semantic_segmentation_masks', + 'rgb_to_mask', + 'percentile_normalization', ) @@ -61,33 +61,33 @@ def __init__(self, dataset: Dataset[object]) -> None: Args: dataset: The dataset that was requested. """ - msg = "Dataset not found" + msg = 'Dataset not found' - if hasattr(dataset, "root"): - var = "root" + if hasattr(dataset, 'root'): + var = 'root' val = dataset.root - elif hasattr(dataset, "paths"): - var = "paths" + elif hasattr(dataset, 'paths'): + var = 'paths' val = dataset.paths else: - super().__init__(f"{msg}.") + super().__init__(f'{msg}.') return - msg += f" in `{var}={val!r}` and " + msg += f' in `{var}={val!r}` and ' - if hasattr(dataset, "download") and not dataset.download: - msg += "`download=False`" + if hasattr(dataset, 'download') and not dataset.download: + msg += '`download=False`' else: - msg += "cannot be automatically downloaded" + msg += 'cannot be automatically downloaded' - msg += f", either specify a different `{var}` or " + msg += f', either specify a different `{var}` or ' - if hasattr(dataset, "download") and not dataset.download: - msg += "use `download=True` to automatically" + if hasattr(dataset, 'download') and not dataset.download: + msg += 'use `download=True` to automatically' else: - msg += "manually" + msg += 'manually' - msg += " download the dataset." + msg += ' download the dataset.' super().__init__(msg) @@ -100,7 +100,7 @@ class RGBBandsMissingError(ValueError): def __init__(self) -> None: """Instantiate a new RGBBandsMissingError instance.""" - msg = "Dataset does not contain some of the RGB bands" + msg = 'Dataset does not contain some of the RGB bands' super().__init__(msg) @@ -115,7 +115,7 @@ def __enter__(self) -> Any: import rarfile except ImportError: raise ImportError( - "rarfile is not installed and is required to extract this dataset" + 'rarfile is not installed and is required to extract this dataset' ) # TODO: catch exception for when rarfile is installed but not @@ -161,34 +161,34 @@ def extract_archive(src: str, dst: str | None = None) -> None: dst = os.path.dirname(src) suffix_and_extractor: list[tuple[str | tuple[str, ...], Any]] = [ - (".rar", _rarfile.RarFile), + ('.rar', _rarfile.RarFile), ( - (".tar", ".tar.gz", ".tar.bz2", ".tar.xz", ".tgz", ".tbz2", ".tbz", ".txz"), + ('.tar', '.tar.gz', '.tar.bz2', '.tar.xz', '.tgz', '.tbz2', '.tbz', '.txz'), tarfile.open, ), - (".zip", _zipfile.ZipFile), + ('.zip', _zipfile.ZipFile), ] for suffix, extractor in suffix_and_extractor: if src.endswith(suffix): - with extractor(src, "r") as f: + with extractor(src, 'r') as f: f.extractall(dst) return suffix_and_decompressor: list[tuple[str, Any]] = [ - (".bz2", bz2.open), - (".gz", gzip.open), - (".xz", lzma.open), + ('.bz2', bz2.open), + ('.gz', gzip.open), + ('.xz', lzma.open), ] for suffix, decompressor in suffix_and_decompressor: if src.endswith(suffix): - dst = os.path.join(dst, os.path.basename(src).replace(suffix, "")) - with decompressor(src, "rb") as sf, open(dst, "wb") as df: + dst = os.path.join(dst, os.path.basename(src).replace(suffix, '')) + with decompressor(src, 'rb') as sf, open(dst, 'wb') as df: df.write(sf.read()) return - raise RuntimeError("src file has unknown archival/compression scheme") + raise RuntimeError('src file has unknown archival/compression scheme') def download_and_extract_archive( @@ -216,7 +216,7 @@ def download_and_extract_archive( download_url(url, download_root, filename, md5) archive = os.path.join(download_root, filename) - print(f"Extracting {archive} to {extract_root}") + print(f'Extracting {archive} to {extract_root}') extract_archive(archive, extract_root) @@ -236,7 +236,7 @@ def download_radiant_mlhub_dataset( import radiant_mlhub except ImportError: raise ImportError( - "radiant_mlhub is not installed and is required to download this dataset" + 'radiant_mlhub is not installed and is required to download this dataset' ) dataset = radiant_mlhub.Dataset.fetch(dataset_id, api_key=api_key) @@ -259,7 +259,7 @@ def download_radiant_mlhub_collection( import radiant_mlhub except ImportError: raise ImportError( - "radiant_mlhub is not installed and is required to download this collection" + 'radiant_mlhub is not installed and is required to download this collection' ) collection = radiant_mlhub.Collection.fetch(collection_id, api_key=api_key) @@ -400,7 +400,7 @@ def __and__(self, other: BoundingBox) -> BoundingBox: min(self.maxt, other.maxt), ) except ValueError: - raise ValueError(f"Bounding boxes {self} and {other} do not overlap") + raise ValueError(f'Bounding boxes {self} and {other} do not overlap') @property def area(self) -> float: @@ -461,7 +461,7 @@ def split( .. versionadded:: 0.5 """ if not (0.0 < proportion < 1.0): - raise ValueError("Input proportion must be between 0 and 1.") + raise ValueError('Input proportion must be between 0 and 1.') if horizontal: w = self.maxx - self.minx @@ -507,28 +507,28 @@ def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]: # TODO: May have issues with time zones, UTC vs. local time, and DST # TODO: This is really tedious, is there a better way to do this? - if not any([f"%{c}" in format for c in "yYcxG"]): + if not any([f'%{c}' in format for c in 'yYcxG']): # No temporal info return 0, sys.maxsize - elif not any([f"%{c}" in format for c in "bBmjUWcxV"]): + elif not any([f'%{c}' in format for c in 'bBmjUWcxV']): # Year resolution maxt = datetime(mint.year + 1, 1, 1) - elif not any([f"%{c}" in format for c in "aAwdjcxV"]): + elif not any([f'%{c}' in format for c in 'aAwdjcxV']): # Month resolution if mint.month == 12: maxt = datetime(mint.year + 1, 1, 1) else: maxt = datetime(mint.year, mint.month + 1, 1) - elif not any([f"%{c}" in format for c in "HIcX"]): + elif not any([f'%{c}' in format for c in 'HIcX']): # Day resolution maxt = mint + timedelta(days=1) - elif not any([f"%{c}" in format for c in "McX"]): + elif not any([f'%{c}' in format for c in 'McX']): # Hour resolution maxt = mint + timedelta(hours=1) - elif not any([f"%{c}" in format for c in "ScX"]): + elif not any([f'%{c}' in format for c in 'ScX']): # Minute resolution maxt = mint + timedelta(minutes=1) - elif not any([f"%{c}" in format for c in "f"]): + elif not any([f'%{c}' in format for c in 'f']): # Second resolution maxt = mint + timedelta(seconds=1) else: @@ -704,10 +704,10 @@ def rasterio_loader(path: str) -> np.typing.NDArray[np.int_]: def sort_sentinel2_bands(x: str) -> str: """Sort Sentinel-2 band files in the correct order.""" - x = os.path.basename(x).split("_")[-1] + x = os.path.basename(x).split('_')[-1] x = os.path.splitext(x)[0] - if x == "B8A": - x = "B08A" + if x == 'B8A': + x = 'B08A' return x @@ -736,7 +736,7 @@ def draw_semantic_segmentation_masks( image=image.byte(), masks=class_masks, alpha=alpha, colors=colors ) img = img.permute((1, 2, 0)).numpy().astype(np.uint8) - return cast("np.typing.NDArray[np.uint8]", img) + return cast('np.typing.NDArray[np.uint8]', img) def rgb_to_mask( @@ -818,7 +818,7 @@ def path_is_vsi(path: str) -> bool: .. versionadded:: 0.6 """ - return "://" in path or path.startswith("/vsi") + return '://' in path or path.startswith('/vsi') def array_to_tensor(array: np.typing.NDArray[Any]) -> Tensor: diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py index 913dad794a0..053fc6e4c18 100644 --- a/torchgeo/datasets/vaihingen.py +++ b/torchgeo/datasets/vaihingen.py @@ -57,57 +57,57 @@ class Vaihingen2D(NonGeoDataset): """ # noqa: E501 filenames = [ - "ISPRS_semantic_labeling_Vaihingen.zip", - "ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip", + 'ISPRS_semantic_labeling_Vaihingen.zip', + 'ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip', ] - md5s = ["462b8dca7b6fa9eaf729840f0cdfc7f3", "4802dd6326e2727a352fb735be450277"] - image_root = "top" + md5s = ['462b8dca7b6fa9eaf729840f0cdfc7f3', '4802dd6326e2727a352fb735be450277'] + image_root = 'top' splits = { - "train": [ - "top_mosaic_09cm_area1.tif", - "top_mosaic_09cm_area11.tif", - "top_mosaic_09cm_area13.tif", - "top_mosaic_09cm_area15.tif", - "top_mosaic_09cm_area17.tif", - "top_mosaic_09cm_area21.tif", - "top_mosaic_09cm_area23.tif", - "top_mosaic_09cm_area26.tif", - "top_mosaic_09cm_area28.tif", - "top_mosaic_09cm_area3.tif", - "top_mosaic_09cm_area30.tif", - "top_mosaic_09cm_area32.tif", - "top_mosaic_09cm_area34.tif", - "top_mosaic_09cm_area37.tif", - "top_mosaic_09cm_area5.tif", - "top_mosaic_09cm_area7.tif", + 'train': [ + 'top_mosaic_09cm_area1.tif', + 'top_mosaic_09cm_area11.tif', + 'top_mosaic_09cm_area13.tif', + 'top_mosaic_09cm_area15.tif', + 'top_mosaic_09cm_area17.tif', + 'top_mosaic_09cm_area21.tif', + 'top_mosaic_09cm_area23.tif', + 'top_mosaic_09cm_area26.tif', + 'top_mosaic_09cm_area28.tif', + 'top_mosaic_09cm_area3.tif', + 'top_mosaic_09cm_area30.tif', + 'top_mosaic_09cm_area32.tif', + 'top_mosaic_09cm_area34.tif', + 'top_mosaic_09cm_area37.tif', + 'top_mosaic_09cm_area5.tif', + 'top_mosaic_09cm_area7.tif', ], - "test": [ - "top_mosaic_09cm_area6.tif", - "top_mosaic_09cm_area24.tif", - "top_mosaic_09cm_area35.tif", - "top_mosaic_09cm_area16.tif", - "top_mosaic_09cm_area14.tif", - "top_mosaic_09cm_area22.tif", - "top_mosaic_09cm_area10.tif", - "top_mosaic_09cm_area4.tif", - "top_mosaic_09cm_area2.tif", - "top_mosaic_09cm_area20.tif", - "top_mosaic_09cm_area8.tif", - "top_mosaic_09cm_area31.tif", - "top_mosaic_09cm_area33.tif", - "top_mosaic_09cm_area27.tif", - "top_mosaic_09cm_area38.tif", - "top_mosaic_09cm_area12.tif", - "top_mosaic_09cm_area29.tif", + 'test': [ + 'top_mosaic_09cm_area6.tif', + 'top_mosaic_09cm_area24.tif', + 'top_mosaic_09cm_area35.tif', + 'top_mosaic_09cm_area16.tif', + 'top_mosaic_09cm_area14.tif', + 'top_mosaic_09cm_area22.tif', + 'top_mosaic_09cm_area10.tif', + 'top_mosaic_09cm_area4.tif', + 'top_mosaic_09cm_area2.tif', + 'top_mosaic_09cm_area20.tif', + 'top_mosaic_09cm_area8.tif', + 'top_mosaic_09cm_area31.tif', + 'top_mosaic_09cm_area33.tif', + 'top_mosaic_09cm_area27.tif', + 'top_mosaic_09cm_area38.tif', + 'top_mosaic_09cm_area12.tif', + 'top_mosaic_09cm_area29.tif', ], } classes = [ - "Clutter/background", - "Impervious surfaces", - "Building", - "Low Vegetation", - "Tree", - "Car", + 'Clutter/background', + 'Impervious surfaces', + 'Building', + 'Low Vegetation', + 'Tree', + 'Car', ] colormap = [ (255, 0, 0), @@ -120,8 +120,8 @@ class Vaihingen2D(NonGeoDataset): def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: @@ -164,7 +164,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ image = self._load_image(index) mask = self._load_target(index) - sample = {"image": image, "mask": mask} + sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -188,9 +188,9 @@ def _load_image(self, index: int) -> Tensor: Returns: the image """ - path = self.files[index]["image"] + path = self.files[index]['image'] with Image.open(path) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('RGB')) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)).float() @@ -205,9 +205,9 @@ def _load_target(self, index: int) -> Tensor: Returns: the target mask """ - path = self.files[index]["mask"] + path = self.files[index]['mask'] with Image.open(path) as img: - array: "np.typing.NDArray[np.uint8]" = np.array(img.convert("RGB")) + array: 'np.typing.NDArray[np.uint8]' = np.array(img.convert('RGB')) array = rgb_to_mask(array, self.colormap) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW @@ -226,7 +226,7 @@ def _verify(self) -> None: filepath = os.path.join(self.root, filename) if os.path.isfile(filepath): if self.checksum and not check_integrity(filepath, md5): - raise RuntimeError("Dataset found, but corrupted.") + raise RuntimeError('Dataset found, but corrupted.') exists.append(True) extract_archive(filepath) else: @@ -257,13 +257,13 @@ def plot( """ ncols = 1 image1 = draw_semantic_segmentation_masks( - sample["image"][:3], sample["mask"], alpha=alpha, colors=self.colormap + sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap ) - if "prediction" in sample: + if 'prediction' in sample: ncols += 1 image2 = draw_semantic_segmentation_masks( - sample["image"][:3], - sample["prediction"], + sample['image'][:3], + sample['prediction'], alpha=alpha, colors=self.colormap, ) @@ -275,15 +275,15 @@ def plot( ax0 = axs ax0.imshow(image1) - ax0.axis("off") + ax0.axis('off') if ncols > 1: ax1.imshow(image2) - ax1.axis("off") + ax1.axis('off') if show_titles: - ax0.set_title("Ground Truth") + ax0.set_title('Ground Truth') if ncols > 1: - ax1.set_title("Predictions") + ax1.set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 93eabdf72bc..dd20b70a71d 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -66,28 +66,28 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: Returns: Processed sample """ - image = sample["image"] + image = sample['image'] _, h, w = image.size() - target = sample["label"] + target = sample['label'] - image_id = target["image_id"] + image_id = target['image_id'] image_id = torch.tensor([image_id]) - anno = target["annotations"] + anno = target['annotations'] - anno = [obj for obj in anno if obj["iscrowd"] == 0] + anno = [obj for obj in anno if obj['iscrowd'] == 0] - bboxes = [obj["bbox"] for obj in anno] + bboxes = [obj['bbox'] for obj in anno] # guard against no boxes via resizing boxes = torch.as_tensor(bboxes, dtype=torch.float32).reshape(-1, 4) boxes[:, 2:] += boxes[:, :2] boxes[:, 0::2].clamp_(min=0, max=w) boxes[:, 1::2].clamp_(min=0, max=h) - categories = [obj["category_id"] for obj in anno] + categories = [obj['category_id'] for obj in anno] classes = torch.tensor(categories, dtype=torch.int64) - segmentations = [obj["segmentation"] for obj in anno] + segmentations = [obj['segmentation'] for obj in anno] masks = convert_coco_poly_to_mask(segmentations, h, w) @@ -95,17 +95,17 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: boxes = boxes[keep] classes = classes[keep] - target = {"boxes": boxes, "labels": classes, "image_id": image_id} + target = {'boxes': boxes, 'labels': classes, 'image_id': image_id} if masks.nelement() > 0: masks = masks[keep] - target["masks"] = masks + target['masks'] = masks # for conversion to coco api - area = torch.tensor([obj["area"] for obj in anno]) - iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) - target["area"] = area - target["iscrowd"] = iscrowd - return {"image": image, "label": target} + area = torch.tensor([obj['area'] for obj in anno]) + iscrowd = torch.tensor([obj['iscrowd'] for obj in anno]) + target['area'] = area + target['iscrowd'] = iscrowd + return {'image': image, 'label': target} class VHR10(NonGeoDataset): @@ -156,34 +156,34 @@ class VHR10(NonGeoDataset): """ image_meta = { - "url": "https://drive.google.com/file/d/1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE", - "filename": "NWPU VHR-10 dataset.rar", - "md5": "d30a7ff99d92123ebb0b3a14d9102081", + 'url': 'https://drive.google.com/file/d/1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE', + 'filename': 'NWPU VHR-10 dataset.rar', + 'md5': 'd30a7ff99d92123ebb0b3a14d9102081', } target_meta = { - "url": "https://raw.githubusercontent.com/chaozhong2010/VHR-10_dataset_coco/ce0ba0f5f6a0737031f1cbe05e785ddd5ef05bd7/NWPU%20VHR-10_dataset_coco/annotations.json", # noqa: E501 - "filename": "annotations.json", - "md5": "7c76ec50c17a61bb0514050d20f22c08", + 'url': 'https://raw.githubusercontent.com/chaozhong2010/VHR-10_dataset_coco/ce0ba0f5f6a0737031f1cbe05e785ddd5ef05bd7/NWPU%20VHR-10_dataset_coco/annotations.json', # noqa: E501 + 'filename': 'annotations.json', + 'md5': '7c76ec50c17a61bb0514050d20f22c08', } categories = [ - "background", - "airplane", - "ships", - "storage tank", - "baseball diamond", - "tennis court", - "basketball court", - "ground track field", - "harbor", - "bridge", - "vehicle", + 'background', + 'airplane', + 'ships', + 'storage tank', + 'baseball diamond', + 'tennis court', + 'basketball court', + 'ground track field', + 'harbor', + 'bridge', + 'vehicle', ] def __init__( self, - root: str = "data", - split: str = "positive", + root: str = 'data', + split: str = 'positive', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, checksum: bool = False, @@ -203,7 +203,7 @@ def __init__( ImportError: if ``split="positive"`` and pycocotools is not installed DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert split in ["positive", "negative"] + assert split in ['positive', 'negative'] self.root = root self.split = split @@ -216,18 +216,18 @@ def __init__( if not self._check_integrity(): raise DatasetNotFoundError(self) - if split == "positive": + if split == 'positive': # Must be installed to parse annotations file try: from pycocotools.coco import COCO # noqa: F401 except ImportError: raise ImportError( - "pycocotools is not installed and is required to use this dataset" + 'pycocotools is not installed and is required to use this dataset' ) self.coco = COCO( os.path.join( - self.root, "NWPU VHR-10 dataset", self.target_meta["filename"] + self.root, 'NWPU VHR-10 dataset', self.target_meta['filename'] ) ) @@ -246,16 +246,16 @@ def __getitem__(self, index: int) -> dict[str, Any]: id_ = index % len(self) + 1 sample: dict[str, Any] = { - "image": self._load_image(id_), - "label": self._load_target(id_), + 'image': self._load_image(id_), + 'label': self._load_target(id_), } - if sample["label"]["annotations"]: + if sample['label']['annotations']: sample = self.coco_convert(sample) - sample["labels"] = sample["label"]["labels"] - sample["boxes"] = sample["label"]["boxes"] - sample["masks"] = sample["label"]["masks"] - del sample["label"] + sample['labels'] = sample['label']['labels'] + sample['boxes'] = sample['label']['boxes'] + sample['masks'] = sample['label']['masks'] + del sample['label'] if self.transforms is not None: sample = self.transforms(sample) @@ -268,7 +268,7 @@ def __len__(self) -> int: Returns: length of the dataset """ - if self.split == "positive": + if self.split == 'positive': return len(self.ids) else: return 150 @@ -284,12 +284,12 @@ def _load_image(self, id_: int) -> Tensor: """ filename = os.path.join( self.root, - "NWPU VHR-10 dataset", - self.split + " image set", - f"{id_:03d}.jpg", + 'NWPU VHR-10 dataset', + self.split + ' image set', + f'{id_:03d}.jpg', ) with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img) + array: 'np.typing.NDArray[np.int_]' = np.array(img) tensor = torch.from_numpy(array) tensor = tensor.float() # Convert from HxWxC to CxHxW @@ -307,7 +307,7 @@ def _load_target(self, id_: int) -> dict[str, Any]: """ # Images in the "negative" image set have no annotations annot = [] - if self.split == "positive": + if self.split == 'positive': annot = self.coco.loadAnns(self.coco.getAnnIds(id_ - 1)) target = dict(image_id=id_, annotations=annot) @@ -321,18 +321,18 @@ def _check_integrity(self) -> bool: True if dataset files are found and/or MD5s match, else False """ image: bool = check_integrity( - os.path.join(self.root, self.image_meta["filename"]), - self.image_meta["md5"] if self.checksum else None, + os.path.join(self.root, self.image_meta['filename']), + self.image_meta['md5'] if self.checksum else None, ) # Annotations only needed for "positive" image set target = True - if self.split == "positive": + if self.split == 'positive': target = check_integrity( os.path.join( - self.root, "NWPU VHR-10 dataset", self.target_meta["filename"] + self.root, 'NWPU VHR-10 dataset', self.target_meta['filename'] ), - self.target_meta["md5"] if self.checksum else None, + self.target_meta['md5'] if self.checksum else None, ) return image and target @@ -340,25 +340,25 @@ def _check_integrity(self) -> bool: def _download(self) -> None: """Download the dataset and extract it.""" if self._check_integrity(): - print("Files already downloaded and verified") + print('Files already downloaded and verified') return # Download images download_and_extract_archive( - self.image_meta["url"], + self.image_meta['url'], self.root, - filename=self.image_meta["filename"], - md5=self.image_meta["md5"] if self.checksum else None, + filename=self.image_meta['filename'], + md5=self.image_meta['md5'] if self.checksum else None, ) # Annotations only needed for "positive" image set - if self.split == "positive": + if self.split == 'positive': # Download annotations download_url( - self.target_meta["url"], - os.path.join(self.root, "NWPU VHR-10 dataset"), - self.target_meta["filename"], - self.target_meta["md5"] if self.checksum else None, + self.target_meta['url'], + os.path.join(self.root, 'NWPU VHR-10 dataset'), + self.target_meta['filename'], + self.target_meta['md5'] if self.checksum else None, ) def plot( @@ -366,7 +366,7 @@ def plot( sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None, - show_feats: str | None = "both", + show_feats: str | None = 'both', box_alpha: float = 0.7, mask_alpha: float = 0.7, ) -> Figure: @@ -389,46 +389,46 @@ def plot( .. versionadded:: 0.4 """ - assert show_feats in {"boxes", "masks", "both"} + assert show_feats in {'boxes', 'masks', 'both'} - if self.split == "negative": + if self.split == 'negative': fig, axs = plt.subplots(squeeze=False) - axs[0, 0].imshow(sample["image"].permute(1, 2, 0)) - axs[0, 0].axis("off") + axs[0, 0].imshow(sample['image'].permute(1, 2, 0)) + axs[0, 0].axis('off') if suptitle is not None: plt.suptitle(suptitle) return fig - if show_feats != "boxes": + if show_feats != 'boxes': try: from skimage.measure import find_contours # noqa: F401 except ImportError: raise ImportError( - "scikit-image is not installed and is required to plot masks." + 'scikit-image is not installed and is required to plot masks.' ) - image = sample["image"].permute(1, 2, 0).numpy() - boxes = sample["boxes"].cpu().numpy() - labels = sample["labels"].cpu().numpy() - if "masks" in sample: - masks = [mask.squeeze().cpu().numpy() for mask in sample["masks"]] + image = sample['image'].permute(1, 2, 0).numpy() + boxes = sample['boxes'].cpu().numpy() + labels = sample['labels'].cpu().numpy() + if 'masks' in sample: + masks = [mask.squeeze().cpu().numpy() for mask in sample['masks']] n_gt = len(boxes) ncols = 1 - show_predictions = "prediction_labels" in sample + show_predictions = 'prediction_labels' in sample if show_predictions: show_pred_boxes = False show_pred_masks = False - prediction_labels = sample["prediction_labels"].numpy() - prediction_scores = sample["prediction_scores"].numpy() - if "prediction_boxes" in sample: - prediction_boxes = sample["prediction_boxes"].numpy() + prediction_labels = sample['prediction_labels'].numpy() + prediction_scores = sample['prediction_scores'].numpy() + if 'prediction_boxes' in sample: + prediction_boxes = sample['prediction_boxes'].numpy() show_pred_boxes = True - if "prediction_masks" in sample: - prediction_masks = sample["prediction_masks"].numpy() + if 'prediction_masks' in sample: + prediction_masks = sample['prediction_masks'].numpy() show_pred_masks = True n_pred = len(prediction_labels) @@ -437,25 +437,25 @@ def plot( # Display image fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 13)) axs[0, 0].imshow(image) - axs[0, 0].axis("off") + axs[0, 0].axis('off') - cm = plt.get_cmap("gist_rainbow") + cm = plt.get_cmap('gist_rainbow') for i in range(n_gt): class_num = labels[i] color = cm(class_num / len(self.categories)) # Add bounding boxes x1, y1, x2, y2 = boxes[i] - if show_feats in {"boxes", "both"}: + if show_feats in {'boxes', 'both'}: r = patches.Rectangle( (x1, y1), x2 - x1, y2 - y1, linewidth=2, alpha=box_alpha, - linestyle="dashed", + linestyle='dashed', edgecolor=color, - facecolor="none", + facecolor='none', ) axs[0, 0].add_patch(r) @@ -463,26 +463,26 @@ def plot( label = self.categories[class_num] caption = label axs[0, 0].text( - x1, y1 - 8, caption, color="white", size=11, backgroundcolor="none" + x1, y1 - 8, caption, color='white', size=11, backgroundcolor='none' ) # Add masks - if show_feats in {"masks", "both"} and "masks" in sample: + if show_feats in {'masks', 'both'} and 'masks' in sample: mask = masks[i] contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call] for verts in contours: verts = np.fliplr(verts) p = patches.Polygon( - verts, facecolor=color, alpha=mask_alpha, edgecolor="white" + verts, facecolor=color, alpha=mask_alpha, edgecolor='white' ) axs[0, 0].add_patch(p) if show_titles: - axs[0, 0].set_title("Ground Truth") + axs[0, 0].set_title('Ground Truth') if show_predictions: axs[0, 1].imshow(image) - axs[0, 1].axis("off") + axs[0, 1].axis('off') for i in range(n_pred): score = prediction_scores[i] if score < 0.5: @@ -500,22 +500,22 @@ def plot( y2 - y1, linewidth=2, alpha=box_alpha, - linestyle="dashed", + linestyle='dashed', edgecolor=color, - facecolor="none", + facecolor='none', ) axs[0, 1].add_patch(r) # Add labels label = self.categories[class_num] - caption = f"{label} {score:.3f}" + caption = f'{label} {score:.3f}' axs[0, 1].text( x1, y1 - 8, caption, - color="white", + color='white', size=11, - backgroundcolor="none", + backgroundcolor='none', ) # Add masks @@ -525,12 +525,12 @@ def plot( for verts in contours: verts = np.fliplr(verts) p = patches.Polygon( - verts, facecolor=color, alpha=mask_alpha, edgecolor="white" + verts, facecolor=color, alpha=mask_alpha, edgecolor='white' ) axs[0, 1].add_patch(p) if show_titles: - axs[0, 1].set_title("Prediction") + axs[0, 1].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/western_usa_live_fuel_moisture.py b/torchgeo/datasets/western_usa_live_fuel_moisture.py index ec41088d3b8..1a12ad4af85 100644 --- a/torchgeo/datasets/western_usa_live_fuel_moisture.py +++ b/torchgeo/datasets/western_usa_live_fuel_moisture.py @@ -53,157 +53,157 @@ class WesternUSALiveFuelMoisture(NonGeoDataset): .. versionadded:: 0.5 """ - collection_id = "su_sar_moisture_content" + collection_id = 'su_sar_moisture_content' - md5 = "a6c0721f06a3a0110b7d1243b18614f0" + md5 = 'a6c0721f06a3a0110b7d1243b18614f0' - label_name = "percent(t)" + label_name = 'percent(t)' all_variable_names = [ # "date", - "slope(t)", - "elevation(t)", - "canopy_height(t)", - "forest_cover(t)", - "silt(t)", - "sand(t)", - "clay(t)", - "vv(t)", - "vh(t)", - "red(t)", - "green(t)", - "blue(t)", - "swir(t)", - "nir(t)", - "ndvi(t)", - "ndwi(t)", - "nirv(t)", - "vv_red(t)", - "vv_green(t)", - "vv_blue(t)", - "vv_swir(t)", - "vv_nir(t)", - "vv_ndvi(t)", - "vv_ndwi(t)", - "vv_nirv(t)", - "vh_red(t)", - "vh_green(t)", - "vh_blue(t)", - "vh_swir(t)", - "vh_nir(t)", - "vh_ndvi(t)", - "vh_ndwi(t)", - "vh_nirv(t)", - "vh_vv(t)", - "slope(t-1)", - "elevation(t-1)", - "canopy_height(t-1)", - "forest_cover(t-1)", - "silt(t-1)", - "sand(t-1)", - "clay(t-1)", - "vv(t-1)", - "vh(t-1)", - "red(t-1)", - "green(t-1)", - "blue(t-1)", - "swir(t-1)", - "nir(t-1)", - "ndvi(t-1)", - "ndwi(t-1)", - "nirv(t-1)", - "vv_red(t-1)", - "vv_green(t-1)", - "vv_blue(t-1)", - "vv_swir(t-1)", - "vv_nir(t-1)", - "vv_ndvi(t-1)", - "vv_ndwi(t-1)", - "vv_nirv(t-1)", - "vh_red(t-1)", - "vh_green(t-1)", - "vh_blue(t-1)", - "vh_swir(t-1)", - "vh_nir(t-1)", - "vh_ndvi(t-1)", - "vh_ndwi(t-1)", - "vh_nirv(t-1)", - "vh_vv(t-1)", - "slope(t-2)", - "elevation(t-2)", - "canopy_height(t-2)", - "forest_cover(t-2)", - "silt(t-2)", - "sand(t-2)", - "clay(t-2)", - "vv(t-2)", - "vh(t-2)", - "red(t-2)", - "green(t-2)", - "blue(t-2)", - "swir(t-2)", - "nir(t-2)", - "ndvi(t-2)", - "ndwi(t-2)", - "nirv(t-2)", - "vv_red(t-2)", - "vv_green(t-2)", - "vv_blue(t-2)", - "vv_swir(t-2)", - "vv_nir(t-2)", - "vv_ndvi(t-2)", - "vv_ndwi(t-2)", - "vv_nirv(t-2)", - "vh_red(t-2)", - "vh_green(t-2)", - "vh_blue(t-2)", - "vh_swir(t-2)", - "vh_nir(t-2)", - "vh_ndvi(t-2)", - "vh_ndwi(t-2)", - "vh_nirv(t-2)", - "vh_vv(t-2)", - "slope(t-3)", - "elevation(t-3)", - "canopy_height(t-3)", - "forest_cover(t-3)", - "silt(t-3)", - "sand(t-3)", - "clay(t-3)", - "vv(t-3)", - "vh(t-3)", - "red(t-3)", - "green(t-3)", - "blue(t-3)", - "swir(t-3)", - "nir(t-3)", - "ndvi(t-3)", - "ndwi(t-3)", - "nirv(t-3)", - "vv_red(t-3)", - "vv_green(t-3)", - "vv_blue(t-3)", - "vv_swir(t-3)", - "vv_nir(t-3)", - "vv_ndvi(t-3)", - "vv_ndwi(t-3)", - "vv_nirv(t-3)", - "vh_red(t-3)", - "vh_green(t-3)", - "vh_blue(t-3)", - "vh_swir(t-3)", - "vh_nir(t-3)", - "vh_ndvi(t-3)", - "vh_ndwi(t-3)", - "vh_nirv(t-3)", - "vh_vv(t-3)", - "lat", - "lon", + 'slope(t)', + 'elevation(t)', + 'canopy_height(t)', + 'forest_cover(t)', + 'silt(t)', + 'sand(t)', + 'clay(t)', + 'vv(t)', + 'vh(t)', + 'red(t)', + 'green(t)', + 'blue(t)', + 'swir(t)', + 'nir(t)', + 'ndvi(t)', + 'ndwi(t)', + 'nirv(t)', + 'vv_red(t)', + 'vv_green(t)', + 'vv_blue(t)', + 'vv_swir(t)', + 'vv_nir(t)', + 'vv_ndvi(t)', + 'vv_ndwi(t)', + 'vv_nirv(t)', + 'vh_red(t)', + 'vh_green(t)', + 'vh_blue(t)', + 'vh_swir(t)', + 'vh_nir(t)', + 'vh_ndvi(t)', + 'vh_ndwi(t)', + 'vh_nirv(t)', + 'vh_vv(t)', + 'slope(t-1)', + 'elevation(t-1)', + 'canopy_height(t-1)', + 'forest_cover(t-1)', + 'silt(t-1)', + 'sand(t-1)', + 'clay(t-1)', + 'vv(t-1)', + 'vh(t-1)', + 'red(t-1)', + 'green(t-1)', + 'blue(t-1)', + 'swir(t-1)', + 'nir(t-1)', + 'ndvi(t-1)', + 'ndwi(t-1)', + 'nirv(t-1)', + 'vv_red(t-1)', + 'vv_green(t-1)', + 'vv_blue(t-1)', + 'vv_swir(t-1)', + 'vv_nir(t-1)', + 'vv_ndvi(t-1)', + 'vv_ndwi(t-1)', + 'vv_nirv(t-1)', + 'vh_red(t-1)', + 'vh_green(t-1)', + 'vh_blue(t-1)', + 'vh_swir(t-1)', + 'vh_nir(t-1)', + 'vh_ndvi(t-1)', + 'vh_ndwi(t-1)', + 'vh_nirv(t-1)', + 'vh_vv(t-1)', + 'slope(t-2)', + 'elevation(t-2)', + 'canopy_height(t-2)', + 'forest_cover(t-2)', + 'silt(t-2)', + 'sand(t-2)', + 'clay(t-2)', + 'vv(t-2)', + 'vh(t-2)', + 'red(t-2)', + 'green(t-2)', + 'blue(t-2)', + 'swir(t-2)', + 'nir(t-2)', + 'ndvi(t-2)', + 'ndwi(t-2)', + 'nirv(t-2)', + 'vv_red(t-2)', + 'vv_green(t-2)', + 'vv_blue(t-2)', + 'vv_swir(t-2)', + 'vv_nir(t-2)', + 'vv_ndvi(t-2)', + 'vv_ndwi(t-2)', + 'vv_nirv(t-2)', + 'vh_red(t-2)', + 'vh_green(t-2)', + 'vh_blue(t-2)', + 'vh_swir(t-2)', + 'vh_nir(t-2)', + 'vh_ndvi(t-2)', + 'vh_ndwi(t-2)', + 'vh_nirv(t-2)', + 'vh_vv(t-2)', + 'slope(t-3)', + 'elevation(t-3)', + 'canopy_height(t-3)', + 'forest_cover(t-3)', + 'silt(t-3)', + 'sand(t-3)', + 'clay(t-3)', + 'vv(t-3)', + 'vh(t-3)', + 'red(t-3)', + 'green(t-3)', + 'blue(t-3)', + 'swir(t-3)', + 'nir(t-3)', + 'ndvi(t-3)', + 'ndwi(t-3)', + 'nirv(t-3)', + 'vv_red(t-3)', + 'vv_green(t-3)', + 'vv_blue(t-3)', + 'vv_swir(t-3)', + 'vv_nir(t-3)', + 'vv_ndvi(t-3)', + 'vv_ndwi(t-3)', + 'vv_nirv(t-3)', + 'vh_red(t-3)', + 'vh_green(t-3)', + 'vh_blue(t-3)', + 'vh_swir(t-3)', + 'vh_nir(t-3)', + 'vh_ndvi(t-3)', + 'vh_ndwi(t-3)', + 'vh_nirv(t-3)', + 'vh_vv(t-3)', + 'lat', + 'lon', ] def __init__( self, - root: str = "data", + root: str = 'data', input_features: list[str] = all_variable_names, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, @@ -237,7 +237,7 @@ def __init__( assert all( input in self.all_variable_names for input in input_features - ), "Invalid input variable name." + ), 'Invalid input variable name.' self.input_features = input_features self.collection = self._retrieve_collection() @@ -251,7 +251,7 @@ def _retrieve_collection(self) -> list[str]: list of sample paths """ return glob.glob( - os.path.join(self.root, self.collection_id, "**", "labels.geojson") + os.path.join(self.root, self.collection_id, '**', 'labels.geojson') ) def __len__(self) -> int: @@ -274,10 +274,10 @@ def __getitem__(self, index: int) -> dict[str, Any]: data = self.dataframe.iloc[index, :] sample: dict[str, Tensor] = { - "input": torch.tensor( + 'input': torch.tensor( data.drop([self.label_name]).values, dtype=torch.float32 ), - "label": torch.tensor(data[self.label_name], dtype=torch.float32), + 'label': torch.tensor(data[self.label_name], dtype=torch.float32), } if self.transforms is not None: @@ -295,9 +295,9 @@ def _load_data(self) -> pd.DataFrame: for path in self.collection: with open(path) as f: content = json.load(f) - data_dict = content["properties"] - data_dict["lon"] = content["geometry"]["coordinates"][0] - data_dict["lat"] = content["geometry"]["coordinates"][1] + data_dict = content['properties'] + data_dict['lon'] = content['geometry']['coordinates'][0] + data_dict['lat'] = content['geometry']['coordinates'][1] data_rows.append(data_dict) df: pd.DataFrame = pd.DataFrame(data_rows) @@ -312,7 +312,7 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - pathname = os.path.join(self.root, self.collection_id) + ".tar.gz" + pathname = os.path.join(self.root, self.collection_id) + '.tar.gz' if os.path.exists(pathname): self._extract() return @@ -327,7 +327,7 @@ def _verify(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - pathname = os.path.join(self.root, self.collection_id) + ".tar.gz" + pathname = os.path.join(self.root, self.collection_id) + '.tar.gz' extract_archive(pathname, self.root) def _download(self, api_key: str | None = None) -> None: @@ -337,5 +337,5 @@ def _download(self, api_key: str | None = None) -> None: api_key: a RadiantEarth MLHub API key to use for downloading the dataset """ download_radiant_mlhub_collection(self.collection_id, self.root, api_key) - filename = os.path.join(self.root, self.collection_id) + ".tar.gz" + filename = os.path.join(self.root, self.collection_id) + '.tar.gz' extract_archive(filename, self.root) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 7df948af3ef..bfdb5eb8eda 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -54,24 +54,24 @@ class XView2(NonGeoDataset): """ metadata = { - "train": { - "filename": "train_images_labels_targets.tar.gz", - "md5": "a20ebbfb7eb3452785b63ad02ffd1e16", - "directory": "train", + 'train': { + 'filename': 'train_images_labels_targets.tar.gz', + 'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16', + 'directory': 'train', }, - "test": { - "filename": "test_images_labels_targets.tar.gz", - "md5": "1b39c47e05d1319c17cc8763cee6fe0c", - "directory": "test", + 'test': { + 'filename': 'test_images_labels_targets.tar.gz', + 'md5': '1b39c47e05d1319c17cc8763cee6fe0c', + 'directory': 'test', }, } - classes = ["background", "no-damage", "minor-damage", "major-damage", "destroyed"] - colormap = ["green", "blue", "orange", "red"] + classes = ['background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed'] + colormap = ['green', 'blue', 'orange', 'red'] def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: @@ -109,14 +109,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - image1 = self._load_image(files["image1"]) - image2 = self._load_image(files["image2"]) - mask1 = self._load_target(files["mask1"]) - mask2 = self._load_target(files["mask2"]) + image1 = self._load_image(files['image1']) + image2 = self._load_image(files['image2']) + mask1 = self._load_target(files['mask1']) + mask2 = self._load_target(files['mask2']) image = torch.stack(tensors=[image1, image2], dim=0) mask = torch.stack(tensors=[mask1, mask2], dim=0) - sample = {"image": image, "mask": mask} + sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -142,17 +142,17 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: list of dicts containing paths for each pair of images and masks """ files = [] - directory = self.metadata[split]["directory"] - image_root = os.path.join(root, directory, "images") - mask_root = os.path.join(root, directory, "targets") - images = glob.glob(os.path.join(image_root, "*.png")) + directory = self.metadata[split]['directory'] + image_root = os.path.join(root, directory, 'images') + mask_root = os.path.join(root, directory, 'targets') + images = glob.glob(os.path.join(image_root, '*.png')) basenames = [os.path.basename(f) for f in images] - basenames = ["_".join(f.split("_")[:-2]) for f in basenames] + basenames = ['_'.join(f.split('_')[:-2]) for f in basenames] for name in sorted(set(basenames)): - image1 = os.path.join(image_root, f"{name}_pre_disaster.png") - image2 = os.path.join(image_root, f"{name}_post_disaster.png") - mask1 = os.path.join(mask_root, f"{name}_pre_disaster_target.png") - mask2 = os.path.join(mask_root, f"{name}_post_disaster_target.png") + image1 = os.path.join(image_root, f'{name}_pre_disaster.png') + image2 = os.path.join(image_root, f'{name}_post_disaster.png') + mask1 = os.path.join(mask_root, f'{name}_pre_disaster_target.png') + mask2 = os.path.join(mask_root, f'{name}_post_disaster_target.png') files.append(dict(image1=image1, image2=image2, mask1=mask1, mask2=mask2)) return files @@ -167,7 +167,7 @@ def _load_image(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('RGB')) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -184,7 +184,7 @@ def _load_target(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img.convert("L")) + array: 'np.typing.NDArray[np.int_]' = np.array(img.convert('L')) tensor = torch.from_numpy(array) tensor = tensor.to(torch.long) return tensor @@ -194,10 +194,10 @@ def _verify(self) -> None: # Check if the files already exist exists = [] for split_info in self.metadata.values(): - for directory in ["images", "targets"]: + for directory in ['images', 'targets']: exists.append( os.path.exists( - os.path.join(self.root, split_info["directory"], directory) + os.path.join(self.root, split_info['directory'], directory) ) ) @@ -207,10 +207,10 @@ def _verify(self) -> None: # Check if .tar.gz files already exists (if so then extract) exists = [] for split_info in self.metadata.values(): - filepath = os.path.join(self.root, split_info["filename"]) + filepath = os.path.join(self.root, split_info['filename']) if os.path.isfile(filepath): - if self.checksum and not check_integrity(filepath, split_info["md5"]): - raise RuntimeError("Dataset found, but corrupted.") + if self.checksum and not check_integrity(filepath, split_info['md5']): + raise RuntimeError('Dataset found, but corrupted.') exists.append(True) extract_archive(filepath) else: @@ -241,34 +241,34 @@ def plot( """ ncols = 2 image1 = draw_semantic_segmentation_masks( - sample["image"][0], sample["mask"][0], alpha=alpha, colors=self.colormap + sample['image'][0], sample['mask'][0], alpha=alpha, colors=self.colormap ) image2 = draw_semantic_segmentation_masks( - sample["image"][1], sample["mask"][1], alpha=alpha, colors=self.colormap + sample['image'][1], sample['mask'][1], alpha=alpha, colors=self.colormap ) - if "prediction" in sample: # NOTE: this assumes predictions are made for post + if 'prediction' in sample: # NOTE: this assumes predictions are made for post ncols += 1 image3 = draw_semantic_segmentation_masks( - sample["image"][1], - sample["prediction"], + sample['image'][1], + sample['prediction'], alpha=alpha, colors=self.colormap, ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image1) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(image2) - axs[1].axis("off") + axs[1].axis('off') if ncols > 2: axs[2].imshow(image3) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[0].set_title("Pre disaster") - axs[1].set_title("Post disaster") + axs[0].set_title('Pre disaster') + axs[1].set_title('Post disaster') if ncols > 2: - axs[2].set_title("Predictions") + axs[2].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py index d1f5029733e..4620c826eaf 100644 --- a/torchgeo/datasets/zuericrop.py +++ b/torchgeo/datasets/zuericrop.py @@ -57,18 +57,18 @@ class ZueriCrop(NonGeoDataset): """ urls = [ - "https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download", - "https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv", # noqa: E501 + 'https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download', + 'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv', # noqa: E501 ] - md5s = ["1635231df67f3d25f4f1e62c98e221a4", "5118398c7a5bbc246f5f6bb35d8d529b"] - filenames = ["ZueriCrop.hdf5", "labels.csv"] + md5s = ['1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b'] + filenames = ['ZueriCrop.hdf5', 'labels.csv'] - band_names = ("NIR", "B03", "B02", "B04", "B05", "B06", "B07", "B11", "B12") - rgb_bands = ["B04", "B03", "B02"] + band_names = ('NIR', 'B03', 'B02', 'B04', 'B05', 'B06', 'B07', 'B11', 'B12') + rgb_bands = ['B04', 'B03', 'B02'] def __init__( self, - root: str = "data", + root: str = 'data', bands: Sequence[str] = band_names, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -97,7 +97,7 @@ def __init__( self.transforms = transforms self.download = download self.checksum = checksum - self.filepath = os.path.join(root, "ZueriCrop.hdf5") + self.filepath = os.path.join(root, 'ZueriCrop.hdf5') self._verify() @@ -105,7 +105,7 @@ def __init__( import h5py # noqa: F401 except ImportError: raise ImportError( - "h5py is not installed and is required to use this dataset" + 'h5py is not installed and is required to use this dataset' ) def __getitem__(self, index: int) -> dict[str, Tensor]: @@ -120,7 +120,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image = self._load_image(index) mask, boxes, label = self._load_target(index) - sample = {"image": image, "mask": mask, "boxes": boxes, "label": label} + sample = {'image': image, 'mask': mask, 'boxes': boxes, 'label': label} if self.transforms is not None: sample = self.transforms(sample) @@ -135,8 +135,8 @@ def __len__(self) -> int: """ import h5py - with h5py.File(self.filepath, "r") as f: - length: int = f["data"].shape[0] + with h5py.File(self.filepath, 'r') as f: + length: int = f['data'].shape[0] return length def _load_image(self, index: int) -> Tensor: @@ -150,8 +150,8 @@ def _load_image(self, index: int) -> Tensor: """ import h5py - with h5py.File(self.filepath, "r") as f: - array = f["data"][index, ...] + with h5py.File(self.filepath, 'r') as f: + array = f['data'][index, ...] tensor = torch.from_numpy(array) # Convert from TxHxWxC to TxCxHxW @@ -170,9 +170,9 @@ def _load_target(self, index: int) -> tuple[Tensor, Tensor, Tensor]: """ import h5py - with h5py.File(self.filepath, "r") as f: - mask_array = f["gt"][index, ...] - instance_array = f["gt_instance"][index, ...] + with h5py.File(self.filepath, 'r') as f: + mask_array = f['gt'][index, ...] + instance_array = f['gt_instance'][index, ...] mask_tensor = torch.from_numpy(mask_array) instance_tensor = torch.from_numpy(instance_array) @@ -289,7 +289,7 @@ def plot( raise RGBBandsMissingError() ncols = 2 - image, mask = sample["image"][time_step, rgb_indices], sample["mask"] + image, mask = sample['image'][time_step, rgb_indices], sample['mask'] image = torch.tensor( percentile_normalization(image.numpy()) * 255, dtype=torch.uint8 @@ -297,26 +297,26 @@ def plot( mask = torch.argmax(mask, dim=0) - if "prediction" in sample: + if 'prediction' in sample: ncols += 1 - preds = torch.argmax(sample["prediction"], dim=0) + preds = torch.argmax(sample['prediction'], dim=0) fig, axs = plt.subplots(ncols=ncols, figsize=(10, 10 * ncols)) axs[0].imshow(image.permute(1, 2, 0)) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(mask) - axs[1].axis("off") + axs[1].axis('off') if show_titles: - axs[0].set_title("Image") - axs[1].set_title("Mask") + axs[0].set_title('Image') + axs[1].set_title('Mask') - if "prediction" in sample: + if 'prediction' in sample: axs[2].imshow(preds) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[2].set_title("Prediction") + axs[2].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) diff --git a/torchgeo/losses/__init__.py b/torchgeo/losses/__init__.py index 3b1a89b3eef..d30807a4bd6 100644 --- a/torchgeo/losses/__init__.py +++ b/torchgeo/losses/__init__.py @@ -5,4 +5,4 @@ from .qr import QRLoss, RQLoss -__all__ = ("QRLoss", "RQLoss") +__all__ = ('QRLoss', 'RQLoss') diff --git a/torchgeo/losses/qr.py b/torchgeo/losses/qr.py index ecffa33e9fa..ea1c2a11196 100644 --- a/torchgeo/losses/qr.py +++ b/torchgeo/losses/qr.py @@ -31,7 +31,7 @@ def forward(self, probs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: q_bar = q.mean(dim=(0, 2, 3)) qbar_log_S = (q_bar * torch.log(q_bar)).sum() - q_log_p = torch.einsum("bcxy,bcxy->bxy", q, torch.log(target)).mean() + q_log_p = torch.einsum('bcxy,bcxy->bxy', q, torch.log(target)).mean() loss = qbar_log_S - q_log_p return loss @@ -62,6 +62,6 @@ def forward(self, probs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: z = q / q.norm(p=1, dim=(0, 2, 3), keepdim=True).clamp_min(1e-12).expand_as(q) r = F.normalize(z * target, p=1, dim=1) - loss = torch.einsum("bcxy,bcxy->bxy", r, torch.log(r) - torch.log(q)).mean() + loss = torch.einsum('bcxy,bcxy->bxy', r, torch.log(r) - torch.log(q)).mean() return loss diff --git a/torchgeo/main.py b/torchgeo/main.py index 0b002cdc201..b403d4fa50c 100644 --- a/torchgeo/main.py +++ b/torchgeo/main.py @@ -18,11 +18,11 @@ def main(args: ArgsType = None) -> None: """Command-line interface to TorchGeo.""" # Taken from https://github.com/pangeo-data/cog-best-practices rasterio_best_practices = { - "GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR", - "AWS_NO_SIGN_REQUEST": "YES", - "GDAL_MAX_RAW_BLOCK_CACHE_SIZE": "200000000", - "GDAL_SWATH_SIZE": "200000000", - "VSI_CURL_CACHE_SIZE": "200000000", + 'GDAL_DISABLE_READDIR_ON_OPEN': 'EMPTY_DIR', + 'AWS_NO_SIGN_REQUEST': 'YES', + 'GDAL_MAX_RAW_BLOCK_CACHE_SIZE': '200000000', + 'GDAL_SWATH_SIZE': '200000000', + 'VSI_CURL_CACHE_SIZE': '200000000', } os.environ.update(rasterio_best_practices) @@ -32,6 +32,6 @@ def main(args: ArgsType = None) -> None: seed_everything_default=0, subclass_mode_model=True, subclass_mode_data=True, - save_config_kwargs={"overwrite": True}, + save_config_kwargs={'overwrite': True}, args=args, ) diff --git a/torchgeo/models/__init__.py b/torchgeo/models/__init__.py index 9a3431708c8..327a343bde3 100644 --- a/torchgeo/models/__init__.py +++ b/torchgeo/models/__init__.py @@ -24,33 +24,33 @@ __all__ = ( # models - "ChangeMixin", - "ChangeStar", - "ChangeStarFarSeg", - "DOFA", - "dofa_small_patch16_224", - "dofa_base_patch16_224", - "dofa_large_patch16_224", - "dofa_huge_patch16_224", - "FarSeg", - "FCN", - "FCSiamConc", - "FCSiamDiff", - "RCF", - "resnet18", - "resnet50", - "swin_v2_b", - "vit_small_patch16_224", + 'ChangeMixin', + 'ChangeStar', + 'ChangeStarFarSeg', + 'DOFA', + 'dofa_small_patch16_224', + 'dofa_base_patch16_224', + 'dofa_large_patch16_224', + 'dofa_huge_patch16_224', + 'FarSeg', + 'FCN', + 'FCSiamConc', + 'FCSiamDiff', + 'RCF', + 'resnet18', + 'resnet50', + 'swin_v2_b', + 'vit_small_patch16_224', # weights - "DOFABase16_Weights", - "DOFALarge16_Weights", - "ResNet50_Weights", - "ResNet18_Weights", - "Swin_V2_B_Weights", - "ViTSmall16_Weights", + 'DOFABase16_Weights', + 'DOFALarge16_Weights', + 'ResNet50_Weights', + 'ResNet18_Weights', + 'Swin_V2_B_Weights', + 'ViTSmall16_Weights', # utilities - "get_model", - "get_model_weights", - "get_weight", - "list_models", + 'get_model', + 'get_model_weights', + 'get_weight', + 'list_models', ) diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index 50c138a8116..696caac69fb 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -21,10 +21,10 @@ from .vit import ViTSmall16_Weights, vit_small_patch16_224 _model = { - "resnet18": resnet18, - "resnet50": resnet50, - "vit_small_patch16_224": vit_small_patch16_224, - "swin_v2_b": swin_v2_b, + 'resnet18': resnet18, + 'resnet50': resnet50, + 'vit_small_patch16_224': vit_small_patch16_224, + 'swin_v2_b': swin_v2_b, } _model_weights = { @@ -32,10 +32,10 @@ resnet50: ResNet50_Weights, vit_small_patch16_224: ViTSmall16_Weights, swin_v2_b: Swin_V2_B_Weights, - "resnet18": ResNet18_Weights, - "resnet50": ResNet50_Weights, - "vit_small_patch16_224": ViTSmall16_Weights, - "swin_v2_b": Swin_V2_B_Weights, + 'resnet18': ResNet18_Weights, + 'resnet50': ResNet50_Weights, + 'vit_small_patch16_224': ViTSmall16_Weights, + 'swin_v2_b': Swin_V2_B_Weights, } diff --git a/torchgeo/models/changestar.py b/torchgeo/models/changestar.py index 172c36c9c81..9d8da16e793 100644 --- a/torchgeo/models/changestar.py +++ b/torchgeo/models/changestar.py @@ -103,7 +103,7 @@ def __init__( dense_feature_extractor: Module, seg_classifier: Module, changemixin: ChangeMixin, - inference_mode: str = "t1t2", + inference_mode: str = 't1t2', ) -> None: """Initializes a new ChangeStar model. @@ -123,8 +123,8 @@ def __init__( self.seg_classifier = seg_classifier self.changemixin = changemixin - if inference_mode not in ["t1t2", "t2t1", "mean"]: - raise ValueError(f"Unknown inference_mode: {inference_mode}") + if inference_mode not in ['t1t2', 't2t1', 'mean']: + raise ValueError(f'Unknown inference_mode: {inference_mode}') self.inference_mode = inference_mode def forward(self, x: Tensor) -> dict[str, Tensor]: @@ -138,28 +138,28 @@ def forward(self, x: Tensor) -> dict[str, Tensor]: change detection logit/probability """ b, t, c, h, w = x.shape - x = rearrange(x, "b t c h w -> (b t) c h w") + x = rearrange(x, 'b t c h w -> (b t) c h w') # feature extraction bi_feature = self.dense_feature_extractor(x) # semantic segmentation bi_seg_logit = self.seg_classifier(bi_feature) - bi_seg_logit = rearrange(bi_seg_logit, "(b t) c h w -> b t c h w", t=t) + bi_seg_logit = rearrange(bi_seg_logit, '(b t) c h w -> b t c h w', t=t) - bi_feature = rearrange(bi_feature, "(b t) c h w -> b t c h w", t=t) + bi_feature = rearrange(bi_feature, '(b t) c h w -> b t c h w', t=t) # change detection c12, c21 = self.changemixin(bi_feature) results: dict[str, Tensor] = {} if not self.training: - results.update({"bi_seg_logit": bi_seg_logit}) - if self.inference_mode == "t1t2": - results.update({"change_prob": c12.sigmoid()}) - elif self.inference_mode == "t2t1": - results.update({"change_prob": c21.sigmoid()}) - elif self.inference_mode == "mean": + results.update({'bi_seg_logit': bi_seg_logit}) + if self.inference_mode == 't1t2': + results.update({'change_prob': c12.sigmoid()}) + elif self.inference_mode == 't2t1': + results.update({'change_prob': c21.sigmoid()}) + elif self.inference_mode == 'mean': results.update( { - "change_prob": torch.stack([c12, c21], dim=0) + 'change_prob': torch.stack([c12, c21], dim=0) .sigmoid_() .mean(dim=0) } @@ -167,8 +167,8 @@ def forward(self, x: Tensor) -> dict[str, Tensor]: else: results.update( { - "bi_seg_logit": bi_seg_logit, - "bi_change_logit": torch.stack([c12, c21], dim=1), + 'bi_seg_logit': bi_seg_logit, + 'bi_change_logit': torch.stack([c12, c21], dim=1), } ) return results @@ -186,7 +186,7 @@ class ChangeStarFarSeg(ChangeStar): def __init__( self, - backbone: str = "resnet50", + backbone: str = 'resnet50', classes: int = 1, backbone_pretrained: bool = True, ) -> None: @@ -209,5 +209,5 @@ def __init__( changemixin=ChangeMixin( in_channels=128 * 2, inner_channels=16, num_convs=4, scale_factor=4.0 ), - inference_mode="t1t2", + inference_mode='t1t2', ) diff --git a/torchgeo/models/dofa.py b/torchgeo/models/dofa.py index f63d82ea146..12faf1c70d7 100644 --- a/torchgeo/models/dofa.py +++ b/torchgeo/models/dofa.py @@ -15,7 +15,7 @@ from torch import Tensor from torchvision.models._api import Weights, WeightsEnum -__all__ = ["DOFABase16_Weights", "DOFALarge16_Weights"] +__all__ = ['DOFABase16_Weights', 'DOFALarge16_Weights'] def position_embedding(embed_dim: int, pos: Tensor) -> Tensor: @@ -37,7 +37,7 @@ def position_embedding(embed_dim: int, pos: Tensor) -> Tensor: omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) - out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + out = torch.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = torch.sin(out) # (M, D/2) emb_cos = torch.cos(out) # (M, D/2) @@ -70,7 +70,7 @@ def __init__( encoder_layer = nn.TransformerEncoderLayer( d_model=input_dim, nhead=num_heads, - activation="gelu", + activation='gelu', norm_first=False, batch_first=False, dropout=False, @@ -386,14 +386,14 @@ class DOFABase16_Weights(WeightsEnum): # type: ignore[misc] """ DOFA_MAE = Weights( - url="https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth", # noqa: E501 + url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth', # noqa: E501 transforms=_dofa_transforms, meta={ - "dataset": "SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k", - "model": "dofa_base_patch16_224", - "publication": "https://arxiv.org/abs/2403.15356", - "repo": "https://github.com/zhu-xlab/DOFA", - "ssl_method": "mae", + 'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k', + 'model': 'dofa_base_patch16_224', + 'publication': 'https://arxiv.org/abs/2403.15356', + 'repo': 'https://github.com/zhu-xlab/DOFA', + 'ssl_method': 'mae', }, ) @@ -405,14 +405,14 @@ class DOFALarge16_Weights(WeightsEnum): # type: ignore[misc] """ DOFA_MAE = Weights( - url="https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth", # noqa: E501 + url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth', # noqa: E501 transforms=_dofa_transforms, meta={ - "dataset": "SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k", - "model": "dofa_large_patch16_224", - "publication": "https://arxiv.org/abs/2403.15356", - "repo": "https://github.com/zhu-xlab/DOFA", - "ssl_method": "mae", + 'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k', + 'model': 'dofa_large_patch16_224', + 'publication': 'https://arxiv.org/abs/2403.15356', + 'repo': 'https://github.com/zhu-xlab/DOFA', + 'ssl_method': 'mae', }, ) @@ -462,10 +462,10 @@ def dofa_base_patch16_224( ) # Both fc_norm and head are generated dynamically assert set(missing_keys) <= { - "fc_norm.weight", - "fc_norm.bias", - "head.weight", - "head.bias", + 'fc_norm.weight', + 'fc_norm.bias', + 'head.weight', + 'head.bias', } assert not unexpected_keys @@ -498,10 +498,10 @@ def dofa_large_patch16_224( ) # Both fc_norm and head are generated dynamically assert set(missing_keys) <= { - "fc_norm.weight", - "fc_norm.bias", - "head.weight", - "head.bias", + 'fc_norm.weight', + 'fc_norm.bias', + 'head.weight', + 'head.bias', } assert not unexpected_keys diff --git a/torchgeo/models/farseg.py b/torchgeo/models/farseg.py index 2dfc3c01894..f57f59cde5f 100644 --- a/torchgeo/models/farseg.py +++ b/torchgeo/models/farseg.py @@ -41,7 +41,7 @@ class FarSeg(Module): def __init__( self, - backbone: str = "resnet50", + backbone: str = 'resnet50', classes: int = 16, backbone_pretrained: bool = True, ) -> None: @@ -54,21 +54,21 @@ def __init__( backbone_pretrained: whether to use pretrained weight for backbone """ super().__init__() - if backbone in ["resnet18", "resnet34"]: + if backbone in ['resnet18', 'resnet34']: max_channels = 512 - elif backbone in ["resnet50", "resnet101"]: + elif backbone in ['resnet50', 'resnet101']: max_channels = 2048 else: - raise ValueError(f"unknown backbone: {backbone}.") + raise ValueError(f'unknown backbone: {backbone}.') kwargs = {} if backbone_pretrained: kwargs = { - "weights": getattr( - torchvision.models, f"ResNet{backbone[6:]}_Weights" + 'weights': getattr( + torchvision.models, f'ResNet{backbone[6:]}_Weights' ).DEFAULT } else: - kwargs = {"weights": None} + kwargs = {'weights': None} self.backbone = getattr(resnet, backbone)(**kwargs) @@ -102,7 +102,7 @@ def forward(self, x: Tensor) -> Tensor: coarsest_features = features[-1] scene_embedding = F.adaptive_avg_pool2d(coarsest_features, 1) fpn_features = self.fpn( - OrderedDict({f"c{i + 2}": features[i] for i in range(4)}) + OrderedDict({f'c{i + 2}': features[i] for i in range(4)}) ) features = [v for k, v in fpn_features.items()] features = self.fsr(scene_embedding, features) diff --git a/torchgeo/models/fcsiam.py b/torchgeo/models/fcsiam.py index 2603e7f4c92..e1fa9a3a69e 100644 --- a/torchgeo/models/fcsiam.py +++ b/torchgeo/models/fcsiam.py @@ -23,9 +23,9 @@ class FCSiamConc(SegmentationModel): # type: ignore[misc] def __init__( self, - encoder_name: str = "resnet34", + encoder_name: str = 'resnet34', encoder_depth: int = 5, - encoder_weights: str | None = "imagenet", + encoder_weights: str | None = 'imagenet', decoder_use_batchnorm: bool = True, decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: str | None = None, @@ -86,7 +86,7 @@ def __init__( decoder_channels=decoder_channels, n_blocks=encoder_depth, use_batchnorm=decoder_use_batchnorm, - center=True if encoder_name.startswith("vgg") else False, + center=True if encoder_name.startswith('vgg') else False, attention_type=decoder_attention_type, ) @@ -97,7 +97,7 @@ def __init__( kernel_size=3, ) self.classification_head = None - self.name = f"u-{encoder_name}" + self.name = f'u-{encoder_name}' self.initialize() def forward(self, x: Tensor) -> Tensor: @@ -139,7 +139,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: **kwargs: Additional keyword arguments passed to :class:`~segmentation_models_pytorch.Unet` """ - kwargs["aux_params"] = None + kwargs['aux_params'] = None super().__init__(*args, **kwargs) def forward(self, x: Tensor) -> Tensor: diff --git a/torchgeo/models/rcf.py b/torchgeo/models/rcf.py index 4f7dac85106..99fd1ae5c3d 100644 --- a/torchgeo/models/rcf.py +++ b/torchgeo/models/rcf.py @@ -42,7 +42,7 @@ def __init__( kernel_size: int = 3, bias: float = -1.0, seed: int | None = None, - mode: str = "gaussian", + mode: str = 'gaussian', dataset: NonGeoDataset | None = None, ) -> None: """Initializes the RCF model. @@ -66,8 +66,8 @@ def __init__( dataset: a NonGeoDataset to sample from when mode is "empirical" """ super().__init__() - assert mode in ["empirical", "gaussian"] - if mode == "empirical" and dataset is None: + assert mode in ['empirical', 'gaussian'] + if mode == 'empirical' and dataset is None: raise ValueError("dataset must be provided when mode is 'empirical'") assert features % 2 == 0 num_patches = features // 2 @@ -81,7 +81,7 @@ def __init__( # them explicitely _not_ Parameters of the model (which might get updated) if # a user tries to train with this model. self.register_buffer( - "weights", + 'weights', torch.randn( num_patches, in_channels, @@ -92,12 +92,12 @@ def __init__( ), ) self.register_buffer( - "biases", torch.zeros(num_patches, requires_grad=False) + bias + 'biases', torch.zeros(num_patches, requires_grad=False) + bias ) - if mode == "empirical": + if mode == 'empirical': assert dataset is not None - num_channels, height, width = dataset[0]["image"].shape + num_channels, height, width = dataset[0]['image'].shape assert num_channels == in_channels patches = np.zeros( (num_patches, num_channels, kernel_size, kernel_size), dtype=np.float32 @@ -113,7 +113,7 @@ def __init__( ).numpy() for i in range(num_patches): - img = dataset[idxs[i]]["image"] + img = dataset[idxs[i]]['image'] patches[i] = img[ :, ys[i] : ys[i] + kernel_size, xs[i] : xs[i] + kernel_size ] @@ -123,10 +123,10 @@ def __init__( def _normalize( self, - patches: "np.typing.NDArray[np.float32]", + patches: 'np.typing.NDArray[np.float32]', min_divisor: float = 1e-8, zca_bias: float = 0.001, - ) -> "np.typing.NDArray[np.float32]": + ) -> 'np.typing.NDArray[np.float32]': """Does ZCA whitening on a set of input patches. Copied from https://github.com/Global-Policy-Lab/mosaiks-paper/blob/7efb09ed455505562d6bb04c2aaa242ef59f0a82/code/mosaiks/featurization.py#L120 @@ -165,11 +165,11 @@ def _normalize( sqrt_zca_eigs = np.sqrt(E) inv_sqrt_zca_eigs = np.diag(np.power(sqrt_zca_eigs, -1)) global_ZCA = V.dot(inv_sqrt_zca_eigs).dot(V.T) - patches_normalized: "np.typing.NDArray[np.float32]" = ( + patches_normalized: 'np.typing.NDArray[np.float32]' = ( (patches).dot(global_ZCA).dot(global_ZCA.T) ) - return patches_normalized.reshape(orig_shape).astype("float32") + return patches_normalized.reshape(orig_shape).astype('float32') def forward(self, x: Tensor) -> Tensor: """Forward pass of the RCF model. diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index 4c3c4ec54c3..7bbf3d6b588 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -11,7 +11,7 @@ from timm.models import ResNet from torchvision.models._api import Weights, WeightsEnum -__all__ = ["ResNet50_Weights", "ResNet18_Weights"] +__all__ = ['ResNet50_Weights', 'ResNet18_Weights'] # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 @@ -73,171 +73,171 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] """ LANDSAT_TM_TOA_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 7, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 7, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_TM_TOA_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 7, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 7, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) LANDSAT_ETM_TOA_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 9, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 9, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_ETM_TOA_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 9, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 9, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) LANDSAT_ETM_SR_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 6, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 6, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_ETM_SR_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 6, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 6, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) LANDSAT_OLI_TIRS_TOA_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 11, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 11, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 11, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 11, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) LANDSAT_OLI_SR_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 7, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 7, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_OLI_SR_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 7, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 7, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) SENTINEL2_ALL_MOCO = Weights( - url="https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth", # noqa: E501 + url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth', # noqa: E501 transforms=_zhu_xlab_transforms, meta={ - "dataset": "SSL4EO-S12", - "in_chans": 13, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2211.07044", - "repo": "https://github.com/zhu-xlab/SSL4EO-S12", - "ssl_method": "moco", + 'dataset': 'SSL4EO-S12', + 'in_chans': 13, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2211.07044', + 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', + 'ssl_method': 'moco', }, ) SENTINEL2_RGB_MOCO = Weights( - url="https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth", # noqa: E501 + url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth', # noqa: E501 transforms=_zhu_xlab_transforms, meta={ - "dataset": "SSL4EO-S12", - "in_chans": 3, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2211.07044", - "repo": "https://github.com/zhu-xlab/SSL4EO-S12", - "ssl_method": "moco", + 'dataset': 'SSL4EO-S12', + 'in_chans': 3, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2211.07044', + 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', + 'ssl_method': 'moco', }, ) SENTINEL2_RGB_SECO = Weights( - url="https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth", # noqa: E501 + url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth', # noqa: E501 transforms=_seco_transforms, meta={ - "dataset": "SeCo Dataset", - "in_chans": 3, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2103.16607", - "repo": "https://github.com/ServiceNow/seasonal-contrast", - "ssl_method": "seco", + 'dataset': 'SeCo Dataset', + 'in_chans': 3, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2103.16607', + 'repo': 'https://github.com/ServiceNow/seasonal-contrast', + 'ssl_method': 'seco', }, ) @@ -252,210 +252,210 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] """ FMOW_RGB_GASSL = Weights( - url="https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth", # noqa: E501 + url='https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth', # noqa: E501 transforms=_gassl_transforms, meta={ - "dataset": "fMoW Dataset", - "in_chans": 3, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2011.09980", - "repo": "https://github.com/sustainlab-group/geography-aware-ssl", - "ssl_method": "gassl", + 'dataset': 'fMoW Dataset', + 'in_chans': 3, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2011.09980', + 'repo': 'https://github.com/sustainlab-group/geography-aware-ssl', + 'ssl_method': 'gassl', }, ) LANDSAT_TM_TOA_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 7, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 7, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_TM_TOA_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 7, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 7, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) LANDSAT_ETM_TOA_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 9, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 9, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_ETM_TOA_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 9, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 9, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) LANDSAT_ETM_SR_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 6, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 6, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_ETM_SR_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 6, - "model": "resnet18", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 6, + 'model': 'resnet18', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) LANDSAT_OLI_TIRS_TOA_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 11, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 11, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 11, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 11, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) LANDSAT_OLI_SR_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 7, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 7, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_OLI_SR_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 7, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 7, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) SENTINEL1_ALL_MOCO = Weights( - url="https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth", # noqa: E501 + url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth', # noqa: E501 transforms=_zhu_xlab_transforms, meta={ - "dataset": "SSL4EO-S12", - "in_chans": 2, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2211.07044", - "repo": "https://github.com/zhu-xlab/SSL4EO-S12", - "ssl_method": "moco", + 'dataset': 'SSL4EO-S12', + 'in_chans': 2, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.07044', + 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', + 'ssl_method': 'moco', }, ) SENTINEL2_ALL_DINO = Weights( - url="https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth", # noqa: E501 + url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth', # noqa: E501 transforms=_zhu_xlab_transforms, meta={ - "dataset": "SSL4EO-S12", - "in_chans": 13, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2211.07044", - "repo": "https://github.com/zhu-xlab/SSL4EO-S12", - "ssl_method": "dino", + 'dataset': 'SSL4EO-S12', + 'in_chans': 13, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.07044', + 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', + 'ssl_method': 'dino', }, ) SENTINEL2_ALL_MOCO = Weights( - url="https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth", # noqa: E501 + url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth', # noqa: E501 transforms=_zhu_xlab_transforms, meta={ - "dataset": "SSL4EO-S12", - "in_chans": 13, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2211.07044", - "repo": "https://github.com/zhu-xlab/SSL4EO-S12", - "ssl_method": "moco", + 'dataset': 'SSL4EO-S12', + 'in_chans': 13, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.07044', + 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', + 'ssl_method': 'moco', }, ) SENTINEL2_RGB_MOCO = Weights( - url="https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth", # noqa: E501 + url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth', # noqa: E501 transforms=_zhu_xlab_transforms, meta={ - "dataset": "SSL4EO-S12", - "in_chans": 3, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2211.07044", - "repo": "https://github.com/zhu-xlab/SSL4EO-S12", - "ssl_method": "moco", + 'dataset': 'SSL4EO-S12', + 'in_chans': 3, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.07044', + 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', + 'ssl_method': 'moco', }, ) SENTINEL2_RGB_SECO = Weights( - url="https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth", # noqa: E501 + url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth', # noqa: E501 transforms=_seco_transforms, meta={ - "dataset": "SeCo Dataset", - "in_chans": 3, - "model": "resnet50", - "publication": "https://arxiv.org/abs/2103.16607", - "repo": "https://github.com/ServiceNow/seasonal-contrast", - "ssl_method": "seco", + 'dataset': 'SeCo Dataset', + 'in_chans': 3, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2103.16607', + 'repo': 'https://github.com/ServiceNow/seasonal-contrast', + 'ssl_method': 'seco', }, ) @@ -480,15 +480,15 @@ def resnet18( A ResNet-18 model. """ if weights: - kwargs["in_chans"] = weights.meta["in_chans"] + kwargs['in_chans'] = weights.meta['in_chans'] - model: ResNet = timm.create_model("resnet18", *args, **kwargs) + model: ResNet = timm.create_model('resnet18', *args, **kwargs) if weights: missing_keys, unexpected_keys = model.load_state_dict( weights.get_state_dict(progress=True), strict=False ) - assert set(missing_keys) <= {"fc.weight", "fc.bias"} + assert set(missing_keys) <= {'fc.weight', 'fc.bias'} assert not unexpected_keys return model @@ -515,15 +515,15 @@ def resnet50( A ResNet-50 model. """ if weights: - kwargs["in_chans"] = weights.meta["in_chans"] + kwargs['in_chans'] = weights.meta['in_chans'] - model: ResNet = timm.create_model("resnet50", *args, **kwargs) + model: ResNet = timm.create_model('resnet50', *args, **kwargs) if weights: missing_keys, unexpected_keys = model.load_state_dict( weights.get_state_dict(progress=True), strict=False ) - assert set(missing_keys) <= {"fc.weight", "fc.bias"} + assert set(missing_keys) <= {'fc.weight', 'fc.bias'} assert not unexpected_keys return model diff --git a/torchgeo/models/swin.py b/torchgeo/models/swin.py index a44e70ab2c6..4f27cf89ac6 100644 --- a/torchgeo/models/swin.py +++ b/torchgeo/models/swin.py @@ -12,7 +12,7 @@ from torchvision.models import SwinTransformer from torchvision.models._api import Weights, WeightsEnum -__all__ = ["Swin_V2_B_Weights"] +__all__ = ['Swin_V2_B_Weights'] # https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 # noqa: E501 # Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255). @@ -58,76 +58,76 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] """ NAIP_RGB_SI_SATLAS = Weights( - url="https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth", # noqa: E501 + url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth', # noqa: E501 transforms=_satlas_transforms, meta={ - "dataset": "Satlas", - "in_chans": 3, - "model": "swin_v2_b", - "publication": "https://arxiv.org/abs/2211.15660", - "repo": "https://github.com/allenai/satlas", + 'dataset': 'Satlas', + 'in_chans': 3, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', }, ) SENTINEL2_RGB_SI_SATLAS = Weights( - url="https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth", # noqa: E501 + url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth', # noqa: E501 transforms=_satlas_transforms, meta={ - "dataset": "Satlas", - "in_chans": 3, - "model": "swin_v2_b", - "publication": "https://arxiv.org/abs/2211.15660", - "repo": "https://github.com/allenai/satlas", + 'dataset': 'Satlas', + 'in_chans': 3, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', }, ) SENTINEL2_MS_SI_SATLAS = Weights( - url="https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth", # noqa: E501 + url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth', # noqa: E501 transforms=_sentinel2_ms_satlas_transforms, meta={ - "dataset": "Satlas", - "in_chans": 9, - "model": "swin_v2_b", - "publication": "https://arxiv.org/abs/2211.15660", - "repo": "https://github.com/allenai/satlas", - "bands": ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B11", "B12"], + 'dataset': 'Satlas', + 'in_chans': 9, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': ['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B11', 'B12'], }, ) SENTINEL1_SI_SATLAS = Weights( - url="https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth", # noqa: E501 + url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth', # noqa: E501 transforms=_satlas_transforms, meta={ - "dataset": "Satlas", - "in_chans": 2, - "model": "swin_v2_b", - "publication": "https://arxiv.org/abs/2211.15660", - "repo": "https://github.com/allenai/satlas", - "bands": ["VH", "VV"], + 'dataset': 'Satlas', + 'in_chans': 2, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': ['VH', 'VV'], }, ) LANDSAT_SI_SATLAS = Weights( - url="https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth", # noqa: E501 + url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth', # noqa: E501 transforms=_landsat_satlas_transforms, meta={ - "dataset": "Satlas", - "in_chans": 11, - "model": "swin_v2_b", - "publication": "https://arxiv.org/abs/2211.15660", - "repo": "https://github.com/allenai/satlas", - "bands": [ - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B09", - "B10", - "B11", + 'dataset': 'Satlas', + 'in_chans': 11, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': [ + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B09', + 'B10', + 'B11', ], # noqa: E501 }, ) diff --git a/torchgeo/models/vit.py b/torchgeo/models/vit.py index 60c875c203c..49e7d4d35c2 100644 --- a/torchgeo/models/vit.py +++ b/torchgeo/models/vit.py @@ -11,7 +11,7 @@ from timm.models.vision_transformer import VisionTransformer from torchvision.models._api import Weights, WeightsEnum -__all__ = ["ViTSmall16_Weights"] +__all__ = ['ViTSmall16_Weights'] # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 @@ -46,158 +46,158 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] """ LANDSAT_TM_TOA_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 7, - "model": "vit_small_patch16_224", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 7, + 'model': 'vit_small_patch16_224', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_TM_TOA_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 7, - "model": "vit_small_patch16_224", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 7, + 'model': 'vit_small_patch16_224', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) LANDSAT_ETM_TOA_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 9, - "model": "vit_small_patch16_224", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 9, + 'model': 'vit_small_patch16_224', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_ETM_TOA_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 9, - "model": "vit_small_patch16_224", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 9, + 'model': 'vit_small_patch16_224', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) LANDSAT_ETM_SR_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 6, - "model": "vit_small_patch16_224", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 6, + 'model': 'vit_small_patch16_224', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_ETM_SR_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 6, - "model": "vit_small_patch16_224", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 6, + 'model': 'vit_small_patch16_224', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) LANDSAT_OLI_TIRS_TOA_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 11, - "model": "vit_small_patch16_224", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 11, + 'model': 'vit_small_patch16_224', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 11, - "model": "vit_small_patch16_224", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 11, + 'model': 'vit_small_patch16_224', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) LANDSAT_OLI_SR_MOCO = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 7, - "model": "vit_small_patch16_224", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "moco", + 'dataset': 'SSL4EO-L', + 'in_chans': 7, + 'model': 'vit_small_patch16_224', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'moco', }, ) LANDSAT_OLI_SR_SIMCLR = Weights( - url="https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth", # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth', # noqa: E501 transforms=_ssl4eo_l_transforms, meta={ - "dataset": "SSL4EO-L", - "in_chans": 7, - "model": "vit_small_patch16_224", - "publication": "https://arxiv.org/abs/2306.09424", - "repo": "https://github.com/microsoft/torchgeo", - "ssl_method": "simclr", + 'dataset': 'SSL4EO-L', + 'in_chans': 7, + 'model': 'vit_small_patch16_224', + 'publication': 'https://arxiv.org/abs/2306.09424', + 'repo': 'https://github.com/microsoft/torchgeo', + 'ssl_method': 'simclr', }, ) SENTINEL2_ALL_DINO = Weights( - url="https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth", # noqa: E501 + url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth', # noqa: E501 transforms=_zhu_xlab_transforms, meta={ - "dataset": "SSL4EO-S12", - "in_chans": 13, - "model": "vit_small_patch16_224", - "publication": "https://arxiv.org/abs/2211.07044", - "repo": "https://github.com/zhu-xlab/SSL4EO-S12", - "ssl_method": "dino", + 'dataset': 'SSL4EO-S12', + 'in_chans': 13, + 'model': 'vit_small_patch16_224', + 'publication': 'https://arxiv.org/abs/2211.07044', + 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', + 'ssl_method': 'dino', }, ) SENTINEL2_ALL_MOCO = Weights( - url="https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth", # noqa: E501 + url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth', # noqa: E501 transforms=_zhu_xlab_transforms, meta={ - "dataset": "SSL4EO-S12", - "in_chans": 13, - "model": "vit_small_patch16_224", - "publication": "https://arxiv.org/abs/2211.07044", - "repo": "https://github.com/zhu-xlab/SSL4EO-S12", - "ssl_method": "moco", + 'dataset': 'SSL4EO-S12', + 'in_chans': 13, + 'model': 'vit_small_patch16_224', + 'publication': 'https://arxiv.org/abs/2211.07044', + 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', + 'ssl_method': 'moco', }, ) @@ -222,17 +222,17 @@ def vit_small_patch16_224( A ViT small 16 model. """ if weights: - kwargs["in_chans"] = weights.meta["in_chans"] + kwargs['in_chans'] = weights.meta['in_chans'] model: VisionTransformer = timm.create_model( - "vit_small_patch16_224", *args, **kwargs + 'vit_small_patch16_224', *args, **kwargs ) if weights: missing_keys, unexpected_keys = model.load_state_dict( weights.get_state_dict(progress=True), strict=False ) - assert set(missing_keys) <= {"head.weight", "head.bias"} + assert set(missing_keys) <= {'head.weight', 'head.bias'} assert not unexpected_keys return model diff --git a/torchgeo/samplers/__init__.py b/torchgeo/samplers/__init__.py index ae449228171..ba995c9c782 100644 --- a/torchgeo/samplers/__init__.py +++ b/torchgeo/samplers/__init__.py @@ -10,17 +10,17 @@ __all__ = ( # Samplers - "GridGeoSampler", - "PreChippedGeoSampler", - "RandomGeoSampler", + 'GridGeoSampler', + 'PreChippedGeoSampler', + 'RandomGeoSampler', # Batch samplers - "RandomBatchGeoSampler", + 'RandomBatchGeoSampler', # Base classes - "GeoSampler", - "BatchGeoSampler", + 'GeoSampler', + 'BatchGeoSampler', # Utilities - "get_random_bounding_box", - "tile_to_chips", + 'get_random_bounding_box', + 'tile_to_chips', # Constants - "Units", + 'Units', ) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index ec8d916a012..2b936401afe 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -14,16 +14,16 @@ __all__ = ( # Supervised - "ClassificationTask", - "MultiLabelClassificationTask", - "ObjectDetectionTask", - "PixelwiseRegressionTask", - "RegressionTask", - "SemanticSegmentationTask", + 'ClassificationTask', + 'MultiLabelClassificationTask', + 'ObjectDetectionTask', + 'PixelwiseRegressionTask', + 'RegressionTask', + 'SemanticSegmentationTask', # Self-supervised - "BYOLTask", - "MoCoTask", - "SimCLRTask", + 'BYOLTask', + 'MoCoTask', + 'SimCLRTask', # Base classes - "BaseTask", + 'BaseTask', ) diff --git a/torchgeo/trainers/base.py b/torchgeo/trainers/base.py index 338961e3cfd..1f50ad0ab58 100644 --- a/torchgeo/trainers/base.py +++ b/torchgeo/trainers/base.py @@ -23,10 +23,10 @@ class BaseTask(LightningModule, ABC): model: Any #: Performance metric to monitor in learning rate scheduler and callbacks. - monitor = "val_loss" + monitor = 'val_loss' #: Whether the goal is to minimize or maximize the performance metric to monitor. - mode = "min" + mode = 'min' def __init__(self, ignore: Sequence[str] | str | None = None) -> None: """Initialize a new BaseTask instance. @@ -52,17 +52,17 @@ def configure_metrics(self) -> None: def configure_optimizers( self, - ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig": + ) -> 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig': """Initialize the optimizer and learning rate scheduler. Returns: Optimizer and learning rate scheduler. """ - optimizer = AdamW(self.parameters(), lr=self.hparams["lr"]) - scheduler = ReduceLROnPlateau(optimizer, patience=self.hparams["patience"]) + optimizer = AdamW(self.parameters(), lr=self.hparams['lr']) + scheduler = ReduceLROnPlateau(optimizer, patience=self.hparams['patience']) return { - "optimizer": optimizer, - "lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor}, + 'optimizer': optimizer, + 'lr_scheduler': {'scheduler': scheduler, 'monitor': self.monitor}, } def forward(self, *args: Any, **kwargs: Any) -> Any: diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 2588ee3d1c5..35243eaa545 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -293,11 +293,11 @@ class BYOLTask(BaseTask): * https://arxiv.org/abs/2006.07733 """ - monitor = "train_loss" + monitor = 'train_loss' def __init__( self, - model: str = "resnet50", + model: str = 'resnet50', weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, lr: float = 1e-3, @@ -324,16 +324,16 @@ def __init__( renamed to *model*, *lr*, and *patience*. """ self.weights = weights - super().__init__(ignore="weights") + super().__init__(ignore='weights') def configure_models(self) -> None: """Initialize the model.""" weights = self.weights - in_channels: int = self.hparams["in_channels"] + in_channels: int = self.hparams['in_channels'] # Create backbone backbone = timm.create_model( - self.hparams["model"], in_chans=in_channels, pretrained=weights is True + self.hparams['model'], in_chans=in_channels, pretrained=weights is True ) # Load weights @@ -364,10 +364,10 @@ def training_step( Raises: AssertionError: If channel dimensions are incorrect. """ - x = batch["image"] + x = batch['image'] batch_size = x.shape[0] - in_channels = self.hparams["in_channels"] + in_channels = self.hparams['in_channels'] assert x.size(1) == in_channels or x.size(1) == 2 * in_channels if x.size(1) == in_channels: @@ -389,7 +389,7 @@ def training_step( loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1)) - self.log("train_loss", loss, batch_size=batch_size) + self.log('train_loss', loss, batch_size=batch_size) self.model.update_target() return loss diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 9f6c828b665..cc293099519 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -34,11 +34,11 @@ class ClassificationTask(BaseTask): def __init__( self, - model: str = "resnet50", + model: str = 'resnet50', weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, num_classes: int = 1000, - loss: str = "ce", + loss: str = 'ce', class_weights: Tensor | None = None, lr: float = 1e-3, patience: int = 10, @@ -73,7 +73,7 @@ class and used with 'ce' loss. *lr* and *patience*. """ self.weights = weights - super().__init__(ignore="weights") + super().__init__(ignore='weights') def configure_models(self) -> None: """Initialize the model.""" @@ -81,9 +81,9 @@ def configure_models(self) -> None: # Create model self.model = timm.create_model( - self.hparams["model"], - num_classes=self.hparams["num_classes"], - in_chans=self.hparams["in_channels"], + self.hparams['model'], + num_classes=self.hparams['num_classes'], + in_chans=self.hparams['in_channels'], pretrained=weights is True, ) @@ -98,7 +98,7 @@ def configure_models(self) -> None: utils.load_state_dict(self.model, state_dict) # Freeze backbone and unfreeze classifier head - if self.hparams["freeze_backbone"]: + if self.hparams['freeze_backbone']: for param in self.model.parameters(): param.requires_grad = False for param in self.model.get_classifier().parameters(): @@ -110,17 +110,17 @@ def configure_losses(self) -> None: Raises: ValueError: If *loss* is invalid. """ - loss: str = self.hparams["loss"] - if loss == "ce": + loss: str = self.hparams['loss'] + if loss == 'ce': self.criterion: nn.Module = nn.CrossEntropyLoss( - weight=self.hparams["class_weights"] + weight=self.hparams['class_weights'] ) - elif loss == "bce": + elif loss == 'bce': self.criterion = nn.BCEWithLogitsLoss() - elif loss == "jaccard": - self.criterion = JaccardLoss(mode="multiclass") - elif loss == "focal": - self.criterion = FocalLoss(mode="multiclass", normalized=True) + elif loss == 'jaccard': + self.criterion = JaccardLoss(mode='multiclass') + elif loss == 'focal': + self.criterion = FocalLoss(mode='multiclass', normalized=True) else: raise ValueError(f"Loss type '{loss}' is not valid.") @@ -145,23 +145,23 @@ def configure_metrics(self) -> None: """ metrics = MetricCollection( { - "OverallAccuracy": MulticlassAccuracy( - num_classes=self.hparams["num_classes"], average="micro" + 'OverallAccuracy': MulticlassAccuracy( + num_classes=self.hparams['num_classes'], average='micro' ), - "AverageAccuracy": MulticlassAccuracy( - num_classes=self.hparams["num_classes"], average="macro" + 'AverageAccuracy': MulticlassAccuracy( + num_classes=self.hparams['num_classes'], average='macro' ), - "JaccardIndex": MulticlassJaccardIndex( - num_classes=self.hparams["num_classes"] + 'JaccardIndex': MulticlassJaccardIndex( + num_classes=self.hparams['num_classes'] ), - "F1Score": MulticlassFBetaScore( - num_classes=self.hparams["num_classes"], beta=1.0, average="micro" + 'F1Score': MulticlassFBetaScore( + num_classes=self.hparams['num_classes'], beta=1.0, average='micro' ), } ) - self.train_metrics = metrics.clone(prefix="train_") - self.val_metrics = metrics.clone(prefix="val_") - self.test_metrics = metrics.clone(prefix="test_") + self.train_metrics = metrics.clone(prefix='train_') + self.val_metrics = metrics.clone(prefix='val_') + self.test_metrics = metrics.clone(prefix='test_') def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 @@ -176,12 +176,12 @@ def training_step( Returns: The loss tensor. """ - x = batch["image"] - y = batch["label"] + x = batch['image'] + y = batch['label'] batch_size = x.shape[0] y_hat = self(x) loss: Tensor = self.criterion(y_hat, y) - self.log("train_loss", loss, batch_size=batch_size) + self.log('train_loss', loss, batch_size=batch_size) self.train_metrics(y_hat, y) self.log_dict(self.train_metrics, batch_size=batch_size) @@ -197,26 +197,26 @@ def validation_step( batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - x = batch["image"] - y = batch["label"] + x = batch['image'] + y = batch['label'] batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log("val_loss", loss, batch_size=batch_size) + self.log('val_loss', loss, batch_size=batch_size) self.val_metrics(y_hat, y) self.log_dict(self.val_metrics, batch_size=batch_size) if ( batch_idx < 10 - and hasattr(self.trainer, "datamodule") - and hasattr(self.trainer.datamodule, "plot") + and hasattr(self.trainer, 'datamodule') + and hasattr(self.trainer.datamodule, 'plot') and self.logger - and hasattr(self.logger, "experiment") - and hasattr(self.logger.experiment, "add_figure") + and hasattr(self.logger, 'experiment') + and hasattr(self.logger.experiment, 'add_figure') ): datamodule = self.trainer.datamodule - batch["prediction"] = y_hat.argmax(dim=-1) - for key in ["image", "label", "prediction"]: + batch['prediction'] = y_hat.argmax(dim=-1) + for key in ['image', 'label', 'prediction']: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] @@ -229,7 +229,7 @@ def validation_step( if fig: summary_writer = self.logger.experiment summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step + f'image/{batch_idx}', fig, global_step=self.global_step ) plt.close() @@ -241,12 +241,12 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - x = batch["image"] - y = batch["label"] + x = batch['image'] + y = batch['label'] batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log("test_loss", loss, batch_size=batch_size) + self.log('test_loss', loss, batch_size=batch_size) self.test_metrics(y_hat, y) self.log_dict(self.test_metrics, batch_size=batch_size) @@ -263,7 +263,7 @@ def predict_step( Returns: Output predicted probabilities. """ - x = batch["image"] + x = batch['image'] y_hat: Tensor = self(x).softmax(dim=-1) return y_hat @@ -290,20 +290,20 @@ def configure_metrics(self) -> None: """ metrics = MetricCollection( { - "OverallAccuracy": MultilabelAccuracy( - num_labels=self.hparams["num_classes"], average="micro" + 'OverallAccuracy': MultilabelAccuracy( + num_labels=self.hparams['num_classes'], average='micro' ), - "AverageAccuracy": MultilabelAccuracy( - num_labels=self.hparams["num_classes"], average="macro" + 'AverageAccuracy': MultilabelAccuracy( + num_labels=self.hparams['num_classes'], average='macro' ), - "F1Score": MultilabelFBetaScore( - num_labels=self.hparams["num_classes"], beta=1.0, average="micro" + 'F1Score': MultilabelFBetaScore( + num_labels=self.hparams['num_classes'], beta=1.0, average='micro' ), } ) - self.train_metrics = metrics.clone(prefix="train_") - self.val_metrics = metrics.clone(prefix="val_") - self.test_metrics = metrics.clone(prefix="test_") + self.train_metrics = metrics.clone(prefix='train_') + self.val_metrics = metrics.clone(prefix='val_') + self.test_metrics = metrics.clone(prefix='test_') def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 @@ -318,13 +318,13 @@ def training_step( Returns: The loss tensor. """ - x = batch["image"] - y = batch["label"] + x = batch['image'] + y = batch['label'] batch_size = x.shape[0] y_hat = self(x) y_hat_hard = torch.sigmoid(y_hat) loss: Tensor = self.criterion(y_hat, y.to(torch.float)) - self.log("train_loss", loss, batch_size=batch_size) + self.log('train_loss', loss, batch_size=batch_size) self.train_metrics(y_hat_hard, y) self.log_dict(self.train_metrics) @@ -340,27 +340,27 @@ def validation_step( batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - x = batch["image"] - y = batch["label"] + x = batch['image'] + y = batch['label'] batch_size = x.shape[0] y_hat = self(x) y_hat_hard = torch.sigmoid(y_hat) loss = self.criterion(y_hat, y.to(torch.float)) - self.log("val_loss", loss, batch_size=batch_size) + self.log('val_loss', loss, batch_size=batch_size) self.val_metrics(y_hat_hard, y) self.log_dict(self.val_metrics, batch_size=batch_size) if ( batch_idx < 10 - and hasattr(self.trainer, "datamodule") - and hasattr(self.trainer.datamodule, "plot") + and hasattr(self.trainer, 'datamodule') + and hasattr(self.trainer.datamodule, 'plot') and self.logger - and hasattr(self.logger, "experiment") - and hasattr(self.logger.experiment, "add_figure") + and hasattr(self.logger, 'experiment') + and hasattr(self.logger.experiment, 'add_figure') ): datamodule = self.trainer.datamodule - batch["prediction"] = y_hat_hard - for key in ["image", "label", "prediction"]: + batch['prediction'] = y_hat_hard + for key in ['image', 'label', 'prediction']: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] @@ -373,7 +373,7 @@ def validation_step( if fig: summary_writer = self.logger.experiment summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step + f'image/{batch_idx}', fig, global_step=self.global_step ) def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: @@ -384,13 +384,13 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - x = batch["image"] - y = batch["label"] + x = batch['image'] + y = batch['label'] batch_size = x.shape[0] y_hat = self(x) y_hat_hard = torch.sigmoid(y_hat) loss = self.criterion(y_hat, y.to(torch.float)) - self.log("test_loss", loss, batch_size=batch_size) + self.log('test_loss', loss, batch_size=batch_size) self.test_metrics(y_hat_hard, y) self.log_dict(self.test_metrics, batch_size=batch_size) @@ -407,6 +407,6 @@ def predict_step( Returns: Output predicted probabilities. """ - x = batch["image"] + x = batch['image'] y_hat = torch.sigmoid(self(x)) return y_hat diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index c74fd7156ef..30916b5df98 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -23,27 +23,27 @@ from .base import BaseTask BACKBONE_LAT_DIM_MAP = { - "resnet18": 512, - "resnet34": 512, - "resnet50": 2048, - "resnet101": 2048, - "resnet152": 2048, - "resnext50_32x4d": 2048, - "resnext101_32x8d": 2048, - "wide_resnet50_2": 2048, - "wide_resnet101_2": 2048, + 'resnet18': 512, + 'resnet34': 512, + 'resnet50': 2048, + 'resnet101': 2048, + 'resnet152': 2048, + 'resnext50_32x4d': 2048, + 'resnext101_32x8d': 2048, + 'wide_resnet50_2': 2048, + 'wide_resnet101_2': 2048, } BACKBONE_WEIGHT_MAP = { - "resnet18": R.ResNet18_Weights.DEFAULT, - "resnet34": R.ResNet34_Weights.DEFAULT, - "resnet50": R.ResNet50_Weights.DEFAULT, - "resnet101": R.ResNet101_Weights.DEFAULT, - "resnet152": R.ResNet152_Weights.DEFAULT, - "resnext50_32x4d": R.ResNeXt50_32X4D_Weights.DEFAULT, - "resnext101_32x8d": R.ResNeXt101_32X8D_Weights.DEFAULT, - "wide_resnet50_2": R.Wide_ResNet50_2_Weights.DEFAULT, - "wide_resnet101_2": R.Wide_ResNet101_2_Weights.DEFAULT, + 'resnet18': R.ResNet18_Weights.DEFAULT, + 'resnet34': R.ResNet34_Weights.DEFAULT, + 'resnet50': R.ResNet50_Weights.DEFAULT, + 'resnet101': R.ResNet101_Weights.DEFAULT, + 'resnet152': R.ResNet152_Weights.DEFAULT, + 'resnext50_32x4d': R.ResNeXt50_32X4D_Weights.DEFAULT, + 'resnext101_32x8d': R.ResNeXt101_32X8D_Weights.DEFAULT, + 'wide_resnet50_2': R.Wide_ResNet50_2_Weights.DEFAULT, + 'wide_resnet101_2': R.Wide_ResNet101_2_Weights.DEFAULT, } @@ -53,13 +53,13 @@ class ObjectDetectionTask(BaseTask): .. versionadded:: 0.4 """ - monitor = "val_map" - mode = "max" + monitor = 'val_map' + mode = 'max' def __init__( self, - model: str = "faster-rcnn", - backbone: str = "resnet50", + model: str = 'faster-rcnn', + backbone: str = 'resnet50', weights: bool | None = None, in_channels: int = 3, num_classes: int = 1000, @@ -107,34 +107,34 @@ def configure_models(self) -> None: Raises: ValueError: If *model* or *backbone* are invalid. """ - backbone: str = self.hparams["backbone"] - model: str = self.hparams["model"] - weights: bool | None = self.hparams["weights"] - num_classes: int = self.hparams["num_classes"] - freeze_backbone: bool = self.hparams["freeze_backbone"] + backbone: str = self.hparams['backbone'] + model: str = self.hparams['model'] + weights: bool | None = self.hparams['weights'] + num_classes: int = self.hparams['num_classes'] + freeze_backbone: bool = self.hparams['freeze_backbone'] if backbone in BACKBONE_LAT_DIM_MAP: kwargs = { - "backbone_name": backbone, - "trainable_layers": self.hparams["trainable_layers"], + 'backbone_name': backbone, + 'trainable_layers': self.hparams['trainable_layers'], } if weights: - kwargs["weights"] = BACKBONE_WEIGHT_MAP[backbone] + kwargs['weights'] = BACKBONE_WEIGHT_MAP[backbone] else: - kwargs["weights"] = None + kwargs['weights'] = None latent_dim = BACKBONE_LAT_DIM_MAP[backbone] else: raise ValueError(f"Backbone type '{backbone}' is not valid.") - if model == "faster-rcnn": + if model == 'faster-rcnn': model_backbone = resnet_fpn_backbone(**kwargs) anchor_generator = AnchorGenerator( sizes=((32), (64), (128), (256), (512)), aspect_ratios=((0.5, 1.0, 2.0)) ) roi_pooler = MultiScaleRoIAlign( - featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2 + featmap_names=['0', '1', '2', '3'], output_size=7, sampling_ratio=2 ) if freeze_backbone: @@ -147,9 +147,9 @@ def configure_models(self) -> None: rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, ) - elif model == "fcos": - kwargs["extra_blocks"] = feature_pyramid_network.LastLevelP6P7(256, 256) - kwargs["norm_layer"] = ( + elif model == 'fcos': + kwargs['extra_blocks'] = feature_pyramid_network.LastLevelP6P7(256, 256) + kwargs['norm_layer'] = ( misc.FrozenBatchNorm2d if weights else torch.nn.BatchNorm2d ) @@ -166,8 +166,8 @@ def configure_models(self) -> None: self.model = torchvision.models.detection.FCOS( model_backbone, num_classes, anchor_generator=anchor_generator ) - elif model == "retinanet": - kwargs["extra_blocks"] = feature_pyramid_network.LastLevelP6P7( + elif model == 'retinanet': + kwargs['extra_blocks'] = feature_pyramid_network.LastLevelP6P7( latent_dim, 256 ) model_backbone = resnet_fpn_backbone(**kwargs) @@ -218,9 +218,9 @@ def configure_metrics(self) -> None: * 'Macro' averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes. """ - metrics = MetricCollection([MeanAveragePrecision(average="macro")]) - self.val_metrics = metrics.clone(prefix="val_") - self.test_metrics = metrics.clone(prefix="test_") + metrics = MetricCollection([MeanAveragePrecision(average='macro')]) + self.val_metrics = metrics.clone(prefix='val_') + self.test_metrics = metrics.clone(prefix='test_') def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 @@ -235,10 +235,10 @@ def training_step( Returns: The loss tensor. """ - x = batch["image"] + x = batch['image'] batch_size = x.shape[0] y = [ - {"boxes": batch["boxes"][i], "labels": batch["labels"][i]} + {'boxes': batch['boxes'][i], 'labels': batch['labels'][i]} for i in range(batch_size) ] loss_dict = self(x, y) @@ -256,38 +256,38 @@ def validation_step( batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - x = batch["image"] + x = batch['image'] batch_size = x.shape[0] y = [ - {"boxes": batch["boxes"][i], "labels": batch["labels"][i]} + {'boxes': batch['boxes'][i], 'labels': batch['labels'][i]} for i in range(batch_size) ] y_hat = self(x) metrics = self.val_metrics(y_hat, y) # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714 - metrics.pop("val_classes", None) + metrics.pop('val_classes', None) self.log_dict(metrics, batch_size=batch_size) if ( batch_idx < 10 - and hasattr(self.trainer, "datamodule") - and hasattr(self.trainer.datamodule, "plot") + and hasattr(self.trainer, 'datamodule') + and hasattr(self.trainer.datamodule, 'plot') and self.logger - and hasattr(self.logger, "experiment") - and hasattr(self.logger.experiment, "add_figure") + and hasattr(self.logger, 'experiment') + and hasattr(self.logger.experiment, 'add_figure') ): datamodule = self.trainer.datamodule - batch["prediction_boxes"] = [b["boxes"].cpu() for b in y_hat] - batch["prediction_labels"] = [b["labels"].cpu() for b in y_hat] - batch["prediction_scores"] = [b["scores"].cpu() for b in y_hat] - batch["image"] = batch["image"].cpu() + batch['prediction_boxes'] = [b['boxes'].cpu() for b in y_hat] + batch['prediction_labels'] = [b['labels'].cpu() for b in y_hat] + batch['prediction_scores'] = [b['scores'].cpu() for b in y_hat] + batch['image'] = batch['image'].cpu() sample = unbind_samples(batch)[0] # Convert image to uint8 for plotting - if torch.is_floating_point(sample["image"]): - sample["image"] *= 255 - sample["image"] = sample["image"].to(torch.uint8) + if torch.is_floating_point(sample['image']): + sample['image'] *= 255 + sample['image'] = sample['image'].to(torch.uint8) fig: Figure | None = None try: @@ -298,7 +298,7 @@ def validation_step( if fig: summary_writer = self.logger.experiment summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step + f'image/{batch_idx}', fig, global_step=self.global_step ) plt.close() @@ -310,17 +310,17 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - x = batch["image"] + x = batch['image'] batch_size = x.shape[0] y = [ - {"boxes": batch["boxes"][i], "labels": batch["labels"][i]} + {'boxes': batch['boxes'][i], 'labels': batch['labels'][i]} for i in range(batch_size) ] y_hat = self(x) metrics = self.test_metrics(y_hat, y) # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714 - metrics.pop("test_classes", None) + metrics.pop('test_classes', None) self.log_dict(metrics, batch_size=batch_size) @@ -337,6 +337,6 @@ def predict_step( Returns: Output predicted probabilities. """ - x = batch["image"] + x = batch['image'] y_hat: list[dict[str, Tensor]] = self(x) return y_hat diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index e06aecdf7b4..73646c3868a 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -67,7 +67,7 @@ def moco_augmentations( K.RandomContrast(contrast=(0.6, 1.4), p=1.0), K.RandomHorizontalFlip(), K.RandomVerticalFlip(), # added - data_keys=["input"], + data_keys=['input'], ) elif version == 2: # Similar to SimCLR: https://arxiv.org/abs/2002.05709 @@ -83,7 +83,7 @@ def moco_augmentations( K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=0.5), K.RandomHorizontalFlip(), K.RandomVerticalFlip(), # added - data_keys=["input"], + data_keys=['input'], ) else: # Same as BYOL: https://arxiv.org/abs/2006.07733 @@ -99,7 +99,7 @@ def moco_augmentations( K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=1), K.RandomHorizontalFlip(), K.RandomVerticalFlip(), # added - data_keys=["input"], + data_keys=['input'], ) aug2 = K.AugmentationSequential( K.RandomResizedCrop(size=(size, size), scale=(0.08, 1)), @@ -114,7 +114,7 @@ def moco_augmentations( K.RandomSolarize(p=0.2), K.RandomHorizontalFlip(), K.RandomVerticalFlip(), # added - data_keys=["input"], + data_keys=['input'], ) return aug1, aug2 @@ -136,11 +136,11 @@ class MoCoTask(BaseTask): .. versionadded:: 0.5 """ - monitor = "train_loss" + monitor = 'train_loss' def __init__( self, - model: str = "resnet50", + model: str = 'resnet50', weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, version: int = 3, @@ -206,20 +206,20 @@ def __init__( assert version in range(1, 4) if version == 1: if memory_bank_size == 0: - warnings.warn("MoCo v1 uses a memory bank") + warnings.warn('MoCo v1 uses a memory bank') elif version == 2: if layers > 2: - warnings.warn("MoCo v2 only uses 2 layers in its projection head") + warnings.warn('MoCo v2 only uses 2 layers in its projection head') if memory_bank_size == 0: - warnings.warn("MoCo v2 uses a memory bank") + warnings.warn('MoCo v2 uses a memory bank') elif version == 3: if layers == 2: - warnings.warn("MoCo v3 uses 3 layers in its projection head") + warnings.warn('MoCo v3 uses 3 layers in its projection head') if memory_bank_size > 0: - warnings.warn("MoCo v3 does not use a memory bank") + warnings.warn('MoCo v3 does not use a memory bank') self.weights = weights - super().__init__(ignore=["weights", "augmentation1", "augmentation2"]) + super().__init__(ignore=['weights', 'augmentation1', 'augmentation2']) grayscale_weights = grayscale_weights or torch.ones(in_channels) aug1, aug2 = moco_augmentations(version, size, grayscale_weights) @@ -228,13 +228,13 @@ def __init__( def configure_models(self) -> None: """Initialize the model.""" - model: str = self.hparams["model"] + model: str = self.hparams['model'] weights = self.weights - in_channels: int = self.hparams["in_channels"] - version: int = self.hparams["version"] - layers: int = self.hparams["layers"] - hidden_dim: int = self.hparams["hidden_dim"] - output_dim: int = self.hparams["output_dim"] + in_channels: int = self.hparams['in_channels'] + version: int = self.hparams['version'] + layers: int = self.hparams['layers'] + hidden_dim: int = self.hparams['hidden_dim'] + output_dim: int = self.hparams['output_dim'] # Create backbone self.backbone = timm.create_model( @@ -278,31 +278,31 @@ def configure_losses(self) -> None: """Initialize the loss criterion.""" try: self.criterion = NTXentLoss( - self.hparams["temperature"], - (self.hparams["memory_bank_size"], self.hparams["output_dim"]), - self.hparams["gather_distributed"], + self.hparams['temperature'], + (self.hparams['memory_bank_size'], self.hparams['output_dim']), + self.hparams['gather_distributed'], ) except TypeError: # lightly 1.4.24 and older self.criterion = NTXentLoss( - self.hparams["temperature"], - self.hparams["memory_bank_size"], - self.hparams["gather_distributed"], + self.hparams['temperature'], + self.hparams['memory_bank_size'], + self.hparams['gather_distributed'], ) def configure_optimizers( self, - ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig": + ) -> 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig': """Initialize the optimizer and learning rate scheduler. Returns: Optimizer and learning rate scheduler. """ - if self.hparams["version"] == 3: + if self.hparams['version'] == 3: optimizer: Optimizer = AdamW( params=self.parameters(), - lr=self.hparams["lr"], - weight_decay=self.hparams["weight_decay"], + lr=self.hparams['lr'], + weight_decay=self.hparams['weight_decay'], ) warmup_epochs = 40 max_epochs = 200 @@ -323,16 +323,16 @@ def configure_optimizers( else: optimizer = SGD( params=self.parameters(), - lr=self.hparams["lr"], - momentum=self.hparams["momentum"], - weight_decay=self.hparams["weight_decay"], + lr=self.hparams['lr'], + momentum=self.hparams['momentum'], + weight_decay=self.hparams['weight_decay'], ) scheduler = MultiStepLR( - optimizer=optimizer, milestones=self.hparams["schedule"] + optimizer=optimizer, milestones=self.hparams['schedule'] ) return { - "optimizer": optimizer, - "lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor}, + 'optimizer': optimizer, + 'lr_scheduler': {'scheduler': scheduler, 'monitor': self.monitor}, } def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: @@ -346,9 +346,9 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: """ h: Tensor = self.backbone(x) q = h - if self.hparams["version"] > 1: + if self.hparams['version'] > 1: q = self.projection_head(q) - if self.hparams["version"] == 3: + if self.hparams['version'] == 3: q = self.prediction_head(q) return q, h @@ -362,7 +362,7 @@ def forward_momentum(self, x: Tensor) -> Tensor: Output from the momentum model. """ k: Tensor = self.backbone_momentum(x) - if self.hparams["version"] > 1: + if self.hparams['version'] > 1: k = self.projection_head_momentum(k) return k @@ -379,10 +379,10 @@ def training_step( Returns: The loss tensor. """ - x = batch["image"] + x = batch['image'] batch_size = x.shape[0] - in_channels = self.hparams["in_channels"] + in_channels = self.hparams['in_channels'] assert x.size(1) == in_channels or x.size(1) == 2 * in_channels if x.size(1) == in_channels: @@ -396,21 +396,21 @@ def training_step( x1 = self.augmentation1(x1) x2 = self.augmentation2(x2) - m = self.hparams["moco_momentum"] - if self.hparams["version"] == 1: + m = self.hparams['moco_momentum'] + if self.hparams['version'] == 1: q, h1 = self.forward(x1) with torch.no_grad(): update_momentum(self.backbone, self.backbone_momentum, m) k = self.forward_momentum(x2) loss: Tensor = self.criterion(q, k) - elif self.hparams["version"] == 2: + elif self.hparams['version'] == 2: q, h1 = self.forward(x1) with torch.no_grad(): update_momentum(self.backbone, self.backbone_momentum, m) update_momentum(self.projection_head, self.projection_head_momentum, m) k = self.forward_momentum(x2) loss = self.criterion(q, k) - if self.hparams["version"] == 3: + if self.hparams['version'] == 3: m = cosine_schedule(self.current_epoch, self.trainer.max_epochs, m, 1) q1, h1 = self.forward(x1) q2, h2 = self.forward(x2) @@ -429,8 +429,8 @@ def training_step( output_std = torch.mean(output_std, dim=0) self.avg_output_std = 0.9 * self.avg_output_std + (1 - 0.9) * output_std.item() - self.log("train_ssl_std", self.avg_output_std, batch_size=batch_size) - self.log("train_loss", loss, batch_size=batch_size) + self.log('train_ssl_std', self.avg_output_std, batch_size=batch_size) + self.log('train_loss', loss, batch_size=batch_size) return loss diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 4fb3a9209fd..86c3423c656 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -25,17 +25,17 @@ class RegressionTask(BaseTask): """Regression.""" - target_key = "label" + target_key = 'label' def __init__( self, - model: str = "resnet50", - backbone: str = "resnet50", + model: str = 'resnet50', + backbone: str = 'resnet50', weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, num_outputs: int = 1, num_filters: int = 3, - loss: str = "mse", + loss: str = 'mse', lr: float = 1e-3, patience: int = 10, freeze_backbone: bool = False, @@ -77,16 +77,16 @@ def __init__( *lr* and *patience*. """ self.weights = weights - super().__init__(ignore="weights") + super().__init__(ignore='weights') def configure_models(self) -> None: """Initialize the model.""" # Create model weights = self.weights self.model = timm.create_model( - self.hparams["model"], - num_classes=self.hparams["num_outputs"], - in_chans=self.hparams["in_channels"], + self.hparams['model'], + num_classes=self.hparams['num_outputs'], + in_chans=self.hparams['in_channels'], pretrained=weights is True, ) @@ -101,7 +101,7 @@ def configure_models(self) -> None: utils.load_state_dict(self.model, state_dict) # Freeze backbone and unfreeze classifier head - if self.hparams["freeze_backbone"]: + if self.hparams['freeze_backbone']: for param in self.model.parameters(): param.requires_grad = False for param in self.model.get_classifier().parameters(): @@ -113,10 +113,10 @@ def configure_losses(self) -> None: Raises: ValueError: If *loss* is invalid. """ - loss: str = self.hparams["loss"] - if loss == "mse": + loss: str = self.hparams['loss'] + if loss == 'mse': self.criterion: nn.Module = nn.MSELoss() - elif loss == "mae": + elif loss == 'mae': self.criterion = nn.L1Loss() else: raise ValueError( @@ -136,14 +136,14 @@ def configure_metrics(self) -> None: """ metrics = MetricCollection( { - "RMSE": MeanSquaredError(squared=False), - "MSE": MeanSquaredError(squared=True), - "MAE": MeanAbsoluteError(), + 'RMSE': MeanSquaredError(squared=False), + 'MSE': MeanSquaredError(squared=True), + 'MAE': MeanAbsoluteError(), } ) - self.train_metrics = metrics.clone(prefix="train_") - self.val_metrics = metrics.clone(prefix="val_") - self.test_metrics = metrics.clone(prefix="test_") + self.train_metrics = metrics.clone(prefix='train_') + self.val_metrics = metrics.clone(prefix='val_') + self.test_metrics = metrics.clone(prefix='test_') def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 @@ -158,7 +158,7 @@ def training_step( Returns: The loss tensor. """ - x = batch["image"] + x = batch['image'] batch_size = x.shape[0] # TODO: remove .to(...) once we have a real pixelwise regression dataset y = batch[self.target_key].to(torch.float) @@ -166,7 +166,7 @@ def training_step( if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) loss: Tensor = self.criterion(y_hat, y) - self.log("train_loss", loss, batch_size=batch_size) + self.log('train_loss', loss, batch_size=batch_size) self.train_metrics(y_hat, y) self.log_dict(self.train_metrics, batch_size=batch_size) @@ -182,7 +182,7 @@ def validation_step( batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - x = batch["image"] + x = batch['image'] batch_size = x.shape[0] # TODO: remove .to(...) once we have a real pixelwise regression dataset y = batch[self.target_key].to(torch.float) @@ -190,24 +190,24 @@ def validation_step( if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) loss = self.criterion(y_hat, y) - self.log("val_loss", loss, batch_size=batch_size) + self.log('val_loss', loss, batch_size=batch_size) self.val_metrics(y_hat, y) self.log_dict(self.val_metrics, batch_size=batch_size) if ( batch_idx < 10 - and hasattr(self.trainer, "datamodule") - and hasattr(self.trainer.datamodule, "plot") + and hasattr(self.trainer, 'datamodule') + and hasattr(self.trainer.datamodule, 'plot') and self.logger - and hasattr(self.logger, "experiment") - and hasattr(self.logger.experiment, "add_figure") + and hasattr(self.logger, 'experiment') + and hasattr(self.logger.experiment, 'add_figure') ): datamodule = self.trainer.datamodule - if self.target_key == "mask": + if self.target_key == 'mask': y = y.squeeze(dim=1) y_hat = y_hat.squeeze(dim=1) - batch["prediction"] = y_hat - for key in ["image", self.target_key, "prediction"]: + batch['prediction'] = y_hat + for key in ['image', self.target_key, 'prediction']: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] @@ -220,7 +220,7 @@ def validation_step( if fig: summary_writer = self.logger.experiment summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step + f'image/{batch_idx}', fig, global_step=self.global_step ) plt.close() @@ -232,7 +232,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - x = batch["image"] + x = batch['image'] batch_size = x.shape[0] # TODO: remove .to(...) once we have a real pixelwise regression dataset y = batch[self.target_key].to(torch.float) @@ -240,7 +240,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) loss = self.criterion(y_hat, y) - self.log("test_loss", loss, batch_size=batch_size) + self.log('test_loss', loss, batch_size=batch_size) self.test_metrics(y_hat, y) self.log_dict(self.test_metrics, batch_size=batch_size) @@ -257,7 +257,7 @@ def predict_step( Returns: Output predicted probabilities. """ - x = batch["image"] + x = batch['image'] y_hat: Tensor = self(x) return y_hat @@ -268,31 +268,31 @@ class PixelwiseRegressionTask(RegressionTask): .. versionadded:: 0.5 """ - target_key = "mask" + target_key = 'mask' def configure_models(self) -> None: """Initialize the model.""" weights = self.weights - if self.hparams["model"] == "unet": + if self.hparams['model'] == 'unet': self.model = smp.Unet( - encoder_name=self.hparams["backbone"], - encoder_weights="imagenet" if weights is True else None, - in_channels=self.hparams["in_channels"], + encoder_name=self.hparams['backbone'], + encoder_weights='imagenet' if weights is True else None, + in_channels=self.hparams['in_channels'], classes=1, ) - elif self.hparams["model"] == "deeplabv3+": + elif self.hparams['model'] == 'deeplabv3+': self.model = smp.DeepLabV3Plus( - encoder_name=self.hparams["backbone"], - encoder_weights="imagenet" if weights is True else None, - in_channels=self.hparams["in_channels"], + encoder_name=self.hparams['backbone'], + encoder_weights='imagenet' if weights is True else None, + in_channels=self.hparams['in_channels'], classes=1, ) - elif self.hparams["model"] == "fcn": + elif self.hparams['model'] == 'fcn': self.model = FCN( - in_channels=self.hparams["in_channels"], + in_channels=self.hparams['in_channels'], classes=1, - num_filters=self.hparams["num_filters"], + num_filters=self.hparams['num_filters'], ) else: raise ValueError( @@ -300,7 +300,7 @@ def configure_models(self) -> None: "Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." ) - if self.hparams["model"] != "fcn": + if self.hparams['model'] != 'fcn': if weights and weights is not True: if isinstance(weights, WeightsEnum): state_dict = weights.get_state_dict(progress=True) @@ -311,17 +311,17 @@ def configure_models(self) -> None: self.model.encoder.load_state_dict(state_dict) # Freeze backbone - if self.hparams.get("freeze_backbone", False) and self.hparams["model"] in [ - "unet", - "deeplabv3+", + if self.hparams.get('freeze_backbone', False) and self.hparams['model'] in [ + 'unet', + 'deeplabv3+', ]: for param in self.model.encoder.parameters(): param.requires_grad = False # Freeze decoder - if self.hparams.get("freeze_decoder", False) and self.hparams["model"] in [ - "unet", - "deeplabv3+", + if self.hparams.get('freeze_decoder', False) and self.hparams['model'] in [ + 'unet', + 'deeplabv3+', ]: for param in self.model.decoder.parameters(): param.requires_grad = False diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 502826db31f..103a5d0e860 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -26,13 +26,13 @@ class SemanticSegmentationTask(BaseTask): def __init__( self, - model: str = "unet", - backbone: str = "resnet50", + model: str = 'unet', + backbone: str = 'resnet50', weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, num_classes: int = 1000, num_filters: int = 3, - loss: str = "ce", + loss: str = 'ce', class_weights: Tensor | None = None, ignore_index: int | None = None, lr: float = 1e-3, @@ -88,7 +88,7 @@ class and used with 'ce' loss. The *ignore_index* parameter now works for jaccard loss. """ self.weights = weights - super().__init__(ignore="weights") + super().__init__(ignore='weights') def configure_models(self) -> None: """Initialize the model. @@ -96,28 +96,28 @@ def configure_models(self) -> None: Raises: ValueError: If *model* is invalid. """ - model: str = self.hparams["model"] - backbone: str = self.hparams["backbone"] + model: str = self.hparams['model'] + backbone: str = self.hparams['backbone'] weights = self.weights - in_channels: int = self.hparams["in_channels"] - num_classes: int = self.hparams["num_classes"] - num_filters: int = self.hparams["num_filters"] + in_channels: int = self.hparams['in_channels'] + num_classes: int = self.hparams['num_classes'] + num_filters: int = self.hparams['num_filters'] - if model == "unet": + if model == 'unet': self.model = smp.Unet( encoder_name=backbone, - encoder_weights="imagenet" if weights is True else None, + encoder_weights='imagenet' if weights is True else None, in_channels=in_channels, classes=num_classes, ) - elif model == "deeplabv3+": + elif model == 'deeplabv3+': self.model = smp.DeepLabV3Plus( encoder_name=backbone, - encoder_weights="imagenet" if weights is True else None, + encoder_weights='imagenet' if weights is True else None, in_channels=in_channels, classes=num_classes, ) - elif model == "fcn": + elif model == 'fcn': self.model = FCN( in_channels=in_channels, classes=num_classes, num_filters=num_filters ) @@ -127,7 +127,7 @@ def configure_models(self) -> None: "Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." ) - if model != "fcn": + if model != 'fcn': if weights and weights is not True: if isinstance(weights, WeightsEnum): state_dict = weights.get_state_dict(progress=True) @@ -138,12 +138,12 @@ def configure_models(self) -> None: self.model.encoder.load_state_dict(state_dict) # Freeze backbone - if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]: + if self.hparams['freeze_backbone'] and model in ['unet', 'deeplabv3+']: for param in self.model.encoder.parameters(): param.requires_grad = False # Freeze decoder - if self.hparams["freeze_decoder"] and model in ["unet", "deeplabv3+"]: + if self.hparams['freeze_decoder'] and model in ['unet', 'deeplabv3+']: for param in self.model.decoder.parameters(): param.requires_grad = False @@ -153,24 +153,24 @@ def configure_losses(self) -> None: Raises: ValueError: If *loss* is invalid. """ - loss: str = self.hparams["loss"] - ignore_index = self.hparams["ignore_index"] - if loss == "ce": + loss: str = self.hparams['loss'] + ignore_index = self.hparams['ignore_index'] + if loss == 'ce': ignore_value = -1000 if ignore_index is None else ignore_index self.criterion = nn.CrossEntropyLoss( - ignore_index=ignore_value, weight=self.hparams["class_weights"] + ignore_index=ignore_value, weight=self.hparams['class_weights'] ) - elif loss == "jaccard": + elif loss == 'jaccard': # JaccardLoss requires a list of classes to use instead of a class # index to ignore. classes = [ - i for i in range(self.hparams["num_classes"]) if i != ignore_index + i for i in range(self.hparams['num_classes']) if i != ignore_index ] - self.criterion = smp.losses.JaccardLoss(mode="multiclass", classes=classes) - elif loss == "focal": + self.criterion = smp.losses.JaccardLoss(mode='multiclass', classes=classes) + elif loss == 'focal': self.criterion = smp.losses.FocalLoss( - "multiclass", ignore_index=ignore_index, normalized=True + 'multiclass', ignore_index=ignore_index, normalized=True ) else: raise ValueError( @@ -193,24 +193,24 @@ def configure_metrics(self) -> None: * 'Macro' averaging, not used here, gives equal weight to each class, useful for balanced performance assessment across imbalanced classes. """ - num_classes: int = self.hparams["num_classes"] - ignore_index: int | None = self.hparams["ignore_index"] + num_classes: int = self.hparams['num_classes'] + ignore_index: int | None = self.hparams['ignore_index'] metrics = MetricCollection( [ MulticlassAccuracy( num_classes=num_classes, ignore_index=ignore_index, - multidim_average="global", - average="micro", + multidim_average='global', + average='micro', ), MulticlassJaccardIndex( - num_classes=num_classes, ignore_index=ignore_index, average="micro" + num_classes=num_classes, ignore_index=ignore_index, average='micro' ), ] ) - self.train_metrics = metrics.clone(prefix="train_") - self.val_metrics = metrics.clone(prefix="val_") - self.test_metrics = metrics.clone(prefix="test_") + self.train_metrics = metrics.clone(prefix='train_') + self.val_metrics = metrics.clone(prefix='val_') + self.test_metrics = metrics.clone(prefix='test_') def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 @@ -225,12 +225,12 @@ def training_step( Returns: The loss tensor. """ - x = batch["image"] - y = batch["mask"] + x = batch['image'] + y = batch['mask'] batch_size = x.shape[0] y_hat = self(x) loss: Tensor = self.criterion(y_hat, y) - self.log("train_loss", loss, batch_size=batch_size) + self.log('train_loss', loss, batch_size=batch_size) self.train_metrics(y_hat, y) self.log_dict(self.train_metrics, batch_size=batch_size) return loss @@ -245,26 +245,26 @@ def validation_step( batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - x = batch["image"] - y = batch["mask"] + x = batch['image'] + y = batch['mask'] batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log("val_loss", loss, batch_size=batch_size) + self.log('val_loss', loss, batch_size=batch_size) self.val_metrics(y_hat, y) self.log_dict(self.val_metrics, batch_size=batch_size) if ( batch_idx < 10 - and hasattr(self.trainer, "datamodule") - and hasattr(self.trainer.datamodule, "plot") + and hasattr(self.trainer, 'datamodule') + and hasattr(self.trainer.datamodule, 'plot') and self.logger - and hasattr(self.logger, "experiment") - and hasattr(self.logger.experiment, "add_figure") + and hasattr(self.logger, 'experiment') + and hasattr(self.logger.experiment, 'add_figure') ): datamodule = self.trainer.datamodule - batch["prediction"] = y_hat.argmax(dim=1) - for key in ["image", "mask", "prediction"]: + batch['prediction'] = y_hat.argmax(dim=1) + for key in ['image', 'mask', 'prediction']: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] @@ -277,7 +277,7 @@ def validation_step( if fig: summary_writer = self.logger.experiment summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step + f'image/{batch_idx}', fig, global_step=self.global_step ) plt.close() @@ -289,12 +289,12 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ - x = batch["image"] - y = batch["mask"] + x = batch['image'] + y = batch['mask'] batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log("test_loss", loss, batch_size=batch_size) + self.log('test_loss', loss, batch_size=batch_size) self.test_metrics(y_hat, y) self.log_dict(self.test_metrics, batch_size=batch_size) @@ -311,6 +311,6 @@ def predict_step( Returns: Output predicted probabilities. """ - x = batch["image"] + x = batch['image'] y_hat: Tensor = self(x).softmax(dim=1) return y_hat diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index 44e517731e2..ba9443e9191 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -49,7 +49,7 @@ def simclr_augmentations(size: int, weights: Tensor) -> nn.Module: K.RandomContrast(contrast=(0.2, 1.8), p=0.8), T.RandomGrayscale(weights=weights, p=0.2), K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2)), - data_keys=["input"], + data_keys=['input'], ) @@ -68,11 +68,11 @@ class SimCLRTask(BaseTask): .. versionadded:: 0.5 """ - monitor = "train_loss" + monitor = 'train_loss' def __init__( self, - model: str = "resnet50", + model: str = 'resnet50', weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, version: int = 2, @@ -125,17 +125,17 @@ def __init__( assert version in range(1, 3) if version == 1: if layers > 2: - warnings.warn("SimCLR v1 only uses 2 layers in its projection head") + warnings.warn('SimCLR v1 only uses 2 layers in its projection head') if memory_bank_size > 0: - warnings.warn("SimCLR v1 does not use a memory bank") + warnings.warn('SimCLR v1 does not use a memory bank') elif version == 2: if layers == 2: - warnings.warn("SimCLR v2 uses 3+ layers in its projection head") + warnings.warn('SimCLR v2 uses 3+ layers in its projection head') if memory_bank_size == 0: - warnings.warn("SimCLR v2 uses a memory bank") + warnings.warn('SimCLR v2 uses a memory bank') self.weights = weights - super().__init__(ignore=["weights", "augmentations"]) + super().__init__(ignore=['weights', 'augmentations']) grayscale_weights = grayscale_weights or torch.ones(in_channels) self.augmentations = augmentations or simclr_augmentations( @@ -148,8 +148,8 @@ def configure_models(self) -> None: # Create backbone self.backbone = timm.create_model( - self.hparams["model"], - in_chans=self.hparams["in_channels"], + self.hparams['model'], + in_chans=self.hparams['in_channels'], num_classes=0, pretrained=weights is True, ) @@ -166,16 +166,16 @@ def configure_models(self) -> None: # Create projection head input_dim = self.backbone.num_features - if self.hparams["hidden_dim"] is None: - self.hparams["hidden_dim"] = input_dim - if self.hparams["output_dim"] is None: - self.hparams["output_dim"] = input_dim + if self.hparams['hidden_dim'] is None: + self.hparams['hidden_dim'] = input_dim + if self.hparams['output_dim'] is None: + self.hparams['output_dim'] = input_dim self.projection_head = SimCLRProjectionHead( input_dim, - self.hparams["hidden_dim"], - self.hparams["output_dim"], - self.hparams["layers"], + self.hparams['hidden_dim'], + self.hparams['output_dim'], + self.hparams['layers'], ) # Initialize moving average of output @@ -189,16 +189,16 @@ def configure_losses(self) -> None: """Initialize the loss criterion.""" try: self.criterion = NTXentLoss( - self.hparams["temperature"], - (self.hparams["memory_bank_size"], self.hparams["output_dim"]), - self.hparams["gather_distributed"], + self.hparams['temperature'], + (self.hparams['memory_bank_size'], self.hparams['output_dim']), + self.hparams['gather_distributed'], ) except TypeError: # lightly 1.4.24 and older self.criterion = NTXentLoss( - self.hparams["temperature"], - self.hparams["memory_bank_size"], - self.hparams["gather_distributed"], + self.hparams['temperature'], + self.hparams['memory_bank_size'], + self.hparams['gather_distributed'], ) def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: @@ -230,10 +230,10 @@ def training_step( Raises: AssertionError: If channel dimensions are incorrect. """ - x = batch["image"] + x = batch['image'] batch_size = x.shape[0] - in_channels: int = self.hparams["in_channels"] + in_channels: int = self.hparams['in_channels'] assert x.size(1) == in_channels or x.size(1) == 2 * in_channels if x.size(1) == in_channels: @@ -260,8 +260,8 @@ def training_step( output_std = torch.mean(output_std, dim=0) self.avg_output_std = 0.9 * self.avg_output_std + (1 - 0.9) * output_std.item() - self.log("train_ssl_std", self.avg_output_std, batch_size=batch_size) - self.log("train_loss", loss, batch_size=batch_size) + self.log('train_ssl_std', self.avg_output_std, batch_size=batch_size) + self.log('train_loss', loss, batch_size=batch_size) return loss @@ -280,7 +280,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> N def configure_optimizers( self, - ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig": + ) -> 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig': """Initialize the optimizer and learning rate scheduler. Returns: @@ -289,13 +289,13 @@ def configure_optimizers( # Original paper uses LARS optimizer, but this is not defined in PyTorch optimizer = Adam( self.parameters(), - lr=self.hparams["lr"], - weight_decay=self.hparams["weight_decay"], + lr=self.hparams['lr'], + weight_decay=self.hparams['weight_decay'], ) max_epochs = 200 if self.trainer and self.trainer.max_epochs: max_epochs = self.trainer.max_epochs - if self.hparams["version"] == 1: + if self.hparams['version'] == 1: warmup_epochs = 10 else: warmup_epochs = int(max_epochs * 0.05) @@ -308,6 +308,6 @@ def configure_optimizers( milestones=[warmup_epochs], ) return { - "optimizer": optimizer, - "lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor}, + 'optimizer': optimizer, + 'lr_scheduler': {'scheduler': scheduler, 'monitor': self.monitor}, } diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index 7332cd1d119..10da4ba452f 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -13,7 +13,7 @@ from torch.nn.modules import Conv2d, Module -def extract_backbone(path: str) -> tuple[str, "OrderedDict[str, Tensor]"]: +def extract_backbone(path: str) -> tuple[str, 'OrderedDict[str, Tensor]']: """Extracts a backbone from a lightning checkpoint file. Args: @@ -29,26 +29,26 @@ def extract_backbone(path: str) -> tuple[str, "OrderedDict[str, Tensor]"]: .. versionchanged:: 0.4 Renamed from *extract_encoder* to *extract_backbone* """ - checkpoint = torch.load(path, map_location=torch.device("cpu")) - if "model" in checkpoint["hyper_parameters"]: - name = checkpoint["hyper_parameters"]["model"] - state_dict = checkpoint["state_dict"] - state_dict = OrderedDict({k: v for k, v in state_dict.items() if "model." in k}) + checkpoint = torch.load(path, map_location=torch.device('cpu')) + if 'model' in checkpoint['hyper_parameters']: + name = checkpoint['hyper_parameters']['model'] + state_dict = checkpoint['state_dict'] + state_dict = OrderedDict({k: v for k, v in state_dict.items() if 'model.' in k}) state_dict = OrderedDict( - {k.replace("model.", ""): v for k, v in state_dict.items()} + {k.replace('model.', ''): v for k, v in state_dict.items()} ) - elif "backbone" in checkpoint["hyper_parameters"]: - name = checkpoint["hyper_parameters"]["backbone"] - state_dict = checkpoint["state_dict"] + elif 'backbone' in checkpoint['hyper_parameters']: + name = checkpoint['hyper_parameters']['backbone'] + state_dict = checkpoint['state_dict'] state_dict = OrderedDict( - {k: v for k, v in state_dict.items() if "model.backbone.model" in k} + {k: v for k, v in state_dict.items() if 'model.backbone.model' in k} ) state_dict = OrderedDict( - {k.replace("model.backbone.model.", ""): v for k, v in state_dict.items()} + {k.replace('model.backbone.model.', ''): v for k, v in state_dict.items()} ) else: raise ValueError( - "Unknown checkpoint task. Only backbone or model extraction is supported" + 'Unknown checkpoint task. Only backbone or model extraction is supported' ) return name, state_dict @@ -67,12 +67,12 @@ def _get_input_layer_name_and_module(model: Module) -> tuple[str, Module]: keys.append(name) children = list(module.named_children()) - key = ".".join(keys) + key = '.'.join(keys) return key, module def load_state_dict( - model: Module, state_dict: "OrderedDict[str, Tensor]" + model: Module, state_dict: 'OrderedDict[str, Tensor]' ) -> tuple[list[str], list[str]]: """Load pretrained resnet weights to a model. @@ -89,7 +89,7 @@ def load_state_dict( """ input_module_key, input_module = _get_input_layer_name_and_module(model) in_channels = input_module.in_channels - expected_in_channels = state_dict[input_module_key + ".weight"].shape[1] + expected_in_channels = state_dict[input_module_key + '.weight'].shape[1] output_module_key, output_module = list(model.named_children())[-1] if isinstance(output_module, nn.Identity): @@ -97,24 +97,24 @@ def load_state_dict( else: num_classes = output_module.out_features expected_num_classes = None - if output_module_key + ".weight" in state_dict: - expected_num_classes = state_dict[output_module_key + ".weight"].shape[0] + if output_module_key + '.weight' in state_dict: + expected_num_classes = state_dict[output_module_key + '.weight'].shape[0] if in_channels != expected_in_channels: warnings.warn( - f"input channels {in_channels} != input channels in pretrained" - f" model {expected_in_channels}. Overriding with new input channels" + f'input channels {in_channels} != input channels in pretrained' + f' model {expected_in_channels}. Overriding with new input channels' ) - del state_dict[input_module_key + ".weight"] + del state_dict[input_module_key + '.weight'] if expected_num_classes and num_classes != expected_num_classes: warnings.warn( - f"num classes {num_classes} != num classes in pretrained model" - f" {expected_num_classes}. Overriding with new num classes" + f'num classes {num_classes} != num classes in pretrained model' + f' {expected_num_classes}. Overriding with new num classes' ) del ( - state_dict[output_module_key + ".weight"], - state_dict[output_module_key + ".bias"], + state_dict[output_module_key + '.weight'], + state_dict[output_module_key + '.bias'], ) missing_keys: list[str] @@ -167,7 +167,7 @@ def reinit_initial_conv_layer( bias=use_bias, padding_mode=layer.padding_mode, ) - nn.init.kaiming_normal_(new_layer.weight, mode="fan_out", nonlinearity="relu") + nn.init.kaiming_normal_(new_layer.weight, mode='fan_out', nonlinearity='relu') if keep_rgb_weights: new_layer.weight.data[:, :3, :, :] = w_old diff --git a/torchgeo/transforms/__init__.py b/torchgeo/transforms/__init__.py index 92ed1ed075e..5a0f9ee3392 100644 --- a/torchgeo/transforms/__init__.py +++ b/torchgeo/transforms/__init__.py @@ -23,20 +23,20 @@ from .transforms import AugmentationSequential __all__ = ( - "AppendBNDVI", - "AppendGBNDVI", - "AppendGNDVI", - "AppendGRNDVI", - "AppendNBR", - "AppendNDBI", - "AppendNDRE", - "AppendNDSI", - "AppendNDVI", - "AppendNDWI", - "AppendNormalizedDifferenceIndex", - "AppendRBNDVI", - "AppendSWI", - "AppendTriBandNormalizedDifferenceIndex", - "AugmentationSequential", - "RandomGrayscale", + 'AppendBNDVI', + 'AppendGBNDVI', + 'AppendGNDVI', + 'AppendGRNDVI', + 'AppendNBR', + 'AppendNDBI', + 'AppendNDRE', + 'AppendNDSI', + 'AppendNDVI', + 'AppendNDWI', + 'AppendNormalizedDifferenceIndex', + 'AppendRBNDVI', + 'AppendSWI', + 'AppendTriBandNormalizedDifferenceIndex', + 'AugmentationSequential', + 'RandomGrayscale', ) diff --git a/torchgeo/transforms/color.py b/torchgeo/transforms/color.py index efc7055feba..72d4f4483f4 100644 --- a/torchgeo/transforms/color.py +++ b/torchgeo/transforms/color.py @@ -48,7 +48,7 @@ def __init__( # Rescale to sum to 1 weights /= weights.sum() - self.flags = {"weights": weights} + self.flags = {'weights': weights} def apply_transform( self, @@ -68,7 +68,7 @@ def apply_transform( Returns: The augmented input. """ - weights = flags["weights"][..., :, None, None].to(input.device) + weights = flags['weights'][..., :, None, None].to(input.device) out = input * weights out = out.sum(dim=-3) out = out.unsqueeze(-3).expand(input.shape) diff --git a/torchgeo/transforms/indices.py b/torchgeo/transforms/indices.py index 3f52fd79630..d04385d6ad1 100644 --- a/torchgeo/transforms/indices.py +++ b/torchgeo/transforms/indices.py @@ -35,7 +35,7 @@ def __init__(self, index_a: int, index_b: int) -> None: index_b: difference band channel index """ super().__init__(p=1) - self.flags = {"index_a": index_a, "index_b": index_b} + self.flags = {'index_a': index_a, 'index_b': index_b} def apply_transform( self, @@ -55,8 +55,8 @@ def apply_transform( Returns: the augmented input """ - band_a = input[..., flags["index_a"], :, :] - band_b = input[..., flags["index_b"], :, :] + band_a = input[..., flags['index_a'], :, :] + band_b = input[..., flags['index_b'], :, :] ndi = (band_a - band_b) / (band_a + band_b + _EPSILON) ndi = torch.unsqueeze(ndi, -3) input = torch.cat((input, ndi), dim=-3) @@ -310,7 +310,7 @@ def __init__(self, index_a: int, index_b: int, index_c: int) -> None: index_c: difference band channel index of component 2 """ super().__init__(p=1) - self.flags = {"index_a": index_a, "index_b": index_b, "index_c": index_c} + self.flags = {'index_a': index_a, 'index_b': index_b, 'index_c': index_c} def apply_transform( self, @@ -330,9 +330,9 @@ def apply_transform( Returns: the augmented input """ - band_a = input[..., flags["index_a"], :, :] - band_b = input[..., flags["index_b"], :, :] - band_c = input[..., flags["index_c"], :, :] + band_a = input[..., flags['index_a'], :, :] + band_b = input[..., flags['index_b'], :, :] + band_c = input[..., flags['index_c'], :, :] band_d = band_b + band_c tbndi = (band_a - band_d) / (band_a + band_d + _EPSILON) tbndi = torch.unsqueeze(tbndi, -3) diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index 3a4cf44d60b..87484e75730 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -44,12 +44,12 @@ def __init__( keys: list[str] = [] for key in data_keys: - if key.startswith("image"): - keys.append("input") - elif key == "boxes": - keys.append("bbox") - elif key == "masks": - keys.append("mask") + if key.startswith('image'): + keys.append('input') + elif key == 'boxes': + keys.append('bbox') + elif key == 'masks': + keys.append('mask') else: keys.append(key) @@ -71,17 +71,17 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch[key] = batch[key].float() # Convert shape of boxes from [N, 4] to [N, 4, 2] - if "boxes" in batch and ( - isinstance(batch["boxes"], list) or batch["boxes"].ndim == 2 + if 'boxes' in batch and ( + isinstance(batch['boxes'], list) or batch['boxes'].ndim == 2 ): - batch["boxes"] = Boxes.from_tensor(batch["boxes"]).data + batch['boxes'] = Boxes.from_tensor(batch['boxes']).data # Kornia requires masks to have a channel dimension - if "mask" in batch and batch["mask"].ndim == 3: - batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") + if 'mask' in batch and batch['mask'].ndim == 3: + batch['mask'] = rearrange(batch['mask'], 'b h w -> b () h w') - if "masks" in batch and batch["masks"].ndim == 3: - batch["masks"] = rearrange(batch["masks"], "c h w -> () c h w") + if 'masks' in batch and batch['masks'].ndim == 3: + batch['masks'] = rearrange(batch['masks'], 'c h w -> () c h w') inputs = [batch[k] for k in self.data_keys] outputs_list: Tensor | list[Tensor] = self.augs(*inputs) @@ -98,14 +98,14 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch[key] = batch[key].to(dtype[key]) # Convert boxes to default [N, 4] - if "boxes" in batch: - batch["boxes"] = Boxes(batch["boxes"]).to_tensor(mode="xyxy") # type:ignore[assignment] + if 'boxes' in batch: + batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy') # type:ignore[assignment] # Torchmetrics does not support masks with a channel dimension - if "mask" in batch and batch["mask"].shape[1] == 1: - batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - if "masks" in batch and batch["masks"].ndim == 4: - batch["masks"] = rearrange(batch["masks"], "() c h w -> c h w") + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') + if 'masks' in batch and batch['masks'].ndim == 4: + batch['masks'] = rearrange(batch['masks'], '() c h w -> c h w') return batch @@ -122,7 +122,7 @@ def __init__(self, size: tuple[int, int], num: int) -> None: """ super().__init__(p=1) self._param_generator: _NCropGenerator = _NCropGenerator(size, num) - self.flags = {"size": size, "num": num} + self.flags = {'size': size, 'num': num} def compute_transformation( self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any] @@ -159,8 +159,8 @@ def apply_transform( the augmented input """ out = [] - for i in range(flags["num"]): - out.append(crop_by_indices(input, params["src"][i], flags["size"])) + for i in range(flags['num']): + out.append(crop_by_indices(input, params['src'][i], flags['size'])) return torch.cat(out) @@ -193,10 +193,10 @@ def forward( for _ in range(self.num): out.append(super().forward(batch_shape, same_on_batch)) return { - "src": torch.stack([x["src"] for x in out]), - "dst": torch.stack([x["dst"] for x in out]), - "input_size": out[0]["input_size"], - "output_size": out[0]["output_size"], + 'src': torch.stack([x['src'] for x in out]), + 'dst': torch.stack([x['dst'] for x in out]), + 'input_size': out[0]['input_size'], + 'output_size': out[0]['output_size'], } @@ -221,10 +221,10 @@ def __init__( """ super().__init__(p=1) self.flags = { - "window_size": window_size, - "stride": stride if stride is not None else window_size, - "padding": padding, - "keepdim": keepdim, + 'window_size': window_size, + 'stride': stride if stride is not None else window_size, + 'padding': padding, + 'keepdim': keepdim, } def compute_transformation( @@ -263,12 +263,12 @@ def apply_transform( """ out = extract_tensor_patches( input, - window_size=flags["window_size"], - stride=flags["stride"], - padding=flags["padding"], + window_size=flags['window_size'], + stride=flags['stride'], + padding=flags['padding'], ) - if flags["keepdim"]: - out = rearrange(out, "b t c h w -> (b t) c h w") + if flags['keepdim']: + out = rearrange(out, 'b t c h w -> (b t) c h w') return out