Skip to content

Commit

Permalink
♻️ Use xarray.merge with join="override" in collate functions (#72)
Browse files Browse the repository at this point in the history
* ♻️ Use xarray.merge with join="override" in collate functions

Refactoring the xarray collate functions to use `xr.merge` instead of the dictionary style way of appending data variables to an xarray.Dataset. Solution adapted from 7787f8e in #62 that is more robust to images being cut off due to rounding issues as with 6b18934 in #31. Downside is the need to verbosely rename the xarray.DataArray objects, and handle some conflicting coordinate labels.

* πŸ“ Minor tweaks to vector segmentation mask walkthrough

A few whitespace fixes and fixing some DataPipe references.
  • Loading branch information
weiji14 authored Oct 2, 2022
1 parent 95b0e20 commit d83fedf
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
16 changes: 8 additions & 8 deletions docs/stacking.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ dp_sen1_copdem
Next, use {py:class}`torchdata.datapipes.iter.Collator` (functional name:
`collate`) to convert 🀸 the tuple of {py:class}`xarray.DataArray` objects into
an {py:class}`xarray.Dataset` 🧊, similar to what was done in
{doc}`./object-detection-boxes`.
{doc}`./vector-segmentation-masks`.

```{code-cell}
def sardem_collate_fn(sar_and_dem: tuple) -> xr.Dataset:
Expand All @@ -333,12 +333,13 @@ def sardem_collate_fn(sar_and_dem: tuple) -> xr.Dataset:
# Turn 2 xr.DataArray objects into 1 xr.Dataset with multiple data vars
sar, dem = sar_and_dem
# Initialize xr.Dataset with VH and VV channels
dataset: xr.Dataset = sar.sel(band="vh").to_dataset(name="vh")
dataset["vv"] = sar.sel(band="vv")
# Add Copernicus DEM mosaic as another layer
dataset["dem"] = dem.squeeze()
# Create datacube with VH and VV channels from SAR + Copernicus DEM mosaic
da_vh: xr.DataArray = sar.sel(band="vh", drop=True).rename("vh")
da_vv: xr.DataArray = sar.sel(band="vv", drop=True).rename("vv")
da_dem: xr.DataArray = (
dem.sel(band="data").drop_vars(names=["proj:epsg", "platform"]).rename("dem")
)
dataset: xr.Dataset = xr.merge(objects=[da_vh, da_vv, da_dem], join="override")
return dataset
```
Expand Down Expand Up @@ -366,7 +367,6 @@ Visualize the DataPipe graph ⛓️ too for good measure.
torchdata.datapipes.utils.to_graph(dp=dp_vhvvdem_dataset)
```


### Rasterize target labels to datacube extent 🏷️

The landslide polygons πŸ”Ά can now be rasterized and added as another layer to
Expand Down
16 changes: 10 additions & 6 deletions docs/vector-segmentation-masks.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ Malaysia on 15 Dec 2019.
### Load and reproject image data πŸ”„

To keep things simple, we'll load just the VV channel into a DataPipe via
{py:class}`zen3geo.datapipes.rioxarray.RioXarrayReaderIterDataPipe` πŸ˜€.
{py:class}`zen3geo.datapipes.RioXarrayReader` (functional name:
`read_from_rioxarray`) πŸ˜€.

```{code-cell}
url = signed_item.assets["vv"].href
Expand All @@ -88,7 +89,7 @@ geographic coordinates by default (OGC:CRS84). To make the pixels more equal
area, we can project it to a 🌏 local projected coordinate system instead.

```{code-cell}
def reproject_to_local_utm(dataarray: xr.DataArray, resolution: float=100.0) -> xr.DataArray:
def reproject_to_local_utm(dataarray: xr.DataArray, resolution: float=80.0) -> xr.DataArray:
"""
Reproject an xarray.DataArray grid from OGC:CRS84 to a local UTM coordinate
reference system.
Expand Down Expand Up @@ -196,6 +197,7 @@ put it into a DataPipe called {py:class}`zen3geo.datapipes.PyogrioReader`
```{code-cell}
dp_shapes = torchdata.datapipes.iter.IterableWrapper(iterable=[shape_url])
dp_pyogrio = dp_shapes.read_from_pyogrio()
dp_pyogrio
```

This will take care of loading the shapefile into a
Expand Down Expand Up @@ -227,7 +229,7 @@ correspond to the zoomed in Sentinel-1 image plotted earlier above.
gdf.plot(figsize=(11.5, 9))
```

```{tip}
```{tip}
Make sure to understand your raster and vector datasets well first! Open the
files up in your favourite 🌐 Geographic Information System (GIS) tool, see how
they actually look like spatially. Then you'll have a better idea to decide on
Expand Down Expand Up @@ -382,8 +384,10 @@ def xr_collate_fn(image_and_mask: tuple) -> xr.Dataset:
"""
# Turn 2 xr.DataArray objects into 1 xr.Dataset with multiple data vars
image, mask = image_and_mask
dataset: xr.Dataset = image.isel(band=0).to_dataset(name="image")
dataset["mask"] = mask
dataset: xr.Dataset = xr.merge(
objects=[image.isel(band=0).rename("image"), mask.rename("mask")],
join="override",
)
# Clip dataset to bounding box extent of where labels are
mask_extent: tuple = mask.where(cond=mask == 1, drop=True).rio.bounds()
Expand Down Expand Up @@ -463,7 +467,7 @@ Pass the DataPipe into {py:class}`torch.utils.data.DataLoader` 🀾!
dataloader = torch.utils.data.DataLoader(dataset=dp_map)
for i, batch in enumerate(dataloader):
image, mask = batch
print(f"Batch {i} - image: {image.shape}, mask:{mask.shape}")
print(f"Batch {i} - image: {image.shape}, mask: {mask.shape}")
```

Now go train some flood water detection models 🌊🌊🌊
Expand Down

0 comments on commit d83fedf

Please sign in to comment.