Skip to content

Commit

Permalink
Refactor P2P rechunk validation (#7890)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Jun 12, 2023
1 parent 618f5ac commit 19c8bf9
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 38 deletions.
31 changes: 1 addition & 30 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import math
from collections import defaultdict
from itertools import compress, product
from itertools import product
from typing import TYPE_CHECKING, NamedTuple

import dask
Expand Down Expand Up @@ -73,34 +72,6 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array:
# Special case for empty array, as the algorithm below does not behave correctly
return da.empty(x.shape, chunks=chunks, dtype=x.dtype)

old_chunks = x.chunks
new_chunks = chunks

def is_unknown(dim: ChunkedAxis) -> bool:
return any(math.isnan(chunk) for chunk in dim)

old_is_unknown = [is_unknown(dim) for dim in old_chunks]
new_is_unknown = [is_unknown(dim) for dim in new_chunks]

if old_is_unknown != new_is_unknown or any(
new != old for new, old in compress(zip(old_chunks, new_chunks), old_is_unknown)
):
raise ValueError(
"Chunks must be unchanging along dimensions with missing values.\n\n"
"A possible solution:\n x.compute_chunk_sizes()"
)

old_known = [dim for dim, unknown in zip(old_chunks, old_is_unknown) if not unknown]
new_known = [dim for dim, unknown in zip(new_chunks, new_is_unknown) if not unknown]

old_sizes = [sum(o) for o in old_known]
new_sizes = [sum(n) for n in new_known]

if old_sizes != new_sizes:
raise ValueError(
f"Cannot change dimensions from {old_sizes!r} to {new_sizes!r}"
)

dsk: dict = {}
token = tokenize(x, chunks)
_barrier_key = barrier_key(ShuffleId(token))
Expand Down
8 changes: 0 additions & 8 deletions distributed/shuffle/_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,6 @@ def __init__(
memory_limiter_comms=memory_limiter_comms,
memory_limiter_disk=memory_limiter_disk,
)
from dask.array.core import normalize_chunks

# We rely on a canonical `np.nan` in `dask.array.rechunk.old_to_new`
# that passes an implicit identity check when testing for list equality.
# This does not work with (de)serialization, so we have to normalize the chunks
# here again to canonicalize `nan`s.
old = normalize_chunks(old)
new = normalize_chunks(new)
self.old = old
self.new = new
partitions_of = defaultdict(list)
Expand Down

0 comments on commit 19c8bf9

Please sign in to comment.