Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unused code #19

Merged
merged 2 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 1 addition & 208 deletions pyop3/array/harray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import collections
import functools
import itertools

Check warning on line 5 in pyop3/array/harray.py

View workflow job for this annotation

GitHub Actions / lint

F401 'itertools' imported but unused
import numbers

Check warning on line 6 in pyop3/array/harray.py

View workflow job for this annotation

GitHub Actions / lint

F401 'numbers' imported but unused
import operator

Check warning on line 7 in pyop3/array/harray.py

View workflow job for this annotation

GitHub Actions / lint

F401 'operator' imported but unused
import sys
import threading

Check warning on line 9 in pyop3/array/harray.py

View workflow job for this annotation

GitHub Actions / lint

F401 'threading' imported but unused
from functools import cached_property
from typing import Any, Optional, Sequence, Tuple, Union

Check warning on line 11 in pyop3/array/harray.py

View workflow job for this annotation

GitHub Actions / lint

F401 'typing.Any' imported but unused

import numpy as np
import pymbolic as pym
Expand Down Expand Up @@ -36,7 +36,6 @@
)
from pyop3.buffer import Buffer, DistributedBuffer
from pyop3.dtypes import IntType, ScalarType, get_mpi_dtype
from pyop3.extras.debug import print_if_rank, print_with_rank
from pyop3.itree import IndexTree, as_index_forest, index_axes
from pyop3.itree.tree import CalledMapVariable, collect_loop_indices, iter_axis_tree
from pyop3.lang import KernelArgument
Expand Down Expand Up @@ -106,14 +105,8 @@
):
super().__init__(name=name, prefix=prefix)

# TODO This is ugly
# temporary_axes = as_axis_tree(axes).freeze() # used for the temporary
# previously layout_axes
# drop index_exprs...
axes = as_axis_tree(axes)

# axes = as_layout_axes(axes)

if isinstance(data, Buffer):
# disable for now, temporaries hit this in an annoying way
# if data.sf is not axes.sf:
Expand All @@ -139,9 +132,7 @@
self.buffer = data

# instead implement "materialize"
# self.temporary_axes = temporary_axes
self.axes = axes
self.layout_axes = axes # used? likely don't need all these

self.max_value = max_value

Expand All @@ -167,7 +158,7 @@

loop_contexts = collect_loop_contexts(indices)
if not loop_contexts:
index_tree = just_one(as_index_forest(indices, axes=self.layout_axes))
index_tree = just_one(as_index_forest(indices, axes=self.axes))
(
indexed_axes,
target_path_per_indexed_cpt,
Expand Down Expand Up @@ -267,7 +258,7 @@
Have to be careful. If not setting all values (i.e. subsets) should call
`reduce_leaves_to_roots` first.

When this is called we set roots_valid, claiming that any (lazy) 'in-flight' writes

Check failure on line 261 in pyop3/array/harray.py

View workflow job for this annotation

GitHub Actions / lint

E501 line too long (91 > 88 characters)
can be dropped.
"""
return self.array.data_wo
Expand Down Expand Up @@ -350,14 +341,10 @@
return strict_int(offset)

def simple_offset(self, path, indices):
print_if_rank(0, "self.layouts", self.layouts)
print_if_rank(0, "path", path)
print_if_rank(0, "indices", indices)
offset = pym.evaluate(self.layouts[path], indices, ExpressionEvaluator)
return strict_int(offset)

def iter_indices(self, outer_map):
print_with_rank(0, "myiexpr!!!!!!!!!!!!!!!!!!", self.index_exprs)
return iter_axis_tree(self.axes, self.target_paths, self.index_exprs, outer_map)

def _with_axes(self, axes):
Expand Down Expand Up @@ -530,197 +517,3 @@

def _shared_attr(self, attr: str):
return single_valued(getattr(a, attr) for a in self.context_map.values())


def replace_layout(orig_layout, replace_map):
return IndexExpressionReplacer(replace_map)(orig_layout)


def as_layout_axes(axes: AxisTree) -> AxisTree:
# drop index exprs, everything else drops out
return AxisTree(
axes.parent_to_children,
axes.target_paths,
axes._default_index_exprs(),
axes.layout_exprs,
axes.layouts,
sf=axes.sf,
)


def make_sparsity(
iterindex,
lmap,
rmap,
llabels=PrettyTuple(),
rlabels=PrettyTuple(),
lindices=PrettyTuple(),
rindices=PrettyTuple(),
):
if iterindex:
if iterindex.children:
raise NotImplementedError(
"Need to think about what to do when we have more complicated "
"iteration sets that have multiple indices (e.g. extruded cells)"
)

if not isinstance(iterindex, Range):
raise NotImplementedError(
"Need to think about whether maps are reasonable here"
)

if not is_single_valued(idx.id for idx in [iterindex, lmap, rmap]):
raise ValueError("Indices must share common roots")

sparsity = collections.defaultdict(set)
for i in range(iterindex.size):
subsparsity = make_sparsity(
None,
lmap.child,
rmap.child,
llabels | iterindex.label,
rlabels | iterindex.label,
lindices | i,
rindices | i,
)
for labels, indices in subsparsity.items():
sparsity[labels].update(indices)
return sparsity
elif lmap:
if not isinstance(lmap, TabulatedMap):
raise NotImplementedError("Need to think about other index types")
if len(lmap.children) not in [0, 1]:
raise NotImplementedError("Need to think about maps forking")

new_labels = list(llabels)
# first pop the old things
for lbl in lmap.from_labels:
if lbl != new_labels[-1]:
raise ValueError("from_labels must match existing labels")
new_labels.pop()
# then append the new ones - only do the labels here, indices are
# done inside the loop
new_labels.extend(lmap.to_labels)
new_labels = PrettyTuple(new_labels)

sparsity = collections.defaultdict(set)
for i in range(lmap.size):
new_indices = PrettyTuple([lmap.data.get_value(lindices | i)])
subsparsity = make_sparsity(
None, lmap.child, rmap, new_labels, rlabels, new_indices, rindices
)
for labels, indices in subsparsity.items():
sparsity[labels].update(indices)
return sparsity
elif rmap:
if not isinstance(rmap, TabulatedMap):
raise NotImplementedError("Need to think about other index types")
if len(rmap.children) not in [0, 1]:
raise NotImplementedError("Need to think about maps forking")

new_labels = list(rlabels)
# first pop the old labels
for lbl in rmap.from_labels:
if lbl != new_labels[-1]:
raise ValueError("from_labels must match existing labels")
new_labels.pop()
# then append the new ones
new_labels.extend(rmap.to_labels)
new_labels = PrettyTuple(new_labels)

sparsity = collections.defaultdict(set)
for i in range(rmap.size):
new_indices = PrettyTuple([rmap.data.get_value(rindices | i)])
subsparsity = make_sparsity(
None, lmap, rmap.child, llabels, new_labels, lindices, new_indices
)
for labels, indices in subsparsity.items():
sparsity[labels].update(indices)
return sparsity
else:
# at the bottom, record an entry
# return {(llabels, rlabels): {(lindices, rindices)}}
# TODO: For now assume single values for each of these
llabel, rlabel = map(single_valued, [llabels, rlabels])
lindex, rindex = map(single_valued, [lindices, rindices])
return {(llabel, rlabel): {(lindex, rindex)}}


def distribute_sparsity(sparsity, ax1, ax2, owner="row"):
if any(ax.nparts > 1 for ax in [ax1, ax2]):
raise NotImplementedError("Only dealing with single-part multi-axes for now")

# how many points need to get sent to other processes?
# how many points do I get from other processes?
new_sparsity = collections.defaultdict(set)
points_to_send = collections.defaultdict(set)
for lindex, rindex in sparsity[ax1.part.label, ax2.part.label]:
if owner == "row":
olabel = ax1.part.overlap[lindex]
if is_owned_by_process(olabel):
new_sparsity[ax1.part.label, ax2.part.label].add((lindex, rindex))
else:
points_to_send[olabel.root.rank].add(
(ax1.part.lgmap[lindex], ax2.part.lgmap[rindex])
)
else:
raise NotImplementedError

# send points

# first determine how many new points we are getting from each rank
comm = single_valued([ax1.sf.comm, ax2.sf.comm]).tompi4py()
npoints_to_send = np.array(
[len(points_to_send[rank]) for rank in range(comm.size)], dtype=IntType
)
npoints_to_recv = np.empty_like(npoints_to_send)
comm.Alltoall(npoints_to_send, npoints_to_recv)

# communicate the offsets back
from_offsets = np.cumsum(npoints_to_recv)
to_offsets = np.empty_like(from_offsets)
comm.Alltoall(from_offsets, to_offsets)

# now send the globally numbered row, col values for each point that
# needs to be sent. This is easiest with an SF.

# nroots is the number of points to send
nroots = sum(npoints_to_send)
local_points = None # contiguous storage

idx = 0
remote_points = []
for rank in range(comm.size):
for i in range(npoints_to_recv[rank]):
remote_points.extend([rank, to_offsets[idx]])
idx += 1

sf = PETSc.SF().create(comm)
sf.setGraph(nroots, local_points, remote_points)

# create a buffer to hold the new values
# x2 since we are sending row and column numbers
new_points = np.empty(sum(npoints_to_recv) * 2, dtype=IntType)
rootdata = np.array(
[
num
for rank in range(comm.size)
for lnum, rnum in points_to_send[rank]
for num in [lnum, rnum]
],
dtype=new_points.dtype,
)

mpi_dtype, _ = get_mpi_dtype(np.dtype(IntType))
mpi_op = MPI.REPLACE
args = (mpi_dtype, rootdata, new_points, mpi_op)
sf.bcastBegin(*args)
sf.bcastEnd(*args)

for i in range(sum(npoints_to_recv)):
new_sparsity[ax1.part.label, ax2.part.label].add(
(new_points[2 * i], new_points[2 * i + 1])
)

# import pdb; pdb.set_trace()
return new_sparsity
32 changes: 0 additions & 32 deletions pyop3/axtree/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,38 +114,6 @@ def step_size(
return 1


def make_star_forest_per_axis_part(part, comm):
if part.is_distributed:
# we have a root if a point is shared but doesn't point to another rank
nroots = len(
[pt for pt in part.overlap if isinstance(pt, Shared) and not pt.root]
)

# which local points are leaves?
local_points = [
i for i, pt in enumerate(part.overlap) if not is_owned_by_process(pt)
]

# roots of other processes (rank, index)
remote_points = utils.flatten(
[pt.root.as_tuple() for pt in part.overlap if not is_owned_by_process(pt)]
)

# import pdb; pdb.set_trace()

sf = PETSc.SF().create(comm)
sf.setGraph(nroots, local_points, remote_points)
return sf
else:
raise NotImplementedError(
"Need to think about concatenating star forests. This will happen if mixed."
)


def attach_owned_star_forest(axis):
raise NotImplementedError


def has_halo(axes, axis):
if axis.sf is not None:
return True
Expand Down
27 changes: 18 additions & 9 deletions pyop3/axtree/parallel.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
from __future__ import annotations

import functools

import numpy as np
from mpi4py import MPI
from petsc4py import PETSc
from pyrsistent import pmap

from pyop3.axtree.layout import _as_int, _axis_component_size, step_size
from pyop3.dtypes import IntType, get_mpi_dtype
from pyop3.extras.debug import print_with_rank
from pyop3.dtypes import IntType, as_numpy_dtype, get_mpi_dtype
from pyop3.utils import checked_zip, just_one, strict_int


def reduction_op(op, invec, inoutvec, datatype):
dtype = as_numpy_dtype(datatype)
invec = np.frombuffer(invec, dtype=dtype)
inoutvec = np.frombuffer(inoutvec, dtype=dtype)
inoutvec[:] = op(invec, inoutvec)


_contig_min_op = MPI.Op.Create(
functools.partial(reduction_op, np.minimum), commute=True
)
_contig_max_op = MPI.Op.Create(
functools.partial(reduction_op, np.maximum), commute=True
)


def partition_ghost_points(axis, sf):
npoints = sf.size
is_owned = np.full(npoints, True, dtype=bool)
Expand Down Expand Up @@ -114,8 +130,6 @@ def grow_dof_sf(axes, axis, path, indices):
)
root_offsets[pt] = offset

print_with_rank("root offsets before", root_offsets)

point_sf.broadcast(root_offsets, MPI.REPLACE)

# for sanity reasons remove the original root values from the buffer
Expand Down Expand Up @@ -158,9 +172,4 @@ def grow_dof_sf(axes, axis, path, indices):
remote_leaf_dof_offsets[counter] = [rank, root_offsets[pos] + d]
counter += 1

print_with_rank("root offsets: ", root_offsets)
print_with_rank("local leaf offsets", local_leaf_offsets)
print_with_rank("local dof offsets: ", local_leaf_dof_offsets)
print_with_rank("remote offsets: ", remote_leaf_dof_offsets)

return (nroots, local_leaf_dof_offsets, remote_leaf_dof_offsets)
Loading
Loading