Skip to content

Commit

Permalink
Fix older pythons (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward authored Sep 28, 2023
1 parent 9d7368d commit 8526bed
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 197 deletions.
8 changes: 1 addition & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,8 @@ jobs:
shell: bash
working-directory: pyop3
run: |
# Gross:
pip install toml
python scripts/requirements.py build | pip install -r /dev/stdin
pip install --no-build-isolation .
python scripts/requirements.py run | pip install -r /dev/stdin
pip install pytest pytest-cov pytest-timeout pytest-xdist pytest-timeout
pip install .
- name: Run tests
shell: bash
Expand All @@ -92,5 +88,3 @@ jobs:
-n 12 --dist worksteal \
-v tests
timeout-minutes: 10


96 changes: 50 additions & 46 deletions pyop3/axes/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def with_context(self, context):
key = {}
for loop_index, path in context.items():
if loop_index in self.keys:
key |= {loop_index: path}
key.update({loop_index: path})
key = pmap(key)
return self.context_map[key]

Expand Down Expand Up @@ -474,9 +474,10 @@ def _collect_datamap(axis, *subdatamaps, axes):
datamap = {}
for cidx, component in enumerate(axis.components):
if isinstance(count := component.count, MultiArray):
datamap |= count.datamap
datamap.update(count.datamap)

return datamap | merge_dicts(subdatamaps)
datamap.update(merge_dicts(subdatamaps))
return datamap


class AxisComponent(LabelledImmutableRecord):
Expand Down Expand Up @@ -516,7 +517,7 @@ class AxisComponent(LabelledImmutableRecord):
def __init__(
self,
count,
label: Hashable | None = None,
label: Optional[Hashable] = None,
*,
indices=None,
overlap=None,
Expand Down Expand Up @@ -623,10 +624,10 @@ class Axis(StrictLabelledNode, LoopIterable):

def __init__(
self,
components: Sequence[AxisComponent] | AxisComponent | int,
label: Hashable | None = None,
components: Union[Sequence[AxisComponent], AxisComponent, int],
label: Optional[Hashable] = None,
*,
permutation: Sequence[int] | None = None,
permutation: Optional[Sequence[int]] = None,
**kwargs,
):
components = tuple(_as_axis_component(cpt) for cpt in as_tuple(components))
Expand Down Expand Up @@ -745,8 +746,8 @@ class AxisTree(StrictLabelledTree, LoopIterable, ContextFree):
# fields = StrictLabelledTree.fields | {"target_paths", "index_exprs", "layout_exprs", "orig_axes", "sf", "shared_sf", "comm"}
def __init__(
self,
root: MultiAxis | None = None,
parent_to_children: dict | None = None,
root: Optional[MultiAxis] = None,
parent_to_children: Optional[Dict] = None,
*,
target_paths=None,
index_exprs=None,
Expand Down Expand Up @@ -914,11 +915,9 @@ def parse_bits(
if target_axis.id in new_visited_target_axes:
continue
new_visited_target_axes |= {target_axis.id}
new_target_path_per_cpt[
axis.id, component.label
] |= self.target_path_per_component[
target_axis.id, target_cpt.label
]
new_target_path_per_cpt[axis.id, component.label].update(
self.target_path_per_component[target_axis.id, target_cpt.label]
)

# do a replacement
orig_index_exprs = self.index_exprs_per_component[
Expand Down Expand Up @@ -953,9 +952,9 @@ def parse_bits(
partial_layout_exprs=new_partial_layout_exprs,
visited_target_axes=new_visited_target_axes,
)
new_target_path_per_cpt |= retval[0]
new_index_exprs_per_cpt |= retval[1]
new_layout_exprs_per_cpt |= retval[2]
new_target_path_per_cpt.update(retval[0])
new_index_exprs_per_cpt.update(retval[1])
new_layout_exprs_per_cpt.update(retval[2])

else:
pass
Expand Down Expand Up @@ -1061,17 +1060,16 @@ def datamap(self) -> dict[str:DistributedArray]:
for cleverdict in [self.layouts, self.orig_layout_fn]:
for layout in cleverdict.values():
for array in MultiArrayCollector()(layout):
dmap |= array.datamap
dmap.update(array.datamap)

# TODO
# for cleverdict in [self.index_exprs, self.layout_exprs]:
for cleverdict in [self.index_exprs_per_component]:
for exprs in cleverdict.values():
for expr in exprs.values():
for array in MultiArrayCollector()(expr):
dmap |= array.datamap
# breakpoint()
return dmap
dmap.update(array.datamap)
return pmap(dmap)

def _make_target_paths(self):
return tuple(self.path(ax, cpt) for ax, cpt in self.leaves)
Expand Down Expand Up @@ -1191,8 +1189,8 @@ def leaf_component(self):
return self.leaf[1]

def child(
self, parent: Axis, component: AxisComponent | ComponentLabel
) -> Axis | None:
self, parent: Axis, component: Union[AxisComponent, ComponentLabel]
) -> Optional[Axis]:
cpt_label = _as_axis_component_label(component)
return super().child(parent, cpt_label)

Expand Down Expand Up @@ -1540,7 +1538,7 @@ def _compute_layouts(
)
sublayoutss.append(sublayouts)
csubtrees.append(csubtree)
steps |= substeps
steps.update(substeps)
else:
csubtrees.append(None)
sublayoutss.append(collections.defaultdict(list))
Expand Down Expand Up @@ -1584,26 +1582,29 @@ def _compute_layouts(
ctree = None
for c in axis.components:
step = step_size(axes, axis, c)
layouts |= {
path
# | {axis.label: c.label}: AffineLayout(axis.label, c.label, step)
| {axis.label: c.label}: AxisVariable(axis.label) * step
}
layouts.update(
{
path
# | {axis.label: c.label}: AffineLayout(axis.label, c.label, step)
| {axis.label: c.label}: AxisVariable(axis.label) * step
}
)

else:
croot = CustomNode(
[(cpt.count, axis.label, cpt.label) for cpt in axis.components]
)
if strictly_all(sub is not None for sub in csubtrees):
cparent_to_children = {
croot.id: [sub.root for sub in csubtrees]
} | merge_dicts(sub.parent_to_children for sub in csubtrees)
cparent_to_children = pmap(
{croot.id: [sub.root for sub in csubtrees]}
) | merge_dicts(sub.parent_to_children for sub in csubtrees)
else:
cparent_to_children = {}
ctree = StrictLabelledTree(croot, cparent_to_children)

# layouts and steps are just propagated from below
return layouts | merge_dicts(sublayoutss), ctree, steps
layouts.update(merge_dicts(sublayoutss))
return layouts, ctree, steps

# 2. add layouts here
else:
Expand All @@ -1623,9 +1624,9 @@ def _compute_layouts(
bits.append((cpt.count, axlabel, clabel))
croot = CustomNode(bits)
if strictly_all(sub is not None for sub in csubtrees):
cparent_to_children = {
croot.id: [sub.root for sub in csubtrees]
} | merge_dicts(sub.parent_to_children for sub in csubtrees)
cparent_to_children = pmap(
{croot.id: [sub.root for sub in csubtrees]}
) | merge_dicts(sub.parent_to_children for sub in csubtrees)
else:
cparent_to_children = {}
ctree = StrictLabelledTree(croot, cparent_to_children)
Expand All @@ -1641,7 +1642,8 @@ def _compute_layouts(
ctree = None
steps = {path: _axis_size(axes, axis)}

return layouts | merge_dicts(sublayoutss), ctree, steps
layouts.update(merge_dicts(sublayoutss))
return layouts, ctree, steps

# must therefore be affine
else:
Expand All @@ -1661,7 +1663,7 @@ def _compute_layouts(
sublayouts[path | {axis.label: mycomponent.label}] = new_layout
start += _axis_component_size(axes, axis, mycomponent)

layouts |= sublayouts
layouts.update(sublayouts)
steps = {path: _axis_size(axes, axis)}
return layouts, None, steps

Expand Down Expand Up @@ -1698,11 +1700,13 @@ def _create_count_array_tree(
)
arrays[new_path] = countarray
else:
arrays |= _create_count_array_tree(
ctree,
child,
counts | current_node.counts[cidx],
new_path,
arrays.update(
_create_count_array_tree(
ctree,
child,
counts | current_node.counts[cidx],
new_path,
)
)

return arrays
Expand Down Expand Up @@ -1790,7 +1794,7 @@ def _tabulate_count_array_tree(
def _collect_at_leaves(
axes,
values,
axis: Axis | None = None,
axis: Optional[Axis] = None,
path=pmap(),
prior=0,
):
Expand All @@ -1804,7 +1808,7 @@ def _collect_at_leaves(
else:
prior_ = prior
if subaxis := axes.child(axis, cpt):
acc |= _collect_at_leaves(axes, values, subaxis, new_path, prior_)
acc.update(_collect_at_leaves(axes, values, subaxis, new_path, prior_))
else:
acc[new_path] = prior_

Expand Down Expand Up @@ -1877,7 +1881,7 @@ def _(arg: numbers.Real, path: Mapping, indices: Mapping):

def _path_and_indices_from_index_tuple(
axes, index_tuple
) -> tuple[pmap[Label, Label], pmap[Label, int]]:
) -> Tuple[pmap[Label, Label], pmap[Label, int]]:
path = pmap()
indices = pmap()
axis = axes.root
Expand Down
2 changes: 1 addition & 1 deletion pyop3/codegen/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def map_called_map(self, expr):
map_array,
pmap({rootaxis.label: just_one(rootaxis.components).label})
| pmap({inner_axis.label: inner_cpt.label}),
{rootaxis.label: inner_expr[0]} | {inner_axis.label: inner_expr[1]},
{rootaxis.label: inner_expr[0], inner_axis.label: inner_expr[1]},
self._codegen_context,
)
return jname_expr
Expand Down
11 changes: 5 additions & 6 deletions pyop3/distarray/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,12 @@ def data_wo(self):

@functools.cached_property
def datamap(self) -> dict[str:DistributedArray]:
# FIXME when we use proper index trees
# return {self.name: self} | self.axes.datamap | merge_dicts([idxs.datamap for idxs in self.indicess])
return (
{self.name: self}
| self.axes.datamap
| merge_dicts([idx.datamap for idxs in self.indicess for idx in idxs])
datamap = {self.name: self}
datamap.update(self.axes.datamap)
datamap.update(
merge_dicts([idx.datamap for idxs in self.indicess for idx in idxs])
)
return datamap

@property
def alloc_size(self):
Expand Down
8 changes: 6 additions & 2 deletions pyop3/extras/debug.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Optional, Union

from mpi4py import MPI
from petsc4py import PETSc


def print_with_rank(*args, comm: PETSc.Comm | MPI.Comm | None = None) -> None:
def print_with_rank(*args, comm: Optional[Union[PETSc.Comm, MPI.Comm]] = None) -> None:
comm = comm or PETSc.Sys.getDefaultComm()
print(f"[rank {comm.rank}] : ", *args, sep="", flush=True)


def print_if_rank(rank: int, *args, comm: PETSc.Comm | MPI.Comm | None = None) -> None:
def print_if_rank(
rank: int, *args, comm: Optional[Union[PETSc.Comm, MPI.Comm]] = None
) -> None:
comm = comm or PETSc.Sys.getDefaultComm()
if rank == comm.rank:
print(*args, flush=True)
Loading

0 comments on commit 8526bed

Please sign in to comment.