From 1834ee191ee5b5f76e4a28e15d1b1e82073dfd3e Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Tue, 12 Dec 2023 17:20:57 +0000 Subject: [PATCH] Add subset test --- pyop3/array/petsc.py | 117 ++++++++++++++++++++---------- pyop3/axtree/parallel.py | 12 --- pyop3/axtree/tree.py | 27 ++++++- pyop3/sf.py | 4 + pyop3/utils.py | 13 ++++ tests/integration/test_subsets.py | 93 ++++++++++++++---------- 6 files changed, 176 insertions(+), 90 deletions(-) diff --git a/pyop3/array/petsc.py b/pyop3/array/petsc.py index fda71d73..b240b3b6 100644 --- a/pyop3/array/petsc.py +++ b/pyop3/array/petsc.py @@ -66,7 +66,7 @@ class MatType(enum.Enum): DEFAULT_MAT_TYPE = MatType.AIJ -class PetscMat(PetscObject): +class PetscMat(PetscObject, abc.ABC): prefix = "mat" def __new__(cls, *args, **kwargs): @@ -78,6 +78,13 @@ def __new__(cls, *args, **kwargs): else: raise AssertionError + # like Dat, bad name? handle? + @property + def array(self): + return self.petscmat + + +class MonolithicPetscMat(PetscMat, abc.ABC): def __getitem__(self, indices): if len(indices) != 2: raise ValueError @@ -118,8 +125,8 @@ class ContextSensitiveIndexedPetscMat(ContextSensitive): pass -class PetscMatAIJ(PetscMat): - def __init__(self, raxes, caxes, sparsity, *, comm=None, name: str = None): +class PetscMatAIJ(MonolithicPetscMat): + def __init__(self, raxes, caxes, sparsity, *, name: str = None): raxes = as_axis_tree(raxes) caxes = as_axis_tree(caxes) @@ -132,34 +139,7 @@ def __init__(self, raxes, caxes, sparsity, *, comm=None, name: str = None): # TODO, good exceptions raise RuntimeError - sizes = (raxes.leaf_component.count, caxes.leaf_component.count) - nnz = sparsity.axes.leaf_component.count - mat = PETSc.Mat().createAIJ(sizes, nnz=nnz.data, comm=comm) - - # fill with zeros (this should be cached) - # this could be done as a pyop3 loop (if we get ragged local working) or - # explicitly in cython - raxis, rcpt = raxes.leaf - caxis, ccpt = caxes.leaf - # e.g. - # map_ = Map({pmap({raxis.label: rcpt.label}): [TabulatedMapComponent(caxes.label, ccpt.label, sparsity)]}) - # do_loop(p := raxes.index(), write(zeros, mat[p, map_(p)])) - - # but for now do in Python... - assert nnz.max_value is not None - zeros = np.zeros(nnz.max_value, dtype=self.dtype) - for row_idx in range(rcpt.count): - cstart = sparsity.axes.offset([row_idx, 0]) - try: - cstop = sparsity.axes.offset([row_idx + 1, 0]) - except IndexError: - # catch the last one - cstop = len(sparsity.data_ro) - # truncate zeros - mat.setValuesLocal( - [row_idx], sparsity.data_ro[cstart:cstop], zeros[: cstop - cstart] - ) - mat.assemble() + self.petscmat = _alloc_mat(raxes, caxes, sparsity) self.raxis = raxes.root self.caxis = caxes.root @@ -167,17 +147,33 @@ def __init__(self, raxes, caxes, sparsity, *, comm=None, name: str = None): self.axes = AxisTree.from_nest({self.raxis: self.caxis}) - # copy only needed if we reuse the zero matrix - self.petscmat = mat.copy() - # like Dat, bad name? handle? - @property - def array(self): - return self.petscmat +class PetscMatBAIJ(MonolithicPetscMat): + def __init__(self, raxes, caxes, sparsity, bsize, *, name: str = None): + raxes = as_axis_tree(raxes) + caxes = as_axis_tree(caxes) + if isinstance(bsize, numbers.Integral): + bsize = (bsize, bsize) -class PetscMatBAIJ(PetscMat): - ... + super().__init__(name) + if any(axes.depth > 1 for axes in [raxes, caxes]): + # TODO, good exceptions + # raise InvalidDimensionException("Cannot instantiate PetscMats with nested axis trees") + raise RuntimeError + if any(len(axes.root.components) > 1 for axes in [raxes, caxes]): + # TODO, good exceptions + raise RuntimeError + + self.petscmat = _alloc_mat(raxes, caxes, sparsity, bsize) + + self.raxis = raxes.root + self.caxis = caxes.root + self.sparsity = sparsity + self.bsize = bsize + + # TODO include bsize here? + self.axes = AxisTree.from_nest({self.raxis: self.caxis}) class PetscMatNest(PetscMat): @@ -190,3 +186,46 @@ class PetscMatDense(PetscMat): class PetscMatPython(PetscMat): ... + + +# TODO cache this function and return a copy if possible +def _alloc_mat(raxes, caxes, sparsity, bsize=None): + comm = single_valued([raxes.comm, caxes.comm]) + + sizes = (raxes.leaf_component.count, caxes.leaf_component.count) + nnz = sparsity.axes.leaf_component.count + + if bsize is None: + mat = PETSc.Mat().createAIJ(sizes, nnz=nnz.data, comm=comm) + else: + mat = PETSc.Mat().createBAIJ(sizes, bsize, nnz=nnz.data, comm=comm) + + # fill with zeros (this should be cached) + # this could be done as a pyop3 loop (if we get ragged local working) or + # explicitly in cython + raxis, rcpt = raxes.leaf + caxis, ccpt = caxes.leaf + # e.g. + # map_ = Map({pmap({raxis.label: rcpt.label}): [TabulatedMapComponent(caxes.label, ccpt.label, sparsity)]}) + # do_loop(p := raxes.index(), write(zeros, mat[p, map_(p)])) + + # but for now do in Python... + assert nnz.max_value is not None + if bsize is None: + shape = (nnz.max_value,) + set_values = mat.setValuesLocal + else: + rbsize, _ = bsize + shape = (nnz.max_value, rbsize) + set_values = mat.setValuesBlockedLocal + zeros = np.zeros(shape, dtype=PetscMat.dtype) + for row_idx in range(rcpt.count): + cstart = sparsity.axes.offset([row_idx, 0]) + try: + cstop = sparsity.axes.offset([row_idx + 1, 0]) + except IndexError: + # catch the last one + cstop = len(sparsity.data_ro) + set_values([row_idx], sparsity.data_ro[cstart:cstop], zeros[: cstop - cstart]) + mat.assemble() + return mat diff --git a/pyop3/axtree/parallel.py b/pyop3/axtree/parallel.py index f4bce59a..93b086c7 100644 --- a/pyop3/axtree/parallel.py +++ b/pyop3/axtree/parallel.py @@ -49,18 +49,6 @@ def partition_ghost_points(axis, sf): return numbering -# stolen from stackoverflow -# https://stackoverflow.com/questions/11649577/how-to-invert-a-permutation-array-in-numpy -def invert(p): - """Return an array s with which np.array_equal(arr[p][s], arr) is True. - The array_like argument p must be some permutation of 0, 1, ..., len(p)-1. - """ - p = np.asanyarray(p) # in case p is a tuple, etc. - s = np.empty_like(p) - s[p] = np.arange(p.size) - return s - - def collect_sf_graphs(axes, axis=None, path=pmap(), indices=pmap()): # NOTE: This function does not check for nested SFs (which should error) axis = axis or axes.root diff --git a/pyop3/axtree/tree.py b/pyop3/axtree/tree.py index 0b9020db..1d6f57d9 100644 --- a/pyop3/axtree/tree.py +++ b/pyop3/axtree/tree.py @@ -42,6 +42,7 @@ flatten, frozen_record, has_unique_entries, + invert, is_single_valued, just_one, merge_dicts, @@ -330,6 +331,10 @@ def from_serial(cls, serial: Axis, sf): numbering = partition_ghost_points(serial, sf) return cls(serial.components, serial.label, numbering=numbering, sf=sf) + @property + def comm(self): + return self.sf.comm if self.sf else None + @property def size(self): return as_axis_tree(self).size @@ -411,6 +416,14 @@ def as_tree(self) -> AxisTree: """ return self._tree + def component_numbering(self, component): + cidx = self.component_index(component) + return self._default_to_applied_numbering[cidx] + + def component_permutation(self, component): + cidx = self.component_index(component) + return self._default_to_applied_permutation[cidx] + def default_to_applied_component_number(self, component, number): cidx = self.component_index(component) return self._default_to_applied_numbering[cidx][number] @@ -443,7 +456,11 @@ def _default_to_applied_numbering(self): old_cpt_pt = pt - self._component_offsets[cidx] renumbering[cidx][old_cpt_pt] = next(counters[cidx]) assert all(next(counters[i]) == c.count for i, c in enumerate(self.components)) - return renumbering + return tuple(renumbering) + + @cached_property + def _default_to_applied_permutation(self): + return tuple(invert(num) for num in self._default_to_applied_numbering) @cached_property def _applied_to_default_numbering(self): @@ -698,6 +715,14 @@ def layouts(self): def sf(self): return self._default_sf() + @property + def comm(self): + paraxes = [axis for axis in self.nodes if axis.sf is not None] + if not paraxes: + return None + else: + return single_valued(ax.comm for ax in paraxes) + @cached_property def datamap(self): if self.is_empty: diff --git a/pyop3/sf.py b/pyop3/sf.py index 292a9bb6..96e4b232 100644 --- a/pyop3/sf.py +++ b/pyop3/sf.py @@ -25,6 +25,10 @@ def from_graph(cls, size: int, nroots: int, ilocal, iremote, comm=None): sf.setGraph(nroots, ilocal, iremote) return cls(sf, size) + @property + def comm(self): + return self.sf.comm + @cached_property def iroot(self): """Return the indices of roots on the current process.""" diff --git a/pyop3/utils.py b/pyop3/utils.py index ac07329b..38cd0f90 100644 --- a/pyop3/utils.py +++ b/pyop3/utils.py @@ -196,6 +196,7 @@ def popwhen(predicate, iterable): def steps(sizes): + sizes = tuple(sizes) return (0,) + tuple(np.cumsum(sizes, dtype=int)) @@ -203,6 +204,18 @@ def pairwise(iterable): return zip(iterable, iterable[1:]) +# stolen from stackoverflow +# https://stackoverflow.com/questions/11649577/how-to-invert-a-permutation-array-in-numpy +def invert(p): + """Return an array s with which np.array_equal(arr[p][s], arr) is True. + The array_like argument p must be some permutation of 0, 1, ..., len(p)-1. + """ + p = np.asanyarray(p) # in case p is a tuple, etc. + s = np.empty_like(p) + s[p] = np.arange(p.size) + return s + + def strict_cast(obj, cast): new_obj = cast(obj) if new_obj != obj: diff --git a/tests/integration/test_subsets.py b/tests/integration/test_subsets.py index c85e8cc6..fdee5ed2 100644 --- a/tests/integration/test_subsets.py +++ b/tests/integration/test_subsets.py @@ -1,30 +1,9 @@ import loopy as lp import numpy as np import pytest -from pyrsistent import pmap - -from pyop3 import ( - INC, - READ, - WRITE, - AffineSliceComponent, - Axis, - AxisComponent, - AxisTree, - Index, - IndexTree, - IntType, - Map, - MultiArray, - ScalarType, - Slice, - SliceComponent, - TabulatedMapComponent, - do_loop, - loop, -) + +import pyop3 as op3 from pyop3.ir import LOOPY_LANG_VERSION, LOOPY_TARGET -from pyop3.utils import flatten @pytest.mark.parametrize( @@ -37,27 +16,65 @@ ) def test_loop_over_slices(scalar_copy_kernel, touched, untouched): npoints = 10 - axes = AxisTree(Axis(npoints)) - dat0 = MultiArray(axes, name="dat0", data=np.arange(npoints, dtype=ScalarType)) - dat1 = MultiArray(axes, name="dat1", dtype=dat0.dtype) + axes = op3.Axis(npoints) + dat0 = op3.HierarchicalArray( + axes, name="dat0", data=np.arange(npoints), dtype=op3.ScalarType + ) + dat1 = op3.HierarchicalArray(axes, name="dat1", dtype=dat0.dtype) - do_loop(p := axes[touched].index(), scalar_copy_kernel(dat0[p], dat1[p])) - assert np.allclose(dat1.data[untouched], 0) - assert np.allclose(dat1.data[touched], dat0.data[touched]) + op3.do_loop(p := axes[touched].index(), scalar_copy_kernel(dat0[p], dat1[p])) + assert np.allclose(dat1.data_ro[untouched], 0) + assert np.allclose(dat1.data_ro[touched], dat0.data_ro[touched]) @pytest.mark.parametrize("size,touched", [(6, [2, 3, 5, 0])]) def test_scalar_copy_of_subset(scalar_copy_kernel, size, touched): untouched = list(set(range(size)) - set(touched)) - subset_axes = Axis([AxisComponent(len(touched), "pt0")], "ax0") - subset = MultiArray( - subset_axes, name="subset0", data=np.asarray(touched, dtype=IntType) + subset_axes = op3.Axis({"pt0": len(touched)}, "ax0") + subset = op3.HierarchicalArray( + subset_axes, name="subset0", data=np.asarray(touched), dtype=op3.IntType ) - axes = Axis([AxisComponent(size, "pt0")], "ax0") - dat0 = MultiArray(axes, name="dat0", data=np.arange(axes.size, dtype=ScalarType)) - dat1 = MultiArray(axes, name="dat1", dtype=dat0.dtype) + axes = op3.Axis({"pt0": size}, "ax0") + dat0 = op3.HierarchicalArray( + axes, name="dat0", data=np.arange(axes.size), dtype=op3.ScalarType + ) + dat1 = op3.HierarchicalArray(axes, name="dat1", dtype=dat0.dtype) + + op3.do_loop(p := axes[subset].index(), scalar_copy_kernel(dat0[p], dat1[p])) + assert np.allclose(dat1.data_ro[touched], dat0.data_ro[touched]) + assert np.allclose(dat1.data_ro[untouched], 0) + + +@pytest.mark.parametrize("size,indices", [(6, [2, 3, 5, 0])]) +def test_write_to_subset(scalar_copy_kernel, size, indices): + n = len(indices) + + subset_axes = op3.Axis({"pt0": n}, "ax0") + subset = op3.HierarchicalArray( + subset_axes, name="subset0", data=np.asarray(indices), dtype=op3.IntType + ) + + axes = op3.Axis({"pt0": size}, "ax0") + dat0 = op3.HierarchicalArray( + axes, name="dat0", data=np.arange(axes.size), dtype=op3.IntType + ) + dat1 = op3.HierarchicalArray(subset_axes, name="dat1", dtype=dat0.dtype) + + kernel = op3.Function( + lp.make_kernel( + f"{{ [i]: 0 <= i < {n} }}", + "y[i] = x[i]", + [ + lp.GlobalArg("x", shape=(n,), dtype=dat0.dtype), + lp.GlobalArg("y", shape=(n,), dtype=dat0.dtype), + ], + name="copy", + target=LOOPY_TARGET, + lang_version=LOOPY_LANG_VERSION, + ), + [op3.READ, op3.WRITE], + ) - do_loop(p := axes[subset].index(), scalar_copy_kernel(dat0[p], dat1[p])) - assert np.allclose(dat1.data[touched], dat0.data[touched]) - assert np.allclose(dat1.data[untouched], 0) + op3.do_loop(op3.Axis(1).index(), kernel(dat0[subset], dat1)) + assert (dat1.data_ro == indices).all()