Skip to content

Commit

Permalink
All tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Nov 8, 2023
1 parent dd1fa07 commit 8ec3637
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 329 deletions.
43 changes: 21 additions & 22 deletions pyop3/distarray/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,44 +281,42 @@ def layouts(self):
def data(self):
import warnings

warnings.warn(".data is a deprecated alias for .data_rw", FutureWarning)
warnings.warn(
".data is a deprecated alias for .data_rw and will be removed in future",
FutureWarning,
)
return self.data_rw

@property
def data_rw(self):
if not self._roots_valid:
self.reduce_leaves_to_roots()

# modifying owned values invalidates ghosts
self._leaves_valid = False
return self._data[: self.axes.owned_size]

@property
def data_ro(self):
return readonly(self.data_rw)
if not self._roots_valid:
self.reduce_leaves_to_roots()
return readonly(self._data[: self.axes.owned_size])

@property
def data_wo(self):
# Even for write-only access we must ensure that roots are updated, otherwise
# writing to a subset of values would leave the array in a poorly defined state.
return self.data_rw

@property
def data_rw_with_ghosts(self):
return self._data

@property
def data_ro_with_ghosts(self):
# TODO
return self.data_rw_with_ghosts

@property
def data_wo_with_ghosts(self):
"""
Have to be careful. If not setting all values (i.e. subsets) should call
`reduce_leaves_to_roots` first.
This method sets the leaves as being valid but this is dangerous since
the set values must match those on other processors. In practice this
should only be used for setting constant values.
When this is called we set roots_valid, claiming that any (lazy) 'in-flight' writes
can be dropped.
"""
return self.data_rw_with_ghosts
# pending writes can be dropped (care needed if only doing subsets)
self._roots_valid = True
self._last_write_op = None
# modifying owned values invalidates ghosts
self._leaves_valid = False
return self._data[: self.axes.owned_size]

@functools.cached_property
def datamap(self) -> dict[str:DistributedArray]:
Expand Down Expand Up @@ -411,6 +409,7 @@ def reduce_leaves_to_roots(self):

self._roots_valid = True
self._leaves_valid = False
self._last_write_op = None

def broadcast_roots_to_leaves(self):
sf = self.axes.sf
Expand Down
27 changes: 20 additions & 7 deletions pyop3/sf.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from functools import cached_property

import numpy as np
from mpi4py import MPI
from petsc4py import PETSc

from pyop3.dtypes import get_mpi_dtype
from pyop3.utils import just_one


class StarForest:
Expand All @@ -22,18 +25,26 @@ def from_graph(cls, size: int, nroots: int, ilocal, iremote, comm=None):
def iroot(self):
"""Return the indices of roots on the current process."""
# mark leaves and reduce
buffer = np.full(self.size, False, dtype=bool)
buffer[ilocal] = True
self.reduce(buffer, MPI.REPLACE)
mask = np.full(self.size, False, dtype=bool)
mask[self.ileaf] = True
self.reduce(mask, MPI.REPLACE)

# now clear the leaf indices, the remaining marked indices are roots
buffer[ilocal] = False
return just_one(np.nonzero(buffer))
mask[self.ileaf] = False
return just_one(np.nonzero(mask))

@property
def ileaf(self):
return self.ilocal

@cached_property
def icore(self):
"""Return the indices of points that are not roots or leaves."""
mask = np.full(self.size, True, dtype=bool)
mask[self.iroot] = False
mask[self.ileaf] = False
return just_one(np.nonzero(mask))

@property
def nroots(self):
return self._graph[0]
Expand Down Expand Up @@ -78,8 +89,7 @@ def reduce_end(self, *args):
def _graph(self):
return self.sf.getGraph()

@staticmethod
def _prepare_args(*args):
def _prepare_args(self, *args):
if len(args) == 3:
from_buffer, to_buffer, op = args
elif len(args) == 2:
Expand All @@ -88,6 +98,9 @@ def _prepare_args(*args):
else:
raise ValueError

if any(len(buf) != self.size for buf in [from_buffer, to_buffer]):
raise ValueError

# what about cdim?
dtype, _ = get_mpi_dtype(from_buffer.dtype)
return (dtype, from_buffer, to_buffer, op)
Loading

0 comments on commit 8ec3637

Please sign in to comment.