Skip to content

Commit

Permalink
Basic tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Sep 20, 2023
1 parent 8345422 commit b03cc61
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 274 deletions.
261 changes: 85 additions & 176 deletions pyop3/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,16 +839,16 @@ def __getitem__(self, indices):
# 3. layout exprs
layout_exprs_per_cpt = {}
else:
raise NotImplementedError("TODO ASAP")
# TODO make this a tree traversal combined with the empty case
(
new_target_path_per_leaf,
new_index_exprs_per_leaf,
new_layout_exprs_per_leaf,
target_path_per_cpt,
index_exprs_per_cpt,
layout_exprs_per_cpt,
) = self.parse_bits(
indexed_axes,
target_path_per_leaf,
index_exprs_per_leaf,
layout_exprs_per_leaf,
target_path_per_indexed_cpt,
index_exprs_per_indexed_cpt,
layout_exprs_per_indexed_cpt,
)

# breakpoint()
Expand All @@ -865,195 +865,104 @@ def __call__(self, *args):

def parse_bits(
self,
axes,
target_path_per_leaf,
index_exprs_per_leaf,
layout_exprs_per_leaf,
indexed_axes,
target_path_per_indexed_component,
index_exprs_per_indexed_component,
layout_exprs_per_indexed_component,
*,
axis=None,
source_path=pmap(),
partial_target_path=pmap(),
partial_index_exprs=pmap(),
partial_layout_exprs=pmap(),
):
from pyop3.distarray.multiarray import IndexExpressionReplacer

assert not axes.is_empty, "handled outside"
# TODO should handle here
assert not indexed_axes.is_empty, "handled outside"

axis = axis or axes.root
new_target_path_per_cpt = {}
new_index_exprs_per_cpt = {}
new_layout_exprs_per_cpt = {}
if axis is None:
partial_target_path |= target_path_per_indexed_component.get(None, {})
partial_index_exprs |= index_exprs_per_indexed_component.get(None, {})
partial_layout_exprs |= layout_exprs_per_indexed_component.get(None, {})

new_target_path_per_leaf = {}
new_index_exprs_per_leaf = {}
new_layout_exprs_per_leaf = {}
axis = axis or indexed_axes.root
for component in axis.components:
new_source_path = source_path | {axis.label: component.label}

if subaxis := axes.child(axis, component):
retval = self.parse_bits(
axes,
target_path_per_leaf,
index_exprs_per_leaf,
layout_exprs_per_leaf,
axis=subaxis,
source_path=new_source_path,
)
new_partial_target_path = (
partial_target_path
| target_path_per_indexed_component.get((axis.id, component.label), {})
)

new_target_path_per_leaf |= retval[0]
new_index_exprs_per_leaf |= retval[1]
new_layout_exprs_per_leaf |= retval[2]
new_partial_index_exprs = (
partial_index_exprs
| index_exprs_per_indexed_component.get((axis.id, component.label), {})
)
new_partial_layout_exprs = (
partial_layout_exprs
| layout_exprs_per_indexed_component.get((axis.id, component.label), {})
)

else:
# NOTE: This is NOT the final target path. This only targets
# the thing *before* the indexing took place. Subsequent indexing
# requires composition.
path = target_path_per_leaf[new_source_path]

# 1. target path
new_target_path = self.target_path_per_leaf[path]
new_target_path_per_leaf[new_source_path] = new_target_path

# 2. index exprs
index_expr_replace_map = index_exprs_per_leaf[new_source_path]
new_index_exprs = {}
for axis_label, index_expr in self.index_exprs_per_leaf[path].items():
new_index_expr = IndexExpressionReplacer(index_expr_replace_map)(
# if target_path is "complete" then do stuff, else pass responsibility to next func down
try:
target_node_path = self.path_with_nodes(
*self._node_from_path(new_partial_target_path), and_components=True
)
except:
raise NotImplementedError("TODO")

new_target_path_per_cpt[axis.id, component.label] = {}
new_index_exprs_per_cpt[axis.id, component.label] = {}
new_layout_exprs_per_cpt[axis.id, component.label] = {}
for target_axis, target_cpt in target_node_path.items():
new_target_path_per_cpt[
axis.id, component.label
] |= self.target_path_per_component[target_axis.id, target_cpt.label]

# do a replacement
orig_index_exprs = self.index_exprs_per_component[
target_axis.id, target_cpt.label
]
for axis_label, index_expr in orig_index_exprs.items():
new_index_expr = IndexExpressionReplacer(new_partial_index_exprs)(
index_expr
)
new_index_exprs[axis_label] = new_index_expr
new_index_exprs_per_leaf[new_source_path] = new_index_exprs
new_index_exprs_per_cpt[axis.id, component.label][
axis_label
] = new_index_expr

# 3. layout exprs
new_layout_exprs_per_leaf[new_source_path] = pmap() # TODO
return (
new_target_path_per_leaf,
new_index_exprs_per_leaf,
new_layout_exprs_per_leaf,
)
# TODO
new_layout_exprs_per_cpt[axis.id, component.label][
target_axis.label
] = NotImplemented

def parse_target_paths(
self,
indexed_axes,
indexed_axis,
target_path_per_axis_tuple,
targetpath,
minipath=(),
mypath=(),
):
assert False, "old code"
new_target_path_per_axis_tuple = {}

for cpt in indexed_axis.components:
new_minipath = minipath + ((indexed_axis, cpt),)
newmypath = mypath
found = False
if new_minipath in target_path_per_axis_tuple:
found = True
pathextras = target_path_per_axis_tuple[new_minipath]
new_targetpath = targetpath | pathextras
node, ncpt = self._node_from_path(new_targetpath)
newmypath += ((node, ncpt),)
else:
new_targetpath = targetpath
# NOTE: This is NOT the final target path. This only targets
# the thing *before* the indexing took place. Subsequent indexing
# requires composition.

if newmypath in self.target_paths:
new_target_path_per_axis_tuple[new_minipath] = self.target_paths[
newmypath
]
newmypath = ()

if found:
new_minipath = ()

if subaxis := indexed_axes.child(indexed_axis, cpt):
retval = self.parse_target_paths(
if subaxis := indexed_axes.child(axis, component):
retval = self.parse_bits(
indexed_axes,
subaxis,
target_path_per_axis_tuple,
new_targetpath,
new_minipath,
newmypath,
)
new_target_path_per_axis_tuple |= retval

else:
assert not new_minipath
assert not newmypath

# should have handled above
pass
return pmap(new_target_path_per_axis_tuple)

def parse_index_exprs(
self,
indexed_axes,
indexed_axis,
target_path_per_axis_tuple,
index_expr_replace_map,
target_path=pmap(),
minipath=(),
first=True,
):
assert False, "old code"
from pyop3.distarray.multiarray import IndexExpressionReplacer

new_index_expr_per_target = {}

###
if first and () in target_path_per_axis_tuple:
target_path = target_path | target_path_per_axis_tuple[()]

leaf = self.orig_axes._node_from_path(target_path)
target_path_with_axes = self.orig_axes.path_with_nodes(*leaf, ordered=True)
assert len(target_path_with_axes) == len(target_path)

for target_axis, target_component in target_path_with_axes:
index_expr = self.index_exprs[target_axis.id, target_component]
# breakpoint()
new_index_expr = IndexExpressionReplacer(index_expr_replace_map[()])(
index_expr
target_path_per_indexed_component,
index_exprs_per_indexed_component,
layout_exprs_per_indexed_component,
axis=subaxis,
partial_target_path=new_partial_target_path,
partial_index_exprs=new_partial_index_exprs,
partial_layout_exprs=new_partial_layout_exprs,
)
new_index_expr_per_target[()] = new_index_expr
new_target_path_per_cpt |= retval[0]
new_index_exprs_per_cpt |= retval[1]
new_layout_exprs_per_cpt |= retval[2]

minipath = ()

###

for cpt in indexed_axis.components:
# breakpoint()
new_target_path = target_path
new_minipath = minipath + ((indexed_axis, cpt),)
if new_minipath in target_path_per_axis_tuple:
# not needed?
new_target_path = target_path | target_path_per_axis_tuple[new_minipath]

leaf = self._node_from_path(new_target_path)
target_path_with_axes = self.path_with_nodes(*leaf, ordered=True)

# FIXME should have loop here so I don't miss anything

target_axis, target_component = target_path_with_axes[-1]

index_expr = self.index_exprs[target_axis.id, target_component]
# breakpoint()
new_index_expr = IndexExpressionReplacer(
index_expr_replace_map[new_minipath]
)(index_expr)
new_index_expr_per_target[indexed_axis.id, cpt.label] = new_index_expr

new_minipath = ()

if subaxis := indexed_axes.child(indexed_axis, cpt):
subresult = self.parse_index_exprs(
indexed_axes,
subaxis,
target_path_per_axis_tuple,
index_expr_replace_map,
new_target_path,
new_minipath,
first=False,
)
new_index_expr_per_target |= subresult
else:
assert not new_minipath
pass
return pmap(new_index_expr_per_target)
return (
new_target_path_per_cpt,
new_index_exprs_per_cpt,
new_layout_exprs_per_cpt,
)

@property
def axis_trees(self):
Expand Down
Loading

0 comments on commit b03cc61

Please sign in to comment.