Skip to content

Commit

Permalink
Add domain_index_exprs to allow indexing ragged maps nicely
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Dec 15, 2023
1 parent 16a323e commit 634c548
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 91 deletions.
15 changes: 12 additions & 3 deletions pyop3/array/harray.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ def __init__(
*,
data=None,
max_value=None,
layouts=None,
target_paths=None,
index_exprs=None,
layouts=None,
domain_index_exprs=pmap(),
name=None,
prefix=None,
):
Expand Down Expand Up @@ -157,6 +158,7 @@ def __init__(

self._target_paths = target_paths or axes._default_target_paths()
self._index_exprs = index_exprs or axes._default_index_exprs()
self.domain_index_exprs = domain_index_exprs

self.layouts = layouts or axes.layouts

Expand All @@ -173,13 +175,15 @@ def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]:
)

loop_contexts = collect_loop_contexts(indices)
# breakpoint()
if not loop_contexts:
index_tree = just_one(as_index_forest(indices, axes=self.axes))
(
indexed_axes,
target_path_per_indexed_cpt,
index_exprs_per_indexed_cpt,
layout_exprs_per_indexed_cpt,
domain_index_exprs,
) = _index_axes(self.axes, index_tree, pmap())
target_paths, index_exprs, layout_exprs = _compose_bits(
self.axes,
Expand All @@ -198,6 +202,7 @@ def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]:
max_value=self.max_value,
target_paths=target_paths,
index_exprs=index_exprs,
domain_index_exprs=domain_index_exprs,
layouts=self.layouts,
name=self.name,
)
Expand All @@ -210,6 +215,7 @@ def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]:
target_path_per_indexed_cpt,
index_exprs_per_indexed_cpt,
layout_exprs_per_indexed_cpt,
domain_index_exprs,
) = _index_axes(self.axes, index_tree, loop_context)

(
Expand All @@ -230,11 +236,12 @@ def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]:
array_per_context[loop_context] = HierarchicalArray(
indexed_axes,
data=self.array,
max_value=self.max_value,
layouts=self.layouts,
target_paths=target_paths,
index_exprs=index_exprs,
layouts=self.layouts,
domain_index_exprs=domain_index_exprs,
name=self.name,
max_value=self.max_value,
)
return ContextSensitiveMultiArray(array_per_context)

Expand Down Expand Up @@ -461,6 +468,7 @@ def __getitem__(self, indices) -> ContextSensitiveMultiArray:
target_path_per_indexed_cpt,
index_exprs_per_indexed_cpt,
layout_exprs_per_indexed_cpt,
domain_index_exprs,
) = _index_axes(array.axes, index_tree, loop_context)

(
Expand All @@ -483,6 +491,7 @@ def __getitem__(self, indices) -> ContextSensitiveMultiArray:
max_value=self.max_value,
target_paths=target_paths,
index_exprs=index_exprs,
domain_index_exprs=domain_index_exprs,
layouts=self.layouts,
name=self.name,
)
Expand Down
2 changes: 2 additions & 0 deletions pyop3/array/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __getitem__(self, indices):
target_paths,
index_exprs,
layout_exprs_per_indexed_cpt,
domain_index_exprs,
) = _index_axes(self.axes, index_tree, loop_context)

indexed_axes = indexed_axes.set_up()
Expand All @@ -107,6 +108,7 @@ def __getitem__(self, indices):
data=packed,
target_paths=target_paths,
index_exprs=index_exprs,
domain_index_exprs=domain_index_exprs,
name=self.name,
)

Expand Down
7 changes: 5 additions & 2 deletions pyop3/axtree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ class AxisTree(PartialAxisTree, Indexed, ContextFreeLoopIterable):
"target_paths",
"index_exprs",
"layout_exprs",
"domain_index_exprs",
}

def __init__(
Expand All @@ -613,6 +614,7 @@ def __init__(
target_paths=None,
index_exprs=None,
layout_exprs=None,
domain_index_exprs=pmap(),
):
if some_but_not_all(
arg is None for arg in [target_paths, index_exprs, layout_exprs]
Expand All @@ -623,6 +625,7 @@ def __init__(
self._target_paths = target_paths or self._default_target_paths()
self._index_exprs = index_exprs or self._default_index_exprs()
self.layout_exprs = layout_exprs or self._default_layout_exprs()
self.domain_index_exprs = domain_index_exprs

def __getitem__(self, indices):
from pyop3.itree.tree import as_index_forest, collect_loop_contexts, index_axes
Expand Down Expand Up @@ -694,8 +697,8 @@ def layouts(self):
new_path = {}
replace_map = {}
for axis, cpt in self.path_with_nodes(*leaf).items():
new_path.update(self.target_paths[axis.id, cpt])
replace_map.update(self.layout_exprs[axis.id, cpt])
new_path.update(self.target_paths.get((axis.id, cpt), {}))
replace_map.update(self.layout_exprs.get((axis.id, cpt), {}))
new_path = freeze(new_path)

orig_layout = layouts[orig_path]
Expand Down
59 changes: 16 additions & 43 deletions pyop3/ir/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,36 +426,20 @@ def parse_loop_properly_this_time(
axis = axes.root

for component in axis.components:
# Maps "know" about indices that aren't otherwise available. Eg map(p)
# knows about p and this isn't accessible to axes.index_exprs except via
# the index expression

input_index_exprs = {}
axis_index_exprs = axes.index_exprs.get((axis.id, component.label), {})
for ax, iexpr in axis_index_exprs.items():
# TODO define as abstract property
if isinstance(iexpr, CalledMapVariable):
input_index_exprs.update(iexpr.input_index_exprs)

index_exprs_ = index_exprs | axis_index_exprs

# extra_target_replace_map = {}
# for axlabel, index_expr in index_exprs.items():
# # TODO define as abstract property
# if isinstance(index_expr, CalledMapVariable):
# replacer = JnameSubstitutor(
# outer_replace_map | target_replace_map, codegen_context
# )
# for axis_label, index_expr in index_expr.in_index_exprs.items():
# extra_target_replace_map[axis_label] = replacer(index_expr)
# extra_target_replace_map = freeze(extra_target_replace_map)

# breakpoint()
# Maps "know" about indices that aren't otherwise available. Eg map(p)
# knows about p and this isn't accessible to axes.index_exprs except via
# the index expression
domain_index_exprs = axes.domain_index_exprs.get(
(axis.id, component.label), pmap()
)

iname = codegen_context.unique_name("i")
extent_var = register_extent(
component.count,
index_exprs | input_index_exprs,
index_exprs | domain_index_exprs,
# TODO just put these in the default replace map
iname_replace_map | outer_replace_map,
codegen_context,
Expand Down Expand Up @@ -775,13 +759,6 @@ def parse_assignment_properly_this_time(
axis = axes.root
target_path = target_path | ctx_free_array.target_paths.get(None, pmap())
index_exprs = ctx_free_array.index_exprs.get(None, pmap())
# jname_extras = {}
# for axis_label, index_expr in my_index_exprs.items():
# jname_expr = JnameSubstitutor(
# iname_replace_map | jname_replace_map, codegen_context
# )(index_expr)
# jname_extras[axis_label] = jname_expr
# jname_replace_map = jname_replace_map | jname_extras

if axes.is_empty:
add_leaf_assignment(
Expand All @@ -801,9 +778,16 @@ def parse_assignment_properly_this_time(

for component in axis.components:
iname = codegen_context.unique_name("i")
# TODO also do the magic for ragged things here

# map magic
domain_index_exprs = ctx_free_array.domain_index_exprs.get(
(axis.id, component.label), pmap()
)
extent_var = register_extent(
component.count, index_exprs, iname_replace_map, codegen_context
component.count,
index_exprs | domain_index_exprs,
iname_replace_map,
codegen_context,
)
codegen_context.add_domain(iname, extent_var)

Expand All @@ -814,20 +798,9 @@ def parse_assignment_properly_this_time(

new_iname_replace_map = iname_replace_map | {axis.label: pym.var(iname)}

# I don't like that I need to do this here and also when I emit the layout
# instructions.
# Do I need the jnames on the way down? Think so for things like ragged...
index_exprs_ = index_exprs | ctx_free_array.index_exprs.get(
(axis.id, component.label), {}
)
# jname_extras = {}
# for axis_label, index_expr in my_index_exprs.items():
# jname_expr = JnameSubstitutor(
# new_iname_replace_map | jname_replace_map, codegen_context
# )(index_expr)
# jname_extras[axis_label] = jname_expr
# new_jname_replace_map = jname_replace_map | jname_extras
# new_jname_replace_map = new_iname_replace_map

with codegen_context.within_inames({iname}):
if subaxis := axes.child(axis, component):
Expand Down
Loading

0 comments on commit 634c548

Please sign in to comment.