Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed May 7, 2024
1 parent f86f39c commit efa23a2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 30 deletions.
39 changes: 9 additions & 30 deletions pyop3/axtree/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from pyop3.dtypes import IntType
from pyop3.utils import (
StrictlyUniqueDict,
as_tuple,
checked_zip,
just_one,
Expand Down Expand Up @@ -85,11 +86,10 @@ def _make_layout_per_axis_component(
inner_loop_vars = frozenset()
inner_loop_vars_with_self = _collect_inner_loop_vars(axes, axis, loop_vars)

layouts = {}
layouts = StrictlyUniqueDict()

# Post-order traversal
csubtrees = []
sublayoutss = []
for cpt in axis.components:
layout_path_ = layout_path | {axis.label: cpt.label}

Expand All @@ -100,11 +100,10 @@ def _make_layout_per_axis_component(
) = _make_layout_per_axis_component(
axes, loop_vars, subaxis, layout_path_,
)
sublayoutss.append(sublayouts)
layouts.update(sublayouts)
csubtrees.append(csubtree)
else:
csubtrees.append(None)
sublayoutss.append(defaultdict(list))

"""
There are two conditions that we need to worry about:
Expand Down Expand Up @@ -143,7 +142,7 @@ def _make_layout_per_axis_component(
) or (has_halo(axes, axis) and axis != axes.root):
if has_halo(axes, axis) or not all(
has_constant_step(axes, axis, c, inner_loop_vars)
for i, c in enumerate(axis.components)
for c in axis.components
):
ctree = AxisTree(axis.copy(numbering=None))

Expand All @@ -157,19 +156,15 @@ def _make_layout_per_axis_component(
# add to shape of things
# in theory if we are ragged and permuted then we do want to include this level
ctree = None
for i, c in enumerate(axis.components):
for c in axis.components:
step = step_size(axes, axis, c)
if (axis.id, c.label) in loop_vars:
axis_var = loop_vars[axis.id, c.label][axis.label]
else:
axis_var = AxisVariable(axis.label)
layouts.update({layout_path | {axis.label: c.label}: axis_var * step})

layouts.update(merge_dicts(sublayoutss))
return (
layouts,
ctree,
)
return (layouts, ctree)

# 2. add layouts here
else:
Expand Down Expand Up @@ -215,10 +210,6 @@ def _make_layout_per_axis_component(
)

for subpath, offset_data in fulltree.items():
# offset_data must be linear so we can unroll the indices
# flat_indices = {
# ax: expr
# }
source_path = offset_data.axes.path_with_nodes(*offset_data.axes.leaf)
index_keys = [None] + [
(axis.id, cpt) for axis, cpt in source_path.items()
Expand All @@ -232,38 +223,26 @@ def _make_layout_per_axis_component(
offset_var = ArrayVar(offset_data, myindices, mytargetpath)

layouts[layout_path | subpath] = offset_var
ctree = None

layouts.update(merge_dicts(sublayoutss))
return (
layouts,
ctree,
)
return (layouts, None)

# must therefore be affine
else:
assert all(sub is None for sub in csubtrees)
layouts = {}
steps = [
step_size(axes, axis, c)
for i, c in enumerate(axis.components)
]
start = 0
for cidx, step in enumerate(steps):
mycomponent = axis.components[cidx]
sublayouts = sublayoutss[cidx].copy()

axis_var = AxisVariable(axis.label)
new_layout = axis_var * step + start

sublayouts[layout_path | {axis.label: mycomponent.label}] = new_layout
layouts[layout_path | {axis.label: mycomponent.label}] = new_layout
start += _axis_component_size(axes, axis, mycomponent)

layouts.update(sublayouts)
return (
layouts,
None,
)
return (layouts, None)



Expand Down
21 changes: 21 additions & 0 deletions pyop3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyrsistent import pmap

from pyop3.config import config
from pyop3.exceptions import Pyop3Exception


class UniqueNameGenerator(pytools.UniqueNameGenerator):
Expand Down Expand Up @@ -64,8 +65,28 @@ def __init__(self, id=None):
Identified.__init__(self, id)


class KeyAlreadyExistsException(Pyop3Exception):
pass


class StrictlyUniqueDict(dict):
"""A dictionary where overwriting entries will raise an error."""

def __setitem__(self, key, value, /) -> None:
if key in self:
raise KeyAlreadyExistsException
return super().__setitem__(key, value)

def update(self, other) -> None:
shared_keys = self.keys() & other.keys()
if len(shared_keys) > 0:
raise KeyAlreadyExistsException
super().update(other)


class OrderedSet:
"""An ordered set."""

def __init__(self):
# Python dicts are ordered so we use one to keep the ordering
# and also have O(1) access.
Expand Down

0 comments on commit efa23a2

Please sign in to comment.