Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel #16

Merged
merged 38 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
1d8d4bb
WIP Begin making parallel work, need to fix numbering
connorjward Oct 25, 2023
3f1d49b
Fix numbering
connorjward Oct 25, 2023
6731d26
Implement some parallel tests
connorjward Oct 27, 2023
5ce34da
Algorithm for storing ghost data at the end appears to work
connorjward Oct 30, 2023
483212d
Parallel tests appear to work
connorjward Oct 30, 2023
cd207c5
All tests passing
connorjward Oct 30, 2023
de92594
Make sure to install pytest-mpi
connorjward Oct 30, 2023
c75325f
fixup
connorjward Oct 30, 2023
9bf9c4e
fixup
connorjward Nov 2, 2023
20c3b49
Add iter method to loop index
connorjward Nov 2, 2023
15b65b3
Add loop index iter tests
connorjward Nov 2, 2023
271fbbe
Add simple partition iterset test
connorjward Nov 2, 2023
0e1a698
WIP start cleaning up pymbolic expr eval stuff
connorjward Nov 2, 2023
ca929b0
WIP nearly passes existing tests, only one simple fix to go
connorjward Nov 2, 2023
65bd328
fix petscmat test
connorjward Nov 3, 2023
6ccfa5c
fixup
connorjward Nov 3, 2023
91886c0
FINALLY, all tests pass
connorjward Nov 3, 2023
427ddc6
All tests passing apart from one
connorjward Nov 7, 2023
8c8c5e6
Oh... that was easy...
connorjward Nov 7, 2023
a463d49
Cleanup
connorjward Nov 7, 2023
6b7b1af
cleanup
connorjward Nov 7, 2023
14ba8ce
Slight performance boost
connorjward Nov 7, 2023
41b4b8c
Improve layout tabulation
connorjward Nov 8, 2023
ff5d55c
Cleanup parallel numbering a bit
connorjward Nov 8, 2023
6885967
Add data_X_with_ghosts properties
connorjward Nov 8, 2023
81953ec
WIP Begin adding accessor logic
connorjward Nov 8, 2023
0a190f0
Add convenience StarForest class
connorjward Nov 8, 2023
dd1fa07
cleanup
connorjward Nov 8, 2023
8ec3637
All tests pass
connorjward Nov 8, 2023
4738673
Remove meshdata subpackage
connorjward Nov 9, 2023
594032c
WIP Begin writing halo exchange logic
connorjward Nov 9, 2023
9cbe1ce
All tests passing with new DistributedArray
connorjward Nov 10, 2023
5dcdc4f
WIP, tests failing because I am collecting arguments badly, need to h…
connorjward Nov 10, 2023
183b269
Rename a bunch of files, still not happy
connorjward Nov 15, 2023
6a9856c
Some parallel stuff working
connorjward Nov 16, 2023
6e9c50d
All tests passing
connorjward Nov 16, 2023
9adb6fe
Add parametrized loop extents and test parallel meshes
connorjward Nov 16, 2023
d2202e5
Fix things and add tests
connorjward Nov 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
working-directory: ${{ env.PETSC_DIR }}/src/binding/petsc4py
run: |
pip install --upgrade pip
pip install --upgrade wheel 'cython<3' numpy
pip install --upgrade wheel cython numpy
pip install --no-deps .

- name: Checkout pyop3
Expand All @@ -67,8 +67,8 @@ jobs:
shell: bash
working-directory: pyop3
run: |
pip install pytest pytest-cov pytest-timeout pytest-xdist pytest-timeout
pip install .
pip install ".[test]"
pip install pytest-cov pytest-timeout pytest-xdist

- name: Run tests
shell: bash
Expand Down
10 changes: 5 additions & 5 deletions pyop3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,23 @@
del pytools


import pyop3.transforms

Check warning on line 10 in pyop3/__init__.py

View workflow job for this annotation

GitHub Actions / lint

F401 'pyop3.transforms' imported but unused

Check failure on line 10 in pyop3/__init__.py

View workflow job for this annotation

GitHub Actions / lint

E402 module level import not at top of file
from pyop3.axes import Axis, AxisComponent, AxisTree # noqa: F401
from pyop3.distarray import MultiArray, PetscMat # noqa: F401
from pyop3.axtree import Axis, AxisComponent, AxisTree # noqa: F401

Check failure on line 11 in pyop3/__init__.py

View workflow job for this annotation

GitHub Actions / lint

E402 module level import not at top of file
from pyop3.distarray import Dat, MultiArray, PetscMat # noqa: F401

Check failure on line 12 in pyop3/__init__.py

View workflow job for this annotation

GitHub Actions / lint

E402 module level import not at top of file
from pyop3.distarray2 import DistributedArray # noqa: F401

Check failure on line 13 in pyop3/__init__.py

View workflow job for this annotation

GitHub Actions / lint

E402 module level import not at top of file
from pyop3.dtypes import IntType, ScalarType # noqa: F401

Check failure on line 14 in pyop3/__init__.py

View workflow job for this annotation

GitHub Actions / lint

E402 module level import not at top of file
from pyop3.indices import ( # noqa: F401
from pyop3.itree import ( # noqa: F401

Check failure on line 15 in pyop3/__init__.py

View workflow job for this annotation

GitHub Actions / lint

E402 module level import not at top of file
AffineSliceComponent,
Index,
IndexTree,
LoopIndex,
Map,
Slice,
SliceComponent,
Subset,
TabulatedMapComponent,
)
from pyop3.lang import ( # noqa: F401

Check failure on line 26 in pyop3/__init__.py

View workflow job for this annotation

GitHub Actions / lint

E402 module level import not at top of file
INC,
MAX_RW,
MAX_WRITE,
Expand All @@ -36,5 +38,3 @@
loop,
offset,
)
from pyop3.meshdata import Const, Dat, Mat # noqa: F401
from pyop3.space import ConstrainedAxis, Space # noqa: F401
2 changes: 1 addition & 1 deletion pyop3/axes/__init__.py → pyop3/axtree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyop3.axes.tree import (
from .tree import (

Check warning on line 1 in pyop3/axtree/__init__.py

View workflow job for this annotation

GitHub Actions / lint

F401 '.tree.Axis' imported but unused

Check warning on line 1 in pyop3/axtree/__init__.py

View workflow job for this annotation

GitHub Actions / lint

F401 '.tree.AxisComponent' imported but unused

Check warning on line 1 in pyop3/axtree/__init__.py

View workflow job for this annotation

GitHub Actions / lint

F401 '.tree.AxisTree' imported but unused

Check warning on line 1 in pyop3/axtree/__init__.py

View workflow job for this annotation

GitHub Actions / lint

F401 '.tree.AxisVariable' imported but unused

Check warning on line 1 in pyop3/axtree/__init__.py

View workflow job for this annotation

GitHub Actions / lint

F401 '.tree.ContextFree' imported but unused

Check warning on line 1 in pyop3/axtree/__init__.py

View workflow job for this annotation

GitHub Actions / lint

F401 '.tree.ContextSensitive' imported but unused

Check warning on line 1 in pyop3/axtree/__init__.py

View workflow job for this annotation

GitHub Actions / lint

F401 '.tree.LoopIterable' imported but unused

Check warning on line 1 in pyop3/axtree/__init__.py

View workflow job for this annotation

GitHub Actions / lint

F401 '.tree.as_axis_tree' imported but unused
Axis,
AxisComponent,
AxisTree,
Expand Down
166 changes: 166 additions & 0 deletions pyop3/axtree/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from __future__ import annotations

import numpy as np
from mpi4py import MPI
from petsc4py import PETSc

Check warning on line 5 in pyop3/axtree/parallel.py

View workflow job for this annotation

GitHub Actions / lint

F401 'petsc4py.PETSc' imported but unused
from pyrsistent import pmap

from pyop3.axtree.tree import _as_int, _axis_component_size, step_size
from pyop3.dtypes import IntType, get_mpi_dtype
from pyop3.extras.debug import print_with_rank
from pyop3.utils import checked_zip, just_one, strict_int


def partition_ghost_points(axis, sf):
npoints = sf.size
is_owned = np.full(npoints, True, dtype=bool)
is_owned[sf.ileaf] = False

numbering = np.empty(npoints, dtype=IntType)
owned_ptr = 0
ghost_ptr = npoints - sf.nleaves
points = axis.numbering if axis.numbering is not None else range(npoints)
for pt in points:
if is_owned[pt]:
numbering[owned_ptr] = pt
owned_ptr += 1
else:
numbering[ghost_ptr] = pt
ghost_ptr += 1

assert owned_ptr == npoints - sf.nleaves
assert ghost_ptr == npoints
return numbering


# stolen from stackoverflow
# https://stackoverflow.com/questions/11649577/how-to-invert-a-permutation-array-in-numpy
def invert(p):
"""Return an array s with which np.array_equal(arr[p][s], arr) is True.
The array_like argument p must be some permutation of 0, 1, ..., len(p)-1.
"""
p = np.asanyarray(p) # in case p is a tuple, etc.
s = np.empty_like(p)
s[p] = np.arange(p.size)
return s


def collect_sf_graphs(axes, axis=None, path=pmap(), indices=pmap()):
# NOTE: This function does not check for nested SFs (which should error)
axis = axis or axes.root

if axis.sf is not None:
return (grow_dof_sf(axes, axis, path, indices),)
else:
graphs = []
for component in axis.components:
subaxis = axes.child(axis, component)
if subaxis is not None:
for pt in range(_as_int(component.count, path, indices)):
graphs.extend(
collect_sf_graphs(
axes,
subaxis,
path | {axis.label: component.label},
indices | {axis.label: pt},
)
)
return tuple(graphs)


# perhaps I can defer renumbering the SF to here?
def grow_dof_sf(axes: FrozenAxisTree, axis, path, indices):
point_sf = axis.sf
# TODO, use convenience methods
nroots, ilocal, iremote = point_sf._graph

component_counts = tuple(c.count for c in axis.components)
component_offsets = [0] + list(np.cumsum(component_counts))
npoints = component_offsets[-1]

# renumbering per component, can skip if no renumbering present
renumbering = [np.empty(c.count, dtype=int) for c in axis.components]
counters = [0] * len(axis.components)
for new_pt, old_pt in enumerate(axis.numbering):
for cidx, (min_, max_) in enumerate(
zip(component_offsets, component_offsets[1:])
):
if min_ <= old_pt < max_:
renumbering[cidx][old_pt - min_] = counters[cidx]
counters[cidx] += 1
break
assert all(count == c.count for count, c in checked_zip(counters, axis.components))

# effectively build the section
root_offsets = np.full(npoints, -1, IntType)
for pt in point_sf.iroot:
# convert to a component-wise numbering
selected_component = None
component_num = None
for cidx, (min_, max_) in enumerate(
zip(component_offsets, component_offsets[1:])
):
if min_ <= pt < max_:
selected_component = axis.components[cidx]
component_num = renumbering[cidx][pt - component_offsets[cidx]]
break
assert selected_component is not None
assert component_num is not None

offset = axes.offset(
path | {axis.label: selected_component.label},
indices | {axis.label: component_num},
insert_zeros=True,
)
root_offsets[pt] = offset

print_with_rank("root offsets before", root_offsets)

point_sf.broadcast(root_offsets, MPI.REPLACE)

# for sanity reasons remove the original root values from the buffer
root_offsets[point_sf.iroot] = -1

local_leaf_offsets = np.empty(point_sf.nleaves, dtype=IntType)
leaf_ndofs = local_leaf_offsets.copy()
for myindex, pt in enumerate(ilocal):
# convert to a component-wise numbering
selected_component = None
component_num = None
for cidx, (min_, max_) in enumerate(
zip(component_offsets, component_offsets[1:])
):
if min_ <= pt < max_:
selected_component = axis.components[cidx]
component_num = renumbering[cidx][pt - component_offsets[cidx]]
break
assert selected_component is not None
assert component_num is not None

offset = axes.offset(
path | {axis.label: selected_component.label},
indices | {axis.label: component_num},
insert_zeros=True,
)
local_leaf_offsets[myindex] = offset
leaf_ndofs[myindex] = step_size(axes, axis, selected_component)

# construct a new SF with these offsets
ndofs = sum(leaf_ndofs)
local_leaf_dof_offsets = np.empty(ndofs, dtype=IntType)
remote_leaf_dof_offsets = np.empty((ndofs, 2), dtype=IntType)
counter = 0
for leaf, pos in enumerate(point_sf.ilocal):
for d in range(leaf_ndofs[leaf]):
local_leaf_dof_offsets[counter] = local_leaf_offsets[leaf] + d

rank = point_sf.iremote[leaf][0]
remote_leaf_dof_offsets[counter] = [rank, root_offsets[pos] + d]
counter += 1

print_with_rank("root offsets: ", root_offsets)
print_with_rank("local leaf offsets", local_leaf_offsets)
print_with_rank("local dof offsets: ", local_leaf_dof_offsets)
print_with_rank("remote offsets: ", remote_leaf_dof_offsets)

return (nroots, local_leaf_dof_offsets, remote_leaf_dof_offsets)
Loading