Skip to content

Commit

Permalink
More cleanup, tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Dec 8, 2023
1 parent 052fba6 commit 328973e
Show file tree
Hide file tree
Showing 13 changed files with 1 addition and 943 deletions.
5 changes: 0 additions & 5 deletions pyop3/array/harray.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
)
from pyop3.buffer import Buffer, DistributedBuffer
from pyop3.dtypes import IntType, ScalarType, get_mpi_dtype
from pyop3.extras.debug import print_if_rank, print_with_rank
from pyop3.itree import IndexTree, as_index_forest, index_axes
from pyop3.itree.tree import CalledMapVariable, collect_loop_indices, iter_axis_tree
from pyop3.lang import KernelArgument
Expand Down Expand Up @@ -342,14 +341,10 @@ def offset(self, *args, allow_unused=False, insert_zeros=False):
return strict_int(offset)

def simple_offset(self, path, indices):
print_if_rank(0, "self.layouts", self.layouts)
print_if_rank(0, "path", path)
print_if_rank(0, "indices", indices)
offset = pym.evaluate(self.layouts[path], indices, ExpressionEvaluator)
return strict_int(offset)

def iter_indices(self, outer_map):
print_with_rank(0, "myiexpr!!!!!!!!!!!!!!!!!!", self.index_exprs)
return iter_axis_tree(self.axes, self.target_paths, self.index_exprs, outer_map)

def _with_axes(self, axes):
Expand Down
8 changes: 0 additions & 8 deletions pyop3/axtree/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from pyop3.axtree.layout import _as_int, _axis_component_size, step_size
from pyop3.dtypes import IntType, as_numpy_dtype, get_mpi_dtype
from pyop3.extras.debug import print_with_rank
from pyop3.utils import checked_zip, just_one, strict_int


Expand Down Expand Up @@ -131,8 +130,6 @@ def grow_dof_sf(axes, axis, path, indices):
)
root_offsets[pt] = offset

print_with_rank("root offsets before", root_offsets)

point_sf.broadcast(root_offsets, MPI.REPLACE)

# for sanity reasons remove the original root values from the buffer
Expand Down Expand Up @@ -175,9 +172,4 @@ def grow_dof_sf(axes, axis, path, indices):
remote_leaf_dof_offsets[counter] = [rank, root_offsets[pos] + d]
counter += 1

print_with_rank("root offsets: ", root_offsets)
print_with_rank("local leaf offsets", local_leaf_offsets)
print_with_rank("local dof offsets: ", local_leaf_dof_offsets)
print_with_rank("remote offsets: ", remote_leaf_dof_offsets)

return (nroots, local_leaf_dof_offsets, remote_leaf_dof_offsets)
93 changes: 1 addition & 92 deletions pyop3/axtree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from pyop3 import utils
from pyop3.dtypes import IntType, PointerType, get_mpi_dtype
from pyop3.extras.debug import print_if_rank, print_with_rank
from pyop3.sf import StarForest
from pyop3.tree import (
LabelledNodeComponent,
Expand Down Expand Up @@ -187,15 +186,9 @@ def map_called_map(self, expr):
# the inner_expr tells us the right mapping for the temporary, however,
# for maps that are arrays the innermost axis label does not always match
# the label used by the temporary. Therefore we need to do a swap here.
# I don't like this.
# print_if_rank(0, repr(array.axes))
# print_if_rank(0, "before: ",indices)
inner_axis = array.axes.leaf_axis
indices[inner_axis.label] = indices.pop(expr.function.full_map.name)

# print_if_rank(0, "after:",indices)
# print_if_rank(0, repr(expr))
# print_if_rank(0, self.context)
return array.get_value(path, indices)


Expand Down Expand Up @@ -580,55 +573,6 @@ def add_node(
parent_cpt_label = _as_axis_component_label(parent_component)
return super().add_node(axis, parent, parent_cpt_label, **kwargs)

# alias
add_subaxis = add_node

# currently untested but should keep
@classmethod
def from_layout(cls, layout: Sequence[ConstrainedMultiAxis]) -> Any: # TODO
return order_axes(layout)

# TODO this is just a regular tree search
@deprecated(internal=True) # I think?
def get_part_from_path(self, path, axis=None):
axis = axis or self.root

label, *sublabels = path

(component, component_index) = just_one(
[
(cpt, cidx)
for cidx, cpt in enumerate(axis.components)
if (axis.label, cidx) == label
]
)
if sublabels:
return self.get_part_from_path(
sublabels, self.component_child(axis, component)
)
else:
return axis, component

@deprecated(internal=True)
def drop_last(self):
"""Remove the last subaxis"""
if not self.part.subaxis:
return None
else:
return self.copy(
parts=[self.part.copy(subaxis=self.part.subaxis.drop_last())]
)

@property
@deprecated(internal=True)
def is_linear(self):
"""Return ``True`` if the multi-axis contains no branches at any level."""
if self.nparts == 1:
return self.part.subaxis.is_linear if self.part.subaxis else True
else:
return False

@deprecated()
def add_subaxis(self, subaxis, *loc):
return self.add_node(subaxis, *loc)

Expand Down Expand Up @@ -657,8 +601,6 @@ class AxisTree(PartialAxisTree, Indexed, ContextFreeLoopIterable):
"target_paths",
"index_exprs",
"layout_exprs",
"layouts",
"sf",
}

def __init__(
Expand All @@ -667,7 +609,6 @@ def __init__(
target_paths=None,
index_exprs=None,
layout_exprs=None,
sf=None,
):
if some_but_not_all(
arg is None for arg in [target_paths, index_exprs, layout_exprs]
Expand All @@ -678,7 +619,6 @@ def __init__(
self._target_paths = target_paths or self._default_target_paths()
self._index_exprs = index_exprs or self._default_index_exprs()
self.layout_exprs = layout_exprs or self._default_layout_exprs()
self.sf = sf or self._default_sf()

def __getitem__(self, indices):
from pyop3.itree.tree import as_index_forest, collect_loop_contexts, index_axes
Expand Down Expand Up @@ -762,7 +702,7 @@ def layouts(self):

@cached_property
def sf(self):
return cls._default_sf(tree)
return self._default_sf()

@cached_property
def datamap(self):
Expand All @@ -771,17 +711,6 @@ def datamap(self):
else:
dmap = postvisit(self, _collect_datamap, axes=self)

# for cleverdict in [self.layouts, self.orig_layout_fn]:
# for layout in cleverdict.values():
# for layout_expr in layout.values():
# # catch invalid layouts
# if isinstance(layout_expr, pym.primitives.NaN):
# continue
# for array in MultiArrayCollector()(layout_expr):
# dmap.update(array.datamap)

# TODO
# for cleverdict in [self.index_exprs, self.layout_exprs]:
for cleverdict in [self.index_exprs]:
for exprs in cleverdict.values():
for expr in exprs.values():
Expand Down Expand Up @@ -939,26 +868,6 @@ def datamap(self):
return merge_dicts(axes.datamap for axes in self.context_map.values())


@dataclasses.dataclass(frozen=True)
class Path:
# TODO Make a persistent dict?
from_axes: Tuple[Any] # axis part IDs I guess (or labels)
to_axess: Tuple[Any] # axis part IDs I guess (or labels)
arity: int
selector: Optional[Any] = None
"""The thing that chooses between the different possible output axes at runtime."""

@property
def degree(self):
return len(self.to_axess)

@property
def to_axes(self):
if self.degree != 1:
raise RuntimeError("Only for degree 1 paths")
return self.to_axess[0]


@functools.singledispatch
def as_axis_tree(arg: Any):
from pyop3.array import HierarchicalArray # cyclic import
Expand Down
1 change: 0 additions & 1 deletion pyop3/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from mpi4py import MPI

from pyop3.dtypes import ScalarType
from pyop3.extras.debug import print_if_rank
from pyop3.lang import KernelArgument
from pyop3.utils import UniqueNameGenerator, as_tuple, deprecated, readonly

Expand Down
45 changes: 0 additions & 45 deletions pyop3/ir/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from pyop3.axtree.tree import ContextSensitiveAxisTree
from pyop3.buffer import DistributedBuffer, PackedBuffer
from pyop3.dtypes import IntType, PointerType
from pyop3.extras.debug import print_with_rank
from pyop3.itree import (
AffineSliceComponent,
CalledMap,
Expand Down Expand Up @@ -442,11 +441,6 @@ def parse_loop_properly_this_time(
# these aren't jnames!
my_index_exprs = axes.index_exprs.get((axis.id, component.label), {})

print_with_rank("myindexexprs", my_index_exprs)
print_with_rank("new_iname_rplacemap", new_iname_replace_map)
print_with_rank("jname_replace_map", jname_replace_map)
print_with_rank("outerreplac", outer_replace_map)

jname_extras = {}
for axis_label, index_expr in my_index_exprs.items():
jname_expr = JnameSubstitutor(
Expand Down Expand Up @@ -854,10 +848,6 @@ def array_expr():
array_ = array.with_context(context)
return make_array_expr(
array,
# I think...
# not calling substitute layouts from above so loop indices not
# present in the layout...
# subst_layout(axes, source_path, target_path),
array_.layouts[target_path],
target_path,
iname_replace_map | jname_replace_map,
Expand Down Expand Up @@ -919,14 +909,6 @@ def make_temp_expr(temporary, shape, path, jnames, ctx):
return pym.subscript(pym.var(temporary.name), extra_indices + (temp_offset_var,))


def subst_layout(axes, source_path, target_path):
replace_map = {}
for axis, cpt in axes.detailed_path(source_path).items():
replace_map.update(axes.layout_exprs[axis.id, cpt])

return IndexExpressionReplacer(replace_map)(axes.layouts[target_path])


class JnameSubstitutor(pym.mapper.IdentityMapper):
def __init__(self, replace_map, codegen_context):
self._labels_to_jnames = replace_map
Expand Down Expand Up @@ -1117,17 +1099,6 @@ def map_variable(self, expr):
return self._replace_map.get(expr.name, expr)


def collect_arrays(expr: pym.primitives.Expr):
collector = MultiArrayCollector()
return collector(expr)


def replace_variables(
expr: pym.primitives.Expr, replace_map: dict[str, pym.primitives.Variable]
):
return VariableReplacer(replace_map)(expr)


def _scalar_assignment(
array,
path,
Expand All @@ -1146,22 +1117,6 @@ def _scalar_assignment(
return rexpr


def find_axis(axes, path, target, current_axis=None):
"""Return the axis matching ``target`` along ``path``.
``path`` is a mapping between axis labels and the selected component indices.
"""
current_axis = current_axis or axes.root

if current_axis.label == target:
return current_axis
else:
subaxis = axes.child(current_axis, path[current_axis.label])
if not subaxis:
assert False, "oops"
return find_axis(axes, path, target, subaxis)


def context_from_indices(loop_indices):
loop_context = {}
for loop_index, (path, _) in loop_indices.items():
Expand Down
Loading

0 comments on commit 328973e

Please sign in to comment.