Skip to content

Commit

Permalink
Trying to use a multiprocessing pool for extracting chips (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
munshkr committed Nov 14, 2021
1 parent ac4dd6d commit 6041f56
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 70 deletions.
186 changes: 123 additions & 63 deletions src/satproc/chips.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from functools import partial

# Workaround: Load fiona at the end to avoid segfault on box (???)
import fiona
Expand All @@ -14,7 +15,12 @@
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm

from satproc.utils import rescale_intensity, sliding_windows, write_chips_geojson
from satproc.utils import (
map_with_threads,
rescale_intensity,
sliding_windows,
write_chips_geojson,
)

__author__ = "Damián Silvani"
__copyright__ = "Dymaxion Labs"
Expand Down Expand Up @@ -222,6 +228,7 @@ def extract_chips(
skip_existing=True,
within=False,
windows_mode="whole_overlap",
num_jobs=1,
*,
size,
step_size,
Expand Down Expand Up @@ -264,6 +271,7 @@ def extract_chips(
polys_dict=polys_dict,
windows_mode=windows_mode,
skip_existing=skip_existing,
num_jobs=num_jobs,
)


Expand All @@ -285,6 +293,7 @@ def extract_chips_from_raster(
polys_dict=None,
windows_mode="whole_overlap",
boundary_mask=False,
num_jobs=1,
*,
size,
step_size,
Expand Down Expand Up @@ -351,75 +360,126 @@ def filter_fn(w, aoi):
# If rescaling, set nodata=0 (will rescale to uint8 1-255)
meta["nodata"] = 0

basename = os.path.basename(raster)
chips = []
for c, ((window, (i, j)), win_shape) in tqdm(
list(enumerate(window_and_shapes)), desc=f"{basename} windows", ascii=True
):
_logger.debug("%s %s", window, (i, j))

img_path = os.path.join(image_folder, f"{basename}_{i}_{j}.{type}")
mask_path = os.path.join(masks_folder, f"{basename}_{i}_{j}.{type}")
boundary_mask_path = os.path.join(
boundary_masks_folder, f"{basename}_{i}_{j}.{type}"
)
worker = partial(
_extract_chip,
raster=raster,
bands=bands,
type=type,
basename=basename,
boundary_mask=boundary_mask,
boundary_masks_folder=boundary_masks_folder,
classes=classes,
image_folder=image_folder,
label_property=label_property,
labels=labels,
mask_type=mask_type,
masks_folder=masks_folder,
meta=meta,
polys_dict=polys_dict,
rescale_mode=rescale_mode,
rescale_range=rescale_range,
skip_existing=skip_existing,
)
chips = map_with_threads(
list(enumerate(window_and_shapes)),
worker=worker,
num_jobs=num_jobs,
desc=f"{basename} windows",
)
chips = [chip for chip in chips if chip]

# Gather list of required files
required_files = {img_path}
if labels:
required_files.add(mask_path)
if boundary_mask:
required_files.add(boundary_mask_path)
if write_footprints:
geojson_path = os.path.join(output_dir, "{}.geojson".format(basename))
_logger.info("Write chips footprints GeoJSON at %s", geojson_path)
write_chips_geojson(
geojson_path, chips, type=type, crs=str(meta["crs"]), basename=basename
)

# If all files already exist and we are skipping existing files, continue
if skip_existing and all(os.path.exists(p) for p in required_files):
continue

img = ds.read(window=window)
img = np.nan_to_num(img)
img = np.array([img[b - 1, :, :] for b in bands])
def _extract_chip(
item,
*,
raster,
bands,
basename,
type=type,
boundary_mask,
boundary_masks_folder,
classes,
image_folder,
label_property,
labels,
mask_type,
masks_folder,
meta,
polys_dict,
rescale_mode,
rescale_range,
skip_existing,
):
c, ((window, (i, j)), win_shape) = item

_logger.debug("%s %s", window, (i, j))

if rescale_mode:
img = rescale_intensity(img, rescale_mode, rescale_range)
img_path = os.path.join(image_folder, f"{basename}_{i}_{j}.{type}")
mask_path = os.path.join(masks_folder, f"{basename}_{i}_{j}.{type}")
boundary_mask_path = os.path.join(
boundary_masks_folder, f"{basename}_{i}_{j}.{type}"
)

if type == "tif":
image_was_saved = write_tif(
img,
img_path,
# Gather list of required files
required_files = {img_path}
if labels:
required_files.add(mask_path)
if boundary_mask:
required_files.add(boundary_mask_path)

# If all files already exist and we are skipping existing files, continue
if skip_existing and all(os.path.exists(p) for p in required_files):
return

with rasterio.open(raster) as ds:
transform = ds.transform
img = ds.read(window=window)

img = np.nan_to_num(img)
img = np.array([img[b - 1, :, :] for b in bands])

if rescale_mode:
img = rescale_intensity(img, rescale_mode, rescale_range)

if type == "tif":
image_was_saved = write_tif(
img,
img_path,
window=window,
meta=meta.copy(),
transform=transform,
bands=bands,
)
else:
image_was_saved = write_image(img, img_path)

if image_was_saved:
chip = (win_shape, (c, i, j))

if labels:
if mask_type == "class":
keys = classes if classes is not None else polys_dict.keys()
multiband_chip_mask_by_classes(
classes=keys,
transform=transform,
window=window,
meta=meta.copy(),
transform=ds.transform,
bands=bands,
window_shape=win_shape,
polys_dict=polys_dict,
metadata=meta,
mask_path=mask_path,
boundary_mask=boundary_mask,
boundary_mask_path=boundary_mask_path,
label_property=label_property,
)
else:
image_was_saved = write_image(img, img_path)

if image_was_saved:
chip = (win_shape, (c, i, j))
chips.append(chip)

if labels:
if mask_type == "class":
keys = classes if classes is not None else polys_dict.keys()
multiband_chip_mask_by_classes(
classes=keys,
transform=ds.transform,
window=window,
window_shape=win_shape,
polys_dict=polys_dict,
metadata=meta,
mask_path=mask_path,
boundary_mask=boundary_mask,
boundary_mask_path=boundary_mask_path,
label_property=label_property,
)

if write_footprints:
geojson_path = os.path.join(output_dir, "{}.geojson".format(basename))
_logger.info("Write chips footprints GeoJSON at %s", geojson_path)
write_chips_geojson(
geojson_path, chips, type=type, crs=str(meta["crs"]), basename=basename
)

return chip


def write_image(img, path, percentiles=None):
Expand Down
9 changes: 9 additions & 0 deletions src/satproc/console/extract_chips.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ def parse_args(args):
help="do not skip already existing chips (and masks)",
)

parser.add_argument(
"--num-jobs",
"-j",
type=int,
default=1,
help="number of jobs to run in parallel",
)

parser.add_argument(
"--version", action="version", version="satproc {ver}".format(ver=__version__)
)
Expand Down Expand Up @@ -284,6 +292,7 @@ def main(args):
step_size=args.step_size,
windows_mode=args.sliding_windows_mode,
output_dir=args.output_dir,
num_jobs=args.num_jobs,
)


Expand Down
31 changes: 24 additions & 7 deletions src/satproc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import subprocess
from itertools import zip_longest
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool

import numpy as np
Expand Down Expand Up @@ -287,19 +288,33 @@ def run_command(cmd, quiet=True):
subprocess.run(cmd, shell=True, stderr=stderr, stdout=stdout)


def map_with_threads(items, worker, num_jobs=None, total=None):
"""Map a worker function to an iterable of items, using a thread pool
def map_with_processes(*args, num_jobs=None, **kwargs):
with Pool(num_jobs) as pool:
return _parallel_map(pool, *args, **kwargs)


def map_with_threads(*args, num_jobs=None, **kwargs):
with ThreadPool(num_jobs) as pool:
return _parallel_map(pool, *args, **kwargs)


def _parallel_map(pool, items, worker, num_jobs=None, total=None, desc=""):
"""Map a worker function to an iterable of items, using a concurrent Pool
Parameters
----------
pool: mp.Pool
a concurrent Pool instance
items : iterable
items to map
worker : Function
worker function to apply to each item
num_jobs : int
num_jobs : Optional[int]
number of threads to use
total : int (optional)
total number of items (for the progress bar)
total : Optional[int]
total number of items (for progress bar)
desc : str
description to use as prefix (for progress bar)
Returns
-------
Expand All @@ -310,8 +325,10 @@ def map_with_threads(items, worker, num_jobs=None, total=None):
total = len(items)
if not num_jobs:
num_jobs = mp.cpu_count()
results = []
with ThreadPool(num_jobs) as pool:
with logging_redirect_tqdm():
with tqdm(total=len(items), ascii=True) as pbar:
for _ in enumerate(pool.imap_unordered(worker, items)):
with tqdm(total=len(items), desc=desc, ascii=True) as pbar:
for result in pool.imap_unordered(worker, items, chunksize=4):
pbar.update()
results.append(result)

0 comments on commit 6041f56

Please sign in to comment.