diff --git a/torchgeo/datasets/stacapidataset.py b/torchgeo/datasets/stacapidataset.py
new file mode 100644
index 00000000000..ea0ae900598
--- /dev/null
+++ b/torchgeo/datasets/stacapidataset.py
@@ -0,0 +1,234 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""STACAPIDataset."""
+
+import sys
+from typing import Any, Callable, Dict, Optional, Sequence
+
+import matplotlib.pyplot as plt
+import planetary_computer as pc
+import stackstac
+import torch
+from pyproj import Transformer
+from pystac_client import Client
+from rasterio.crs import CRS
+from torch import Tensor
+
+from torchgeo.datasets.geo import GeoDataset
+from torchgeo.datasets.utils import BoundingBox
+
+
+class STACAPIDataset(GeoDataset):
+ """STACApiDataset.
+
+ SpatioTemporal Asset Catalogs (`STACs `_) are a way
+ to organize geospatial datasets. STAC APIs let you query huge STAC Catalogs by
+ date, time, and other metadata.
+
+
+ .. versionadded:: 0.3
+ """
+
+ sentinel_bands = [
+ "B01",
+ "B02",
+ "B03",
+ "B04",
+ "B05",
+ "B06",
+ "B07",
+ "B08",
+ "B8A",
+ "B09",
+ "B11",
+ "B12",
+ ]
+
+ def __init__( # type: ignore[no-untyped-def]
+ self,
+ root: str,
+ crs: Optional[CRS] = None,
+ res: Optional[float] = None,
+ bands: Sequence[str] = sentinel_bands,
+ is_image: bool = True,
+ api_endpoint: str = "https://planetarycomputer.microsoft.com/api/stac/v1",
+ transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ **query_parameters,
+ ) -> None:
+ """Initialize a new Dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ crs: :term:`coordinate reference system (CRS)` to warp to
+ (defaults to the CRS of the first file found)
+ res: resolution of the dataset in units of CRS
+ (defaults to the resolution of the first file found)
+ bands: sequence of of stac asset band names
+ is_image: if true, :meth:`__getitem__` uses `image` as sample key, `mask`
+ otherwise
+ api_endpoint: api for pystac Client to access
+ transforms: a function/transform that takes an input sample
+ and returns a transformed versio
+ query_parameters: parameters for the catalog to search, for an idea see
+
+ """
+ self.root = root
+ self.api_endpoint = api_endpoint
+ self.bands = bands
+ self.is_image = is_image
+
+ super().__init__(transforms)
+
+ catalog = Client.open(api_endpoint)
+
+ search = catalog.search(**query_parameters)
+
+ items = list(search.get_items())
+
+ if not items:
+ raise RuntimeError(
+ f"No items returned from search criteria: {query_parameters}"
+ )
+
+ epsg = items[0].properties["proj:epsg"]
+ src_crs = CRS.from_epsg(epsg)
+ if crs is None:
+ crs = src_crs
+
+ for i, item in enumerate(items):
+ minx, miny, maxx, maxy = item.bbox
+
+ transformer = Transformer.from_crs(4326, crs.to_epsg(), always_xy=True)
+ (minx, maxx), (miny, maxy) = transformer.transform(
+ [minx, maxx], [miny, maxy]
+ )
+ mint = 0
+ maxt = sys.maxsize
+ coords = (minx, maxx, miny, maxy, mint, maxt)
+ self.index.insert(i, coords, item)
+
+ self._crs = crs
+ self.res = res
+ self.items = items
+
+ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
+ """Retrieve image/mask and metadata indexed by query.
+
+ Args:
+ query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
+
+ Returns:
+ sample of image/mask and metadata at that index
+
+ Raises:
+ IndexError: if query is not found in the index
+ """
+ hits = self.index.intersection(tuple(query), objects=True)
+ items = [hit.object for hit in hits]
+
+ if not items:
+ raise IndexError(
+ f"query: {query} not found in index with bounds: {self.bounds}"
+ )
+
+ # suggested #
+ signed_items = [pc.sign(item).to_dict() for item in items]
+
+ stack = stackstac.stack(
+ signed_items,
+ assets=self.bands,
+ resolution=self.res,
+ epsg=self._crs.to_epsg(),
+ )
+
+ aoi = stack.loc[
+ ..., query.maxy : query.miny, query.minx : query.maxx # type: ignore[misc]
+ ]
+
+ data = aoi.compute(scheduler="single-threaded").data
+
+ # handle time dimension here
+ image: Tensor = torch.Tensor(data)
+
+ key = "image" if self.is_image else "mask"
+ sample = {key: image, "crs": self.crs, "bbox": query}
+
+ if self.transforms is not None:
+ sample = self.transforms(sample)
+
+ return sample
+
+ def plot(
+ self,
+ sample: Dict[str, Tensor],
+ show_titles: bool = True,
+ suptitle: Optional[str] = None,
+ ) -> plt.Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: a sample returned by :meth:`RasterDataset.__getitem__`
+ show_titles: flag indicating whether to show titles above each panel
+ suptitle: optional string to use as a suptitle
+
+ Returns:
+ a matplotlib Figure with the rendered sample
+ """
+ image = sample["image"].permute(1, 2, 0)
+ image = torch.clip(image / 10000, 0, 1) # type: ignore[attr-defined]
+
+ fig, ax = plt.subplots(1, 1, figsize=(4, 4))
+ ax.imshow(image)
+ ax.axis("off")
+
+ if show_titles:
+ ax.set_title("Image")
+
+ if suptitle is not None:
+ plt.suptitle(suptitle)
+
+ return fig
+
+
+if __name__ == "__main__":
+
+ area_of_interest = {
+ "type": "Polygon",
+ "coordinates": [
+ [
+ [-148.56536865234375, 60.80072385643073],
+ [-147.44338989257812, 60.80072385643073],
+ [-147.44338989257812, 61.18363894915102],
+ [-148.56536865234375, 61.18363894915102],
+ [-148.56536865234375, 60.80072385643073],
+ ]
+ ],
+ }
+
+ time_of_interest = "2019-06-01/2019-08-01"
+
+ collections = (["sentinel-2-l2a"],)
+ intersects = (area_of_interest,)
+ datetime = (time_of_interest,)
+ query = ({"eo:cloud_cover": {"lt": 10}},)
+
+ rgb_bands = ["B04", "B03", "B02"]
+ ds = STACAPIDataset(
+ "./data",
+ bands=rgb_bands,
+ collections=["sentinel-2-l2a"],
+ intersects=area_of_interest,
+ datetime=time_of_interest,
+ query={"eo:cloud_cover": {"lt": 10}},
+ )
+
+ minx = 420688.14962388354
+ maxx = 429392.15007465985
+ miny = 6769145.954634559
+ maxy = 6777492.989499866
+ mint = 0
+ maxt = 100000
+
+ bbox = BoundingBox(minx, maxx, miny, maxy, mint, maxt)
+ sample = ds[bbox]