From 634c5482e16325c0da83c0a0fbbd29c155eb6bbe Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Fri, 15 Dec 2023 14:15:17 +0000 Subject: [PATCH] Add domain_index_exprs to allow indexing ragged maps nicely --- pyop3/array/harray.py | 15 +++++-- pyop3/array/petsc.py | 2 + pyop3/axtree/tree.py | 7 ++- pyop3/ir/lower.py | 59 +++++++----------------- pyop3/itree/tree.py | 101 ++++++++++++++++++++++++------------------ 5 files changed, 93 insertions(+), 91 deletions(-) diff --git a/pyop3/array/harray.py b/pyop3/array/harray.py index 8efe58ff..e0479c69 100644 --- a/pyop3/array/harray.py +++ b/pyop3/array/harray.py @@ -113,9 +113,10 @@ def __init__( *, data=None, max_value=None, + layouts=None, target_paths=None, index_exprs=None, - layouts=None, + domain_index_exprs=pmap(), name=None, prefix=None, ): @@ -157,6 +158,7 @@ def __init__( self._target_paths = target_paths or axes._default_target_paths() self._index_exprs = index_exprs or axes._default_index_exprs() + self.domain_index_exprs = domain_index_exprs self.layouts = layouts or axes.layouts @@ -173,6 +175,7 @@ def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]: ) loop_contexts = collect_loop_contexts(indices) + # breakpoint() if not loop_contexts: index_tree = just_one(as_index_forest(indices, axes=self.axes)) ( @@ -180,6 +183,7 @@ def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]: target_path_per_indexed_cpt, index_exprs_per_indexed_cpt, layout_exprs_per_indexed_cpt, + domain_index_exprs, ) = _index_axes(self.axes, index_tree, pmap()) target_paths, index_exprs, layout_exprs = _compose_bits( self.axes, @@ -198,6 +202,7 @@ def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]: max_value=self.max_value, target_paths=target_paths, index_exprs=index_exprs, + domain_index_exprs=domain_index_exprs, layouts=self.layouts, name=self.name, ) @@ -210,6 +215,7 @@ def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]: target_path_per_indexed_cpt, index_exprs_per_indexed_cpt, layout_exprs_per_indexed_cpt, + domain_index_exprs, ) = _index_axes(self.axes, index_tree, loop_context) ( @@ -230,11 +236,12 @@ def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]: array_per_context[loop_context] = HierarchicalArray( indexed_axes, data=self.array, - max_value=self.max_value, + layouts=self.layouts, target_paths=target_paths, index_exprs=index_exprs, - layouts=self.layouts, + domain_index_exprs=domain_index_exprs, name=self.name, + max_value=self.max_value, ) return ContextSensitiveMultiArray(array_per_context) @@ -461,6 +468,7 @@ def __getitem__(self, indices) -> ContextSensitiveMultiArray: target_path_per_indexed_cpt, index_exprs_per_indexed_cpt, layout_exprs_per_indexed_cpt, + domain_index_exprs, ) = _index_axes(array.axes, index_tree, loop_context) ( @@ -483,6 +491,7 @@ def __getitem__(self, indices) -> ContextSensitiveMultiArray: max_value=self.max_value, target_paths=target_paths, index_exprs=index_exprs, + domain_index_exprs=domain_index_exprs, layouts=self.layouts, name=self.name, ) diff --git a/pyop3/array/petsc.py b/pyop3/array/petsc.py index 10a59d31..878d6f29 100644 --- a/pyop3/array/petsc.py +++ b/pyop3/array/petsc.py @@ -96,6 +96,7 @@ def __getitem__(self, indices): target_paths, index_exprs, layout_exprs_per_indexed_cpt, + domain_index_exprs, ) = _index_axes(self.axes, index_tree, loop_context) indexed_axes = indexed_axes.set_up() @@ -107,6 +108,7 @@ def __getitem__(self, indices): data=packed, target_paths=target_paths, index_exprs=index_exprs, + domain_index_exprs=domain_index_exprs, name=self.name, ) diff --git a/pyop3/axtree/tree.py b/pyop3/axtree/tree.py index 4b2d1a4e..5d48ba71 100644 --- a/pyop3/axtree/tree.py +++ b/pyop3/axtree/tree.py @@ -605,6 +605,7 @@ class AxisTree(PartialAxisTree, Indexed, ContextFreeLoopIterable): "target_paths", "index_exprs", "layout_exprs", + "domain_index_exprs", } def __init__( @@ -613,6 +614,7 @@ def __init__( target_paths=None, index_exprs=None, layout_exprs=None, + domain_index_exprs=pmap(), ): if some_but_not_all( arg is None for arg in [target_paths, index_exprs, layout_exprs] @@ -623,6 +625,7 @@ def __init__( self._target_paths = target_paths or self._default_target_paths() self._index_exprs = index_exprs or self._default_index_exprs() self.layout_exprs = layout_exprs or self._default_layout_exprs() + self.domain_index_exprs = domain_index_exprs def __getitem__(self, indices): from pyop3.itree.tree import as_index_forest, collect_loop_contexts, index_axes @@ -694,8 +697,8 @@ def layouts(self): new_path = {} replace_map = {} for axis, cpt in self.path_with_nodes(*leaf).items(): - new_path.update(self.target_paths[axis.id, cpt]) - replace_map.update(self.layout_exprs[axis.id, cpt]) + new_path.update(self.target_paths.get((axis.id, cpt), {})) + replace_map.update(self.layout_exprs.get((axis.id, cpt), {})) new_path = freeze(new_path) orig_layout = layouts[orig_path] diff --git a/pyop3/ir/lower.py b/pyop3/ir/lower.py index 058fa0dd..1a5f51ad 100644 --- a/pyop3/ir/lower.py +++ b/pyop3/ir/lower.py @@ -426,36 +426,20 @@ def parse_loop_properly_this_time( axis = axes.root for component in axis.components: - # Maps "know" about indices that aren't otherwise available. Eg map(p) - # knows about p and this isn't accessible to axes.index_exprs except via - # the index expression - - input_index_exprs = {} axis_index_exprs = axes.index_exprs.get((axis.id, component.label), {}) - for ax, iexpr in axis_index_exprs.items(): - # TODO define as abstract property - if isinstance(iexpr, CalledMapVariable): - input_index_exprs.update(iexpr.input_index_exprs) - index_exprs_ = index_exprs | axis_index_exprs - # extra_target_replace_map = {} - # for axlabel, index_expr in index_exprs.items(): - # # TODO define as abstract property - # if isinstance(index_expr, CalledMapVariable): - # replacer = JnameSubstitutor( - # outer_replace_map | target_replace_map, codegen_context - # ) - # for axis_label, index_expr in index_expr.in_index_exprs.items(): - # extra_target_replace_map[axis_label] = replacer(index_expr) - # extra_target_replace_map = freeze(extra_target_replace_map) - - # breakpoint() + # Maps "know" about indices that aren't otherwise available. Eg map(p) + # knows about p and this isn't accessible to axes.index_exprs except via + # the index expression + domain_index_exprs = axes.domain_index_exprs.get( + (axis.id, component.label), pmap() + ) iname = codegen_context.unique_name("i") extent_var = register_extent( component.count, - index_exprs | input_index_exprs, + index_exprs | domain_index_exprs, # TODO just put these in the default replace map iname_replace_map | outer_replace_map, codegen_context, @@ -775,13 +759,6 @@ def parse_assignment_properly_this_time( axis = axes.root target_path = target_path | ctx_free_array.target_paths.get(None, pmap()) index_exprs = ctx_free_array.index_exprs.get(None, pmap()) - # jname_extras = {} - # for axis_label, index_expr in my_index_exprs.items(): - # jname_expr = JnameSubstitutor( - # iname_replace_map | jname_replace_map, codegen_context - # )(index_expr) - # jname_extras[axis_label] = jname_expr - # jname_replace_map = jname_replace_map | jname_extras if axes.is_empty: add_leaf_assignment( @@ -801,9 +778,16 @@ def parse_assignment_properly_this_time( for component in axis.components: iname = codegen_context.unique_name("i") - # TODO also do the magic for ragged things here + + # map magic + domain_index_exprs = ctx_free_array.domain_index_exprs.get( + (axis.id, component.label), pmap() + ) extent_var = register_extent( - component.count, index_exprs, iname_replace_map, codegen_context + component.count, + index_exprs | domain_index_exprs, + iname_replace_map, + codegen_context, ) codegen_context.add_domain(iname, extent_var) @@ -814,20 +798,9 @@ def parse_assignment_properly_this_time( new_iname_replace_map = iname_replace_map | {axis.label: pym.var(iname)} - # I don't like that I need to do this here and also when I emit the layout - # instructions. - # Do I need the jnames on the way down? Think so for things like ragged... index_exprs_ = index_exprs | ctx_free_array.index_exprs.get( (axis.id, component.label), {} ) - # jname_extras = {} - # for axis_label, index_expr in my_index_exprs.items(): - # jname_expr = JnameSubstitutor( - # new_iname_replace_map | jname_replace_map, codegen_context - # )(index_expr) - # jname_extras[axis_label] = jname_expr - # new_jname_replace_map = jname_replace_map | jname_extras - # new_jname_replace_map = new_iname_replace_map with codegen_context.within_inames({iname}): if subaxis := axes.child(axis, component): diff --git a/pyop3/itree/tree.py b/pyop3/itree/tree.py index 93c2717e..f59550ee 100644 --- a/pyop3/itree/tree.py +++ b/pyop3/itree/tree.py @@ -63,16 +63,6 @@ def map_multi_array(self, expr): index_exprs = {ax: self.rec(iexpr) for ax, iexpr in expr.index_exprs.items()} return MultiArrayVariable(expr.array, expr.target_path, index_exprs) - def map_called_map(self, expr): - raise NotImplementedError - array = expr.function.map_component.array - - # the inner_expr tells us the right mapping for the temporary, however, - # for maps that are arrays the innermost axis label does not always match - # the label used by the temporary. Therefore we need to do a swap here. - indices = {axis: self.rec(idx) for axis, idx in expr.parameters.items()} - return CalledMapVariable(expr.function, indices) - def map_loop_index(self, expr): # this is hacky, if I make this raise a KeyError then we fail in indexing return self._replace_map.get((expr.name, expr.axis), expr) @@ -337,16 +327,18 @@ def index(self) -> LoopIndex: target_paths, index_exprs, layout_exprs, + domain_index_exprs, ) = collect_shape_index_callback(self, loop_indices=context) # breakpoint() - axes = AxisTree.from_node_map(axes.parent_to_children) axes = AxisTree( axes.parent_to_children, target_paths, index_exprs, layout_exprs, + domain_index_exprs, ) + # breakpoint() context_map[context] = axes context_sensitive_axes = ContextSensitiveAxisTree(context_map) return LoopIndex(context_sensitive_axes) @@ -586,6 +578,8 @@ def _(arg: LocalLoopIndex): @collect_loop_contexts.register def _(arg: LoopIndex, local=False): + # I think that this is wrong! not enough detected + # breakpoint() if isinstance(arg.iterset, ContextSensitiveAxisTree): contexts = [] for loop_context, axis_tree in arg.iterset.context_map.items(): @@ -600,14 +594,13 @@ def _(arg: LoopIndex, local=False): target_path.update( axis_tree.target_paths.get((axis.id, cpt.label), {}) ) - extra_source_context.update(source_path) - extracontext.update(target_path) - if local: - contexts.append( - loop_context | {arg.local_index.id: pmap(extra_source_context)} - ) - else: - contexts.append(loop_context | {arg.id: pmap(extracontext)}) + + if local: + contexts.append( + loop_context | {arg.local_index.id: pmap(source_path)} + ) + else: + contexts.append(loop_context | {arg.id: pmap(target_path)}) return tuple(contexts) else: assert isinstance(arg.iterset, AxisTree) @@ -835,6 +828,7 @@ def _(loop_index: LoopIndex, *, loop_indices, **kwargs): target_path_per_component, index_exprs_per_component, layout_exprs_per_component, + pmap(), ) @@ -860,6 +854,7 @@ def _(local_index: LocalLoopIndex, *args, loop_indices, **kwargs): target_path_per_cpt, index_exprs_per_cpt, layout_exprs_per_cpt, + pmap(), ) @@ -957,6 +952,7 @@ def _(slice_: Slice, *, prev_axes, **kwargs): target_path_per_component, index_exprs_per_component, layout_exprs_per_component, + pmap(), ) @@ -967,6 +963,7 @@ def _(called_map: CalledMap, **kwargs): prior_target_path_per_cpt, prior_index_exprs_per_cpt, _, + prior_domain_index_exprs_per_cpt, ) = collect_shape_index_callback(called_map.from_index, **kwargs) if not prior_axes: @@ -977,35 +974,18 @@ def _(called_map: CalledMap, **kwargs): target_path_per_cpt, index_exprs_per_cpt, layout_exprs_per_cpt, + domain_index_exprs_per_cpt, ) = _make_leaf_axis_from_called_map( called_map, prior_target_path, prior_index_exprs ) axes = PartialAxisTree(axis) - # FIXME I think that this logic is truly awful, and it doesn't even work - # for nested ragged things! - # we need to keep track of the index expressions from the loop indices - # I think this is fundamentally the same thing that we are already doing for - # loop indices - # if None in index_exprs_per_cpt: - # index_exprs_per_cpt = dict(index_exprs_per_cpt) - # breakpoint() - # index_exprs_per_cpt[None] |= prior_index_exprs_per_cpt.get(None, pmap()) - # index_exprs_per_cpt = freeze(index_exprs_per_cpt) - # else: - # index_exprs_per_cpt |= {None: prior_index_exprs_per_cpt.get(None, pmap())} - else: axes = prior_axes target_path_per_cpt = {} - # if None in index_exprs_per_cpt: - # index_exprs_per_cpt = dict(index_exprs_per_cpt) - # index_exprs_per_cpt[None] |= prior_index_exprs_per_cpt.get(None, pmap()) - # index_exprs_per_cpt = freeze(index_exprs_per_cpt) - # else: - # index_exprs_per_cpt = {None: prior_index_exprs_per_cpt.get(None, pmap())} index_exprs_per_cpt = {} layout_exprs_per_cpt = {} + domain_index_exprs_per_cpt = {} for prior_leaf_axis, prior_leaf_cpt in prior_axes.leaves: prior_target_path = prior_target_path_per_cpt.get(None, pmap()) prior_index_exprs = prior_index_exprs_per_cpt.get(None, pmap()) @@ -1025,6 +1005,7 @@ def _(called_map: CalledMap, **kwargs): subtarget_paths, subindex_exprs, sublayout_exprs, + subdomain_index_exprs, ) = _make_leaf_axis_from_called_map( called_map, prior_target_path, prior_index_exprs ) @@ -1034,12 +1015,21 @@ def _(called_map: CalledMap, **kwargs): target_path_per_cpt.update(subtarget_paths) index_exprs_per_cpt.update(subindex_exprs) layout_exprs_per_cpt.update(sublayout_exprs) + domain_index_exprs_per_cpt.update(subdomain_index_exprs) + + # does this work? + # need to track these for nested ragged things + # breakpoint() + # index_exprs_per_cpt.update({(prior_leaf_axis.id, prior_leaf_cpt.label): prior_index_exprs}) + domain_index_exprs_per_cpt.update(prior_domain_index_exprs_per_cpt) + # layout_exprs_per_cpt.update(...) return ( axes, freeze(target_path_per_cpt), freeze(index_exprs_per_cpt), freeze(layout_exprs_per_cpt), + freeze(domain_index_exprs_per_cpt), ) @@ -1051,6 +1041,7 @@ def _make_leaf_axis_from_called_map(called_map, prior_target_path, prior_index_e target_path_per_cpt = {} index_exprs_per_cpt = {} layout_exprs_per_cpt = {} + domain_index_exprs_per_cpt = {} for map_cpt in called_map.map.connectivity[prior_target_path]: cpt = AxisComponent(map_cpt.arity, label=map_cpt.label) @@ -1093,6 +1084,7 @@ def _make_leaf_axis_from_called_map(called_map, prior_target_path, prior_index_e my_index_exprs[axlabel] = replacer(index_expr) new_inner_index_expr = my_index_exprs + # breakpoint() map_var = CalledMapVariable( map_cpt.array, my_target_path, prior_index_exprs, new_inner_index_expr ) @@ -1107,9 +1099,17 @@ def _make_leaf_axis_from_called_map(called_map, prior_target_path, prior_index_e called_map.name: pym.primitives.NaN(IntType) } + domain_index_exprs_per_cpt[axis_id, cpt.label] = prior_index_exprs + axis = Axis(components, label=called_map.name, id=axis_id) - return axis, target_path_per_cpt, index_exprs_per_cpt, layout_exprs_per_cpt + return ( + axis, + target_path_per_cpt, + index_exprs_per_cpt, + layout_exprs_per_cpt, + domain_index_exprs_per_cpt, + ) def _index_axes(axes, indices: IndexTree, loop_context): @@ -1118,6 +1118,7 @@ def _index_axes(axes, indices: IndexTree, loop_context): tpaths, index_expr_per_target, layout_expr_per_target, + domain_index_exprs, ) = _index_axes_rec( indices, current_index=indices.root, @@ -1134,7 +1135,13 @@ def _index_axes(axes, indices: IndexTree, loop_context): raise ValueError("incorrect/insufficient indices") # return the new axes plus the new index expressions per leaf - return indexed_axes, tpaths, index_expr_per_target, layout_expr_per_target + return ( + indexed_axes, + tpaths, + index_expr_per_target, + layout_expr_per_target, + domain_index_exprs, + ) def _index_axes_rec( @@ -1150,6 +1157,7 @@ def _index_axes_rec( target_path_per_cpt_per_index, index_exprs_per_cpt_per_index, layout_exprs_per_cpt_per_index, + domain_index_exprs_per_cpt_per_index, ) = tuple(map(dict, rest)) if axes_per_index: @@ -1185,9 +1193,13 @@ def _index_axes_rec( index_exprs_per_cpt_per_index.update({key: retval[2][key]}) layout_exprs_per_cpt_per_index.update({key: retval[3][key]}) - target_path_per_component = pmap(target_path_per_cpt_per_index) - index_exprs_per_component = pmap(index_exprs_per_cpt_per_index) - layout_exprs_per_component = pmap(layout_exprs_per_cpt_per_index) + assert key not in domain_index_exprs_per_cpt_per_index + domain_index_exprs_per_cpt_per_index[key] = retval[4].get(key, pmap()) + + target_path_per_component = freeze(target_path_per_cpt_per_index) + index_exprs_per_component = freeze(index_exprs_per_cpt_per_index) + layout_exprs_per_component = freeze(layout_exprs_per_cpt_per_index) + domain_index_exprs_per_cpt_per_index = freeze(domain_index_exprs_per_cpt_per_index) axes = axes_per_index for k, subax in subaxes.items(): @@ -1202,6 +1214,7 @@ def _index_axes_rec( target_path_per_component, index_exprs_per_component, layout_exprs_per_component, + domain_index_exprs_per_cpt_per_index, ) @@ -1211,6 +1224,7 @@ def index_axes(axes, index_tree): target_path_per_indexed_cpt, index_exprs_per_indexed_cpt, layout_exprs_per_indexed_cpt, + domain_index_exprs, ) = _index_axes(axes, index_tree, loop_context=index_tree.loop_context) target_paths, index_exprs, layout_exprs = _compose_bits( @@ -1228,6 +1242,7 @@ def index_axes(axes, index_tree): target_paths, index_exprs, layout_exprs, + domain_index_exprs, )