Skip to content

Commit

Permalink
♻️ Use xarray.merge with join="override" in collate functions
Browse files Browse the repository at this point in the history
Refactoring the xarray collate functions to use `xr.merge` instead of the dictionary style way of appending data variables to an xarray.Dataset. Solution adapted from 7787f8e in #62 that is more robust to images being cut off due to rounding issues as with 6b18934 in #31. Downside is the need to verbosely rename the xarray.DataArray objects, and handle some conflicting coordinate labels.
  • Loading branch information
weiji14 committed Oct 2, 2022
1 parent 95b0e20 commit d4c1bf8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
16 changes: 8 additions & 8 deletions docs/stacking.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ dp_sen1_copdem
Next, use {py:class}`torchdata.datapipes.iter.Collator` (functional name:
`collate`) to convert 🤸 the tuple of {py:class}`xarray.DataArray` objects into
an {py:class}`xarray.Dataset` 🧊, similar to what was done in
{doc}`./object-detection-boxes`.
{doc}`./vector-segmentation-masks`.

```{code-cell}
def sardem_collate_fn(sar_and_dem: tuple) -> xr.Dataset:
Expand All @@ -333,12 +333,13 @@ def sardem_collate_fn(sar_and_dem: tuple) -> xr.Dataset:
# Turn 2 xr.DataArray objects into 1 xr.Dataset with multiple data vars
sar, dem = sar_and_dem
# Initialize xr.Dataset with VH and VV channels
dataset: xr.Dataset = sar.sel(band="vh").to_dataset(name="vh")
dataset["vv"] = sar.sel(band="vv")
# Add Copernicus DEM mosaic as another layer
dataset["dem"] = dem.squeeze()
# Create datacube with VH and VV channels from SAR + Copernicus DEM mosaic
da_vh: xr.DataArray = sar.sel(band="vh", drop=True).rename("vh")
da_vv: xr.DataArray = sar.sel(band="vv", drop=True).rename("vv")
da_dem: xr.DataArray = (
dem.sel(band="data").drop_vars(names=["proj:epsg", "platform"]).rename("dem")
)
dataset: xr.Dataset = xr.merge(objects=[da_vh, da_vv, da_dem], join="override")
return dataset
```
Expand Down Expand Up @@ -366,7 +367,6 @@ Visualize the DataPipe graph ⛓️ too for good measure.
torchdata.datapipes.utils.to_graph(dp=dp_vhvvdem_dataset)
```


### Rasterize target labels to datacube extent 🏷️

The landslide polygons 🔶 can now be rasterized and added as another layer to
Expand Down
8 changes: 5 additions & 3 deletions docs/vector-segmentation-masks.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ geographic coordinates by default (OGC:CRS84). To make the pixels more equal
area, we can project it to a 🌏 local projected coordinate system instead.

```{code-cell}
def reproject_to_local_utm(dataarray: xr.DataArray, resolution: float=100.0) -> xr.DataArray:
def reproject_to_local_utm(dataarray: xr.DataArray, resolution: float=80.0) -> xr.DataArray:
"""
Reproject an xarray.DataArray grid from OGC:CRS84 to a local UTM coordinate
reference system.
Expand Down Expand Up @@ -382,8 +382,10 @@ def xr_collate_fn(image_and_mask: tuple) -> xr.Dataset:
"""
# Turn 2 xr.DataArray objects into 1 xr.Dataset with multiple data vars
image, mask = image_and_mask
dataset: xr.Dataset = image.isel(band=0).to_dataset(name="image")
dataset["mask"] = mask
dataset: xr.Dataset = xr.merge(
objects=[image.isel(band=0).rename("image"), mask.rename("mask")],
join="override",
)
# Clip dataset to bounding box extent of where labels are
mask_extent: tuple = mask.where(cond=mask == 1, drop=True).rio.bounds()
Expand Down

0 comments on commit d4c1bf8

Please sign in to comment.