Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix older pythons #12

Merged
merged 3 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,8 @@ jobs:
shell: bash
working-directory: pyop3
run: |
# Gross:
pip install toml
python scripts/requirements.py build | pip install -r /dev/stdin
pip install --no-build-isolation .
python scripts/requirements.py run | pip install -r /dev/stdin
pip install pytest pytest-cov pytest-timeout pytest-xdist pytest-timeout
pip install .

- name: Run tests
shell: bash
Expand All @@ -92,5 +88,3 @@ jobs:
-n 12 --dist worksteal \
-v tests
timeout-minutes: 10


96 changes: 50 additions & 46 deletions pyop3/axes/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def with_context(self, context):
key = {}
for loop_index, path in context.items():
if loop_index in self.keys:
key |= {loop_index: path}
key.update({loop_index: path})
key = pmap(key)
return self.context_map[key]

Expand Down Expand Up @@ -474,9 +474,10 @@ def _collect_datamap(axis, *subdatamaps, axes):
datamap = {}
for cidx, component in enumerate(axis.components):
if isinstance(count := component.count, MultiArray):
datamap |= count.datamap
datamap.update(count.datamap)

return datamap | merge_dicts(subdatamaps)
datamap.update(merge_dicts(subdatamaps))
return datamap


class AxisComponent(LabelledImmutableRecord):
Expand Down Expand Up @@ -516,7 +517,7 @@ class AxisComponent(LabelledImmutableRecord):
def __init__(
self,
count,
label: Hashable | None = None,
label: Optional[Hashable] = None,
*,
indices=None,
overlap=None,
Expand Down Expand Up @@ -623,10 +624,10 @@ class Axis(StrictLabelledNode, LoopIterable):

def __init__(
self,
components: Sequence[AxisComponent] | AxisComponent | int,
label: Hashable | None = None,
components: Union[Sequence[AxisComponent], AxisComponent, int],
label: Optional[Hashable] = None,
*,
permutation: Sequence[int] | None = None,
permutation: Optional[Sequence[int]] = None,
**kwargs,
):
components = tuple(_as_axis_component(cpt) for cpt in as_tuple(components))
Expand Down Expand Up @@ -745,8 +746,8 @@ class AxisTree(StrictLabelledTree, LoopIterable, ContextFree):
# fields = StrictLabelledTree.fields | {"target_paths", "index_exprs", "layout_exprs", "orig_axes", "sf", "shared_sf", "comm"}
def __init__(
self,
root: MultiAxis | None = None,
parent_to_children: dict | None = None,
root: Optional[MultiAxis] = None,
parent_to_children: Optional[Dict] = None,
*,
target_paths=None,
index_exprs=None,
Expand Down Expand Up @@ -914,11 +915,9 @@ def parse_bits(
if target_axis.id in new_visited_target_axes:
continue
new_visited_target_axes |= {target_axis.id}
new_target_path_per_cpt[
axis.id, component.label
] |= self.target_path_per_component[
target_axis.id, target_cpt.label
]
new_target_path_per_cpt[axis.id, component.label].update(
self.target_path_per_component[target_axis.id, target_cpt.label]
)

# do a replacement
orig_index_exprs = self.index_exprs_per_component[
Expand Down Expand Up @@ -953,9 +952,9 @@ def parse_bits(
partial_layout_exprs=new_partial_layout_exprs,
visited_target_axes=new_visited_target_axes,
)
new_target_path_per_cpt |= retval[0]
new_index_exprs_per_cpt |= retval[1]
new_layout_exprs_per_cpt |= retval[2]
new_target_path_per_cpt.update(retval[0])
new_index_exprs_per_cpt.update(retval[1])
new_layout_exprs_per_cpt.update(retval[2])

else:
pass
Expand Down Expand Up @@ -1061,17 +1060,16 @@ def datamap(self) -> dict[str:DistributedArray]:
for cleverdict in [self.layouts, self.orig_layout_fn]:
for layout in cleverdict.values():
for array in MultiArrayCollector()(layout):
dmap |= array.datamap
dmap.update(array.datamap)

# TODO
# for cleverdict in [self.index_exprs, self.layout_exprs]:
for cleverdict in [self.index_exprs_per_component]:
for exprs in cleverdict.values():
for expr in exprs.values():
for array in MultiArrayCollector()(expr):
dmap |= array.datamap
# breakpoint()
return dmap
dmap.update(array.datamap)
return pmap(dmap)

def _make_target_paths(self):
return tuple(self.path(ax, cpt) for ax, cpt in self.leaves)
Expand Down Expand Up @@ -1191,8 +1189,8 @@ def leaf_component(self):
return self.leaf[1]

def child(
self, parent: Axis, component: AxisComponent | ComponentLabel
) -> Axis | None:
self, parent: Axis, component: Union[AxisComponent, ComponentLabel]
) -> Optional[Axis]:
cpt_label = _as_axis_component_label(component)
return super().child(parent, cpt_label)

Expand Down Expand Up @@ -1540,7 +1538,7 @@ def _compute_layouts(
)
sublayoutss.append(sublayouts)
csubtrees.append(csubtree)
steps |= substeps
steps.update(substeps)
else:
csubtrees.append(None)
sublayoutss.append(collections.defaultdict(list))
Expand Down Expand Up @@ -1584,26 +1582,29 @@ def _compute_layouts(
ctree = None
for c in axis.components:
step = step_size(axes, axis, c)
layouts |= {
path
# | {axis.label: c.label}: AffineLayout(axis.label, c.label, step)
| {axis.label: c.label}: AxisVariable(axis.label) * step
}
layouts.update(
{
path
# | {axis.label: c.label}: AffineLayout(axis.label, c.label, step)
| {axis.label: c.label}: AxisVariable(axis.label) * step
}
)

else:
croot = CustomNode(
[(cpt.count, axis.label, cpt.label) for cpt in axis.components]
)
if strictly_all(sub is not None for sub in csubtrees):
cparent_to_children = {
croot.id: [sub.root for sub in csubtrees]
} | merge_dicts(sub.parent_to_children for sub in csubtrees)
cparent_to_children = pmap(
{croot.id: [sub.root for sub in csubtrees]}
) | merge_dicts(sub.parent_to_children for sub in csubtrees)
else:
cparent_to_children = {}
ctree = StrictLabelledTree(croot, cparent_to_children)

# layouts and steps are just propagated from below
return layouts | merge_dicts(sublayoutss), ctree, steps
layouts.update(merge_dicts(sublayoutss))
return layouts, ctree, steps

# 2. add layouts here
else:
Expand All @@ -1623,9 +1624,9 @@ def _compute_layouts(
bits.append((cpt.count, axlabel, clabel))
croot = CustomNode(bits)
if strictly_all(sub is not None for sub in csubtrees):
cparent_to_children = {
croot.id: [sub.root for sub in csubtrees]
} | merge_dicts(sub.parent_to_children for sub in csubtrees)
cparent_to_children = pmap(
{croot.id: [sub.root for sub in csubtrees]}
) | merge_dicts(sub.parent_to_children for sub in csubtrees)
else:
cparent_to_children = {}
ctree = StrictLabelledTree(croot, cparent_to_children)
Expand All @@ -1641,7 +1642,8 @@ def _compute_layouts(
ctree = None
steps = {path: _axis_size(axes, axis)}

return layouts | merge_dicts(sublayoutss), ctree, steps
layouts.update(merge_dicts(sublayoutss))
return layouts, ctree, steps

# must therefore be affine
else:
Expand All @@ -1661,7 +1663,7 @@ def _compute_layouts(
sublayouts[path | {axis.label: mycomponent.label}] = new_layout
start += _axis_component_size(axes, axis, mycomponent)

layouts |= sublayouts
layouts.update(sublayouts)
steps = {path: _axis_size(axes, axis)}
return layouts, None, steps

Expand Down Expand Up @@ -1698,11 +1700,13 @@ def _create_count_array_tree(
)
arrays[new_path] = countarray
else:
arrays |= _create_count_array_tree(
ctree,
child,
counts | current_node.counts[cidx],
new_path,
arrays.update(
_create_count_array_tree(
ctree,
child,
counts | current_node.counts[cidx],
new_path,
)
)

return arrays
Expand Down Expand Up @@ -1790,7 +1794,7 @@ def _tabulate_count_array_tree(
def _collect_at_leaves(
axes,
values,
axis: Axis | None = None,
axis: Optional[Axis] = None,
path=pmap(),
prior=0,
):
Expand All @@ -1804,7 +1808,7 @@ def _collect_at_leaves(
else:
prior_ = prior
if subaxis := axes.child(axis, cpt):
acc |= _collect_at_leaves(axes, values, subaxis, new_path, prior_)
acc.update(_collect_at_leaves(axes, values, subaxis, new_path, prior_))
else:
acc[new_path] = prior_

Expand Down Expand Up @@ -1877,7 +1881,7 @@ def _(arg: numbers.Real, path: Mapping, indices: Mapping):

def _path_and_indices_from_index_tuple(
axes, index_tuple
) -> tuple[pmap[Label, Label], pmap[Label, int]]:
) -> Tuple[pmap[Label, Label], pmap[Label, int]]:
path = pmap()
indices = pmap()
axis = axes.root
Expand Down
2 changes: 1 addition & 1 deletion pyop3/codegen/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def map_called_map(self, expr):
map_array,
pmap({rootaxis.label: just_one(rootaxis.components).label})
| pmap({inner_axis.label: inner_cpt.label}),
{rootaxis.label: inner_expr[0]} | {inner_axis.label: inner_expr[1]},
{rootaxis.label: inner_expr[0], inner_axis.label: inner_expr[1]},
self._codegen_context,
)
return jname_expr
Expand Down
11 changes: 5 additions & 6 deletions pyop3/distarray/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,12 @@ def data_wo(self):

@functools.cached_property
def datamap(self) -> dict[str:DistributedArray]:
# FIXME when we use proper index trees
# return {self.name: self} | self.axes.datamap | merge_dicts([idxs.datamap for idxs in self.indicess])
return (
{self.name: self}
| self.axes.datamap
| merge_dicts([idx.datamap for idxs in self.indicess for idx in idxs])
datamap = {self.name: self}
datamap.update(self.axes.datamap)
datamap.update(
merge_dicts([idx.datamap for idxs in self.indicess for idx in idxs])
)
return datamap

@property
def alloc_size(self):
Expand Down
8 changes: 6 additions & 2 deletions pyop3/extras/debug.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Optional, Union

from mpi4py import MPI
from petsc4py import PETSc


def print_with_rank(*args, comm: PETSc.Comm | MPI.Comm | None = None) -> None:
def print_with_rank(*args, comm: Optional[Union[PETSc.Comm, MPI.Comm]] = None) -> None:
comm = comm or PETSc.Sys.getDefaultComm()
print(f"[rank {comm.rank}] : ", *args, sep="", flush=True)


def print_if_rank(rank: int, *args, comm: PETSc.Comm | MPI.Comm | None = None) -> None:
def print_if_rank(
rank: int, *args, comm: Optional[Union[PETSc.Comm, MPI.Comm]] = None
) -> None:
comm = comm or PETSc.Sys.getDefaultComm()
if rank == comm.rank:
print(*args, flush=True)
Loading