Skip to content

Commit

Permalink
Add subset test
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Dec 12, 2023
1 parent 39ea9d3 commit 1834ee1
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 90 deletions.
117 changes: 78 additions & 39 deletions pyop3/array/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class MatType(enum.Enum):
DEFAULT_MAT_TYPE = MatType.AIJ


class PetscMat(PetscObject):
class PetscMat(PetscObject, abc.ABC):
prefix = "mat"

def __new__(cls, *args, **kwargs):
Expand All @@ -78,6 +78,13 @@ def __new__(cls, *args, **kwargs):
else:
raise AssertionError

# like Dat, bad name? handle?
@property
def array(self):
return self.petscmat


class MonolithicPetscMat(PetscMat, abc.ABC):
def __getitem__(self, indices):
if len(indices) != 2:
raise ValueError
Expand Down Expand Up @@ -118,8 +125,8 @@ class ContextSensitiveIndexedPetscMat(ContextSensitive):
pass


class PetscMatAIJ(PetscMat):
def __init__(self, raxes, caxes, sparsity, *, comm=None, name: str = None):
class PetscMatAIJ(MonolithicPetscMat):
def __init__(self, raxes, caxes, sparsity, *, name: str = None):
raxes = as_axis_tree(raxes)
caxes = as_axis_tree(caxes)

Expand All @@ -132,52 +139,41 @@ def __init__(self, raxes, caxes, sparsity, *, comm=None, name: str = None):
# TODO, good exceptions
raise RuntimeError

sizes = (raxes.leaf_component.count, caxes.leaf_component.count)
nnz = sparsity.axes.leaf_component.count
mat = PETSc.Mat().createAIJ(sizes, nnz=nnz.data, comm=comm)

# fill with zeros (this should be cached)
# this could be done as a pyop3 loop (if we get ragged local working) or
# explicitly in cython
raxis, rcpt = raxes.leaf
caxis, ccpt = caxes.leaf
# e.g.
# map_ = Map({pmap({raxis.label: rcpt.label}): [TabulatedMapComponent(caxes.label, ccpt.label, sparsity)]})
# do_loop(p := raxes.index(), write(zeros, mat[p, map_(p)]))

# but for now do in Python...
assert nnz.max_value is not None
zeros = np.zeros(nnz.max_value, dtype=self.dtype)
for row_idx in range(rcpt.count):
cstart = sparsity.axes.offset([row_idx, 0])
try:
cstop = sparsity.axes.offset([row_idx + 1, 0])
except IndexError:
# catch the last one
cstop = len(sparsity.data_ro)
# truncate zeros
mat.setValuesLocal(
[row_idx], sparsity.data_ro[cstart:cstop], zeros[: cstop - cstart]
)
mat.assemble()
self.petscmat = _alloc_mat(raxes, caxes, sparsity)

self.raxis = raxes.root
self.caxis = caxes.root
self.sparsity = sparsity

self.axes = AxisTree.from_nest({self.raxis: self.caxis})

# copy only needed if we reuse the zero matrix
self.petscmat = mat.copy()

# like Dat, bad name? handle?
@property
def array(self):
return self.petscmat
class PetscMatBAIJ(MonolithicPetscMat):
def __init__(self, raxes, caxes, sparsity, bsize, *, name: str = None):
raxes = as_axis_tree(raxes)
caxes = as_axis_tree(caxes)

if isinstance(bsize, numbers.Integral):
bsize = (bsize, bsize)

class PetscMatBAIJ(PetscMat):
...
super().__init__(name)
if any(axes.depth > 1 for axes in [raxes, caxes]):
# TODO, good exceptions
# raise InvalidDimensionException("Cannot instantiate PetscMats with nested axis trees")
raise RuntimeError
if any(len(axes.root.components) > 1 for axes in [raxes, caxes]):
# TODO, good exceptions
raise RuntimeError

self.petscmat = _alloc_mat(raxes, caxes, sparsity, bsize)

self.raxis = raxes.root
self.caxis = caxes.root
self.sparsity = sparsity
self.bsize = bsize

# TODO include bsize here?
self.axes = AxisTree.from_nest({self.raxis: self.caxis})


class PetscMatNest(PetscMat):
Expand All @@ -190,3 +186,46 @@ class PetscMatDense(PetscMat):

class PetscMatPython(PetscMat):
...


# TODO cache this function and return a copy if possible
def _alloc_mat(raxes, caxes, sparsity, bsize=None):
comm = single_valued([raxes.comm, caxes.comm])

sizes = (raxes.leaf_component.count, caxes.leaf_component.count)
nnz = sparsity.axes.leaf_component.count

if bsize is None:
mat = PETSc.Mat().createAIJ(sizes, nnz=nnz.data, comm=comm)
else:
mat = PETSc.Mat().createBAIJ(sizes, bsize, nnz=nnz.data, comm=comm)

# fill with zeros (this should be cached)
# this could be done as a pyop3 loop (if we get ragged local working) or
# explicitly in cython
raxis, rcpt = raxes.leaf
caxis, ccpt = caxes.leaf
# e.g.
# map_ = Map({pmap({raxis.label: rcpt.label}): [TabulatedMapComponent(caxes.label, ccpt.label, sparsity)]})
# do_loop(p := raxes.index(), write(zeros, mat[p, map_(p)]))

# but for now do in Python...
assert nnz.max_value is not None
if bsize is None:
shape = (nnz.max_value,)
set_values = mat.setValuesLocal
else:
rbsize, _ = bsize
shape = (nnz.max_value, rbsize)
set_values = mat.setValuesBlockedLocal
zeros = np.zeros(shape, dtype=PetscMat.dtype)
for row_idx in range(rcpt.count):
cstart = sparsity.axes.offset([row_idx, 0])
try:
cstop = sparsity.axes.offset([row_idx + 1, 0])
except IndexError:
# catch the last one
cstop = len(sparsity.data_ro)
set_values([row_idx], sparsity.data_ro[cstart:cstop], zeros[: cstop - cstart])
mat.assemble()
return mat
12 changes: 0 additions & 12 deletions pyop3/axtree/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,6 @@ def partition_ghost_points(axis, sf):
return numbering


# stolen from stackoverflow
# https://stackoverflow.com/questions/11649577/how-to-invert-a-permutation-array-in-numpy
def invert(p):
"""Return an array s with which np.array_equal(arr[p][s], arr) is True.
The array_like argument p must be some permutation of 0, 1, ..., len(p)-1.
"""
p = np.asanyarray(p) # in case p is a tuple, etc.
s = np.empty_like(p)
s[p] = np.arange(p.size)
return s


def collect_sf_graphs(axes, axis=None, path=pmap(), indices=pmap()):
# NOTE: This function does not check for nested SFs (which should error)
axis = axis or axes.root
Expand Down
27 changes: 26 additions & 1 deletion pyop3/axtree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
flatten,
frozen_record,
has_unique_entries,
invert,
is_single_valued,
just_one,
merge_dicts,
Expand Down Expand Up @@ -330,6 +331,10 @@ def from_serial(cls, serial: Axis, sf):
numbering = partition_ghost_points(serial, sf)
return cls(serial.components, serial.label, numbering=numbering, sf=sf)

@property
def comm(self):
return self.sf.comm if self.sf else None

@property
def size(self):
return as_axis_tree(self).size
Expand Down Expand Up @@ -411,6 +416,14 @@ def as_tree(self) -> AxisTree:
"""
return self._tree

def component_numbering(self, component):
cidx = self.component_index(component)
return self._default_to_applied_numbering[cidx]

def component_permutation(self, component):
cidx = self.component_index(component)
return self._default_to_applied_permutation[cidx]

def default_to_applied_component_number(self, component, number):
cidx = self.component_index(component)
return self._default_to_applied_numbering[cidx][number]
Expand Down Expand Up @@ -443,7 +456,11 @@ def _default_to_applied_numbering(self):
old_cpt_pt = pt - self._component_offsets[cidx]
renumbering[cidx][old_cpt_pt] = next(counters[cidx])
assert all(next(counters[i]) == c.count for i, c in enumerate(self.components))
return renumbering
return tuple(renumbering)

@cached_property
def _default_to_applied_permutation(self):
return tuple(invert(num) for num in self._default_to_applied_numbering)

@cached_property
def _applied_to_default_numbering(self):
Expand Down Expand Up @@ -698,6 +715,14 @@ def layouts(self):
def sf(self):
return self._default_sf()

@property
def comm(self):
paraxes = [axis for axis in self.nodes if axis.sf is not None]
if not paraxes:
return None
else:
return single_valued(ax.comm for ax in paraxes)

@cached_property
def datamap(self):
if self.is_empty:
Expand Down
4 changes: 4 additions & 0 deletions pyop3/sf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def from_graph(cls, size: int, nroots: int, ilocal, iremote, comm=None):
sf.setGraph(nroots, ilocal, iremote)
return cls(sf, size)

@property
def comm(self):
return self.sf.comm

@cached_property
def iroot(self):
"""Return the indices of roots on the current process."""
Expand Down
13 changes: 13 additions & 0 deletions pyop3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,26 @@ def popwhen(predicate, iterable):


def steps(sizes):
sizes = tuple(sizes)
return (0,) + tuple(np.cumsum(sizes, dtype=int))


def pairwise(iterable):
return zip(iterable, iterable[1:])


# stolen from stackoverflow
# https://stackoverflow.com/questions/11649577/how-to-invert-a-permutation-array-in-numpy
def invert(p):
"""Return an array s with which np.array_equal(arr[p][s], arr) is True.
The array_like argument p must be some permutation of 0, 1, ..., len(p)-1.
"""
p = np.asanyarray(p) # in case p is a tuple, etc.
s = np.empty_like(p)
s[p] = np.arange(p.size)
return s


def strict_cast(obj, cast):
new_obj = cast(obj)
if new_obj != obj:
Expand Down
93 changes: 55 additions & 38 deletions tests/integration/test_subsets.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,9 @@
import loopy as lp
import numpy as np
import pytest
from pyrsistent import pmap

from pyop3 import (
INC,
READ,
WRITE,
AffineSliceComponent,
Axis,
AxisComponent,
AxisTree,
Index,
IndexTree,
IntType,
Map,
MultiArray,
ScalarType,
Slice,
SliceComponent,
TabulatedMapComponent,
do_loop,
loop,
)

import pyop3 as op3
from pyop3.ir import LOOPY_LANG_VERSION, LOOPY_TARGET
from pyop3.utils import flatten


@pytest.mark.parametrize(
Expand All @@ -37,27 +16,65 @@
)
def test_loop_over_slices(scalar_copy_kernel, touched, untouched):
npoints = 10
axes = AxisTree(Axis(npoints))
dat0 = MultiArray(axes, name="dat0", data=np.arange(npoints, dtype=ScalarType))
dat1 = MultiArray(axes, name="dat1", dtype=dat0.dtype)
axes = op3.Axis(npoints)
dat0 = op3.HierarchicalArray(
axes, name="dat0", data=np.arange(npoints), dtype=op3.ScalarType
)
dat1 = op3.HierarchicalArray(axes, name="dat1", dtype=dat0.dtype)

do_loop(p := axes[touched].index(), scalar_copy_kernel(dat0[p], dat1[p]))
assert np.allclose(dat1.data[untouched], 0)
assert np.allclose(dat1.data[touched], dat0.data[touched])
op3.do_loop(p := axes[touched].index(), scalar_copy_kernel(dat0[p], dat1[p]))
assert np.allclose(dat1.data_ro[untouched], 0)
assert np.allclose(dat1.data_ro[touched], dat0.data_ro[touched])


@pytest.mark.parametrize("size,touched", [(6, [2, 3, 5, 0])])
def test_scalar_copy_of_subset(scalar_copy_kernel, size, touched):
untouched = list(set(range(size)) - set(touched))
subset_axes = Axis([AxisComponent(len(touched), "pt0")], "ax0")
subset = MultiArray(
subset_axes, name="subset0", data=np.asarray(touched, dtype=IntType)
subset_axes = op3.Axis({"pt0": len(touched)}, "ax0")
subset = op3.HierarchicalArray(
subset_axes, name="subset0", data=np.asarray(touched), dtype=op3.IntType
)

axes = Axis([AxisComponent(size, "pt0")], "ax0")
dat0 = MultiArray(axes, name="dat0", data=np.arange(axes.size, dtype=ScalarType))
dat1 = MultiArray(axes, name="dat1", dtype=dat0.dtype)
axes = op3.Axis({"pt0": size}, "ax0")
dat0 = op3.HierarchicalArray(
axes, name="dat0", data=np.arange(axes.size), dtype=op3.ScalarType
)
dat1 = op3.HierarchicalArray(axes, name="dat1", dtype=dat0.dtype)

op3.do_loop(p := axes[subset].index(), scalar_copy_kernel(dat0[p], dat1[p]))
assert np.allclose(dat1.data_ro[touched], dat0.data_ro[touched])
assert np.allclose(dat1.data_ro[untouched], 0)


@pytest.mark.parametrize("size,indices", [(6, [2, 3, 5, 0])])
def test_write_to_subset(scalar_copy_kernel, size, indices):
n = len(indices)

subset_axes = op3.Axis({"pt0": n}, "ax0")
subset = op3.HierarchicalArray(
subset_axes, name="subset0", data=np.asarray(indices), dtype=op3.IntType
)

axes = op3.Axis({"pt0": size}, "ax0")
dat0 = op3.HierarchicalArray(
axes, name="dat0", data=np.arange(axes.size), dtype=op3.IntType
)
dat1 = op3.HierarchicalArray(subset_axes, name="dat1", dtype=dat0.dtype)

kernel = op3.Function(
lp.make_kernel(
f"{{ [i]: 0 <= i < {n} }}",
"y[i] = x[i]",
[
lp.GlobalArg("x", shape=(n,), dtype=dat0.dtype),
lp.GlobalArg("y", shape=(n,), dtype=dat0.dtype),
],
name="copy",
target=LOOPY_TARGET,
lang_version=LOOPY_LANG_VERSION,
),
[op3.READ, op3.WRITE],
)

do_loop(p := axes[subset].index(), scalar_copy_kernel(dat0[p], dat1[p]))
assert np.allclose(dat1.data[touched], dat0.data[touched])
assert np.allclose(dat1.data[untouched], 0)
op3.do_loop(op3.Axis(1).index(), kernel(dat0[subset], dat1))
assert (dat1.data_ro == indices).all()

0 comments on commit 1834ee1

Please sign in to comment.