diff --git a/pyop3/array/petsc.py b/pyop3/array/petsc.py index 641000d..d1a77c2 100644 --- a/pyop3/array/petsc.py +++ b/pyop3/array/petsc.py @@ -318,23 +318,31 @@ def _iter_nest_labels( else: yield (rlabel_acc_, clabel_acc_) - # @cached_property - def _block_axes(self, axes, shape, blocked=False): + def _block_axes(self, axes, shape, nested_index=None): block_axes, target_paths, index_exprs = self._collect_block_axes(axes, shape) - block_axes_unindexed, _, _ = self._collect_block_axes(axes.unindexed, shape) + if self.nested: + block_axes_unindexed, _, _ = self._collect_block_axes( + axes.unindexed[nested_index], shape + ) + else: + block_axes_unindexed, _, _ = self._collect_block_axes( + axes.unindexed, shape + ) return IndexedAxisTree( block_axes.node_map, block_axes_unindexed, target_paths=target_paths, index_exprs=index_exprs, outer_loops=axes.outer_loops, layout_exprs=None) - - # @cached_property - def _nest_axes(self, axes, index, blocked=False): + + def _nest_axes(self, axes, index): + if axes.size > 1: + axes = self._block_axes(axes, axes.size, nested_index=index) axes_unindexed = AxisTree(axes.unindexed[index].node_map) target_paths = thaw(axes.target_paths) to_kill = target_paths.pop(None) - for key, target_paths_per_axis in target_paths.items(): - target_paths[key].pop(to_kill) + for key, _ in target_paths.items(): + for k, _ in to_kill.items(): + target_paths[key].pop(k) index_exprs = dict(axes.index_exprs) index_exprs.pop(None) return IndexedAxisTree( @@ -344,7 +352,6 @@ def _nest_axes(self, axes, index, blocked=False): outer_loops=axes.outer_loops, layout_exprs=None) - def _collect_block_axes(self, axes, shape, axis=None): from pyop3.axtree.layout import _axis_size target_paths = {} @@ -368,6 +375,30 @@ def _collect_block_axes(self, axes, shape, axis=None): index_exprs.update(subindex_exprs) return axis_tree, target_paths, index_exprs + def _map(self, axes): + from pyop3.axtree.layout import my_product + loop_index = just_one(axes.outer_loops) + iterset = AxisTree(loop_index.iterset.node_map) + map_axes = iterset.add_subtree(axes, *iterset.leaf) + global_map = HierarchicalArray(map_axes, dtype=IntType) + global_map = global_map[loop_index.local_index] + for idxs in my_product(axes.outer_loops): + # target_indices = {idx.index.id: idx.target_exprs for idx in idxs} + target_indices = merge_dicts([idx.replace_map for idx in idxs]) + + for p in axes.iter(idxs, include_ghost_points=True): + offset = axes.unindexed.offset( + p.target_exprs, p.target_path, loop_exprs=target_indices + ) + global_map.set_value( + p.source_exprs, + offset, + p.source_path, + loop_exprs=target_indices, + ) + return global_map + + @cached_property @PETSc.Log.EventDecorator() def maps(self): @@ -376,153 +407,25 @@ def maps(self): # TODO: Don't think these need to be lists here. # FIXME: This will only work for singly-nested matrices if self.nested: - #for (row_index, col_index), submat_type in self.mat_type.items(): - - row_index = 0 # for now only! - #col_index = ??? - #submat_type = self.mat_type[row_index, col_index] + row_index = 1 # for now only! + col_index = 1 submat_type = "aij" - - raxes = self._nest_axes(self.raxes, row_index) - - # if self.raxes.unindexed[row_index].size // 2 == 1: - # raxes = unindexed_row_axes - # else: - # raxes = self._block_axes(unindexed_row_axes, self.raxes.unindexed[row_index].size // 2) - - loop_index = just_one(raxes.outer_loops) - iterset = AxisTree(loop_index.iterset.node_map) - rmap_axes = iterset.add_subtree(raxes, *iterset.leaf) - rmap = HierarchicalArray(rmap_axes, dtype=IntType) - rmap = rmap[loop_index.local_index] - for idxs in my_product(raxes.outer_loops): - # target_indices = {idx.index.id: idx.target_exprs for idx in idxs} - target_indices = merge_dicts([idx.replace_map for idx in idxs]) - - for p in raxes.iter(idxs, include_ghost_points=True): - offset = raxes.unindexed.offset( - p.target_exprs, p.target_path, loop_exprs=target_indices - ) - rmap.set_value( - p.source_exprs, - offset, - p.source_path, - loop_exprs=target_indices, - ) - - # # # rfield_axis = self.raxes.unindexed.root - # # cfield_axis = self.caxes.unindexed.root - - # if strictly_all(c.unit for c in rfield_axis.components): - # # This weird trick is because the right target path for the field - # # is actually tied to the root of the axis tree, rather than None. - # # This seems like a limitation of the _compose_bits function. - # rfield = single_valued( - # cpt - # for mycpt in self.raxes.root.components - # for ax, cpt in self.raxes.target_paths[ - # self.raxes.root.id, mycpt.label - # ].items() - # if ax == rfield_axis.label - # ) - # orig_raxes = AxisTree(self.raxes.unindexed[rfield].node_map) - # orig_raxess = [orig_raxes] - # dropped_rkeys = {rfield_axis.label} - # else: - # orig_raxess = [self.raxes.unindexed] - # dropped_rkeys = frozenset() - - # if strictly_all(c.unit for c in cfield_axis.components): - # cfield = single_valued( - # cpt - # for mycpt in self.caxes.root.components - # for ax, cpt in self.caxes.target_paths[ - # self.caxes.root.id, mycpt.label - # ].items() - # if ax == cfield_axis.label - # ) - # orig_caxes = AxisTree(self.caxes.unindexed[cfield].node_map) - # orig_caxess = [orig_caxes] - # dropped_ckeys = {cfield_axis.label} - # else: - # orig_caxess = [self.caxes.unindexed] - # dropped_ckeys = set() - elif self.mat_type == "baij": - raxes = self._block_axes(self.raxes, self.block_shape) - caxes = self._block_axes(self.caxes, self.block_shape) - elif self.mat_type == "aij": - raxes = self.raxes - caxes = self.caxes + caxes = self._nest_axes(self.caxes, col_index) + rmap = self._map(raxes) + cmap = self._map(caxes) + elif self.mat_type == "baij" or self.mat_type == "aij": + if self.mat_type == "baij": + raxes = self._block_axes(self.raxes, self.block_shape) + caxes = self._block_axes(self.caxes, self.block_shape) + else: + raxes = self.raxes + caxes = self.caxes + rmap = self._map(raxes) + cmap = self._map(caxes) else: raise NotImplementedError - orig_raxess = [raxes.unindexed] - orig_caxess = [caxes.unindexed] - dropped_rkeys = set() - dropped_ckeys = set() - - # TODO: are dropped_rkeys and dropped_ckeys still needed? - loop_index = just_one(raxes.outer_loops) - iterset = AxisTree(loop_index.iterset.node_map) - - rmap_axes = iterset.add_subtree(raxes, *iterset.leaf) - rmap = HierarchicalArray(rmap_axes, dtype=IntType) - rmap = rmap[loop_index.local_index] - - loop_index = just_one(caxes.outer_loops) - iterset = AxisTree(loop_index.iterset.node_map) - - cmap_axes = iterset.add_subtree(caxes, *iterset.leaf) - cmap = HierarchicalArray(cmap_axes, dtype=IntType) - cmap = cmap[loop_index.local_index] - - # TODO: Make the code below go into a separate function distinct - # from mat_type logic. Then can also share code for rmap and cmap. - for orig_raxes in orig_raxess: - for idxs in my_product(raxes.outer_loops): - # target_indices = {idx.index.id: idx.target_exprs for idx in idxs} - target_indices = merge_dicts([idx.replace_map for idx in idxs]) - - for p in raxes.iter(idxs, include_ghost_points=True): # seems to fix things - target_path = p.target_path - target_exprs = p.target_exprs - for key in dropped_rkeys: - target_path = target_path.remove(key) - target_exprs = target_exprs.remove(key) - - offset = orig_raxes.offset( - target_exprs, target_path, loop_exprs=target_indices - ) - rmap.set_value( - p.source_exprs, - offset, - p.source_path, - loop_exprs=target_indices, - ) - - for orig_caxes in orig_caxess: - for idxs in my_product(self.caxes.outer_loops): - # target_indices = {idx.index.id: idx.target_exprs for idx in idxs} - target_indices = merge_dicts([idx.replace_map for idx in idxs]) - - # for p in self.caxes.iter(idxs): - for p in caxes.iter(idxs, include_ghost_points=True): # seems to fix things - target_path = p.target_path - target_exprs = p.target_exprs - for key in dropped_ckeys: - target_path = target_path.remove(key) - target_exprs = target_exprs.remove(key) - - offset = orig_caxes.offset( - target_exprs, target_path, loop_exprs=target_indices - ) - cmap.set_value( - p.source_exprs, - offset, - p.source_path, - loop_exprs=target_indices, - ) return (rmap, cmap) @property