diff --git a/pyop3/array/petsc.py b/pyop3/array/petsc.py index d1a77c2..08631a2 100644 --- a/pyop3/array/petsc.py +++ b/pyop3/array/petsc.py @@ -380,8 +380,8 @@ def _map(self, axes): loop_index = just_one(axes.outer_loops) iterset = AxisTree(loop_index.iterset.node_map) map_axes = iterset.add_subtree(axes, *iterset.leaf) - global_map = HierarchicalArray(map_axes, dtype=IntType) - global_map = global_map[loop_index.local_index] + mapping = HierarchicalArray(map_axes, dtype=IntType) + mapping = mapping[loop_index.local_index] for idxs in my_product(axes.outer_loops): # target_indices = {idx.index.id: idx.target_exprs for idx in idxs} target_indices = merge_dicts([idx.replace_map for idx in idxs]) @@ -390,20 +390,18 @@ def _map(self, axes): offset = axes.unindexed.offset( p.target_exprs, p.target_path, loop_exprs=target_indices ) - global_map.set_value( + mapping.set_value( p.source_exprs, offset, p.source_path, loop_exprs=target_indices, ) - return global_map + return mapping @cached_property @PETSc.Log.EventDecorator() def maps(self): - from pyop3.axtree.layout import my_product - # TODO: Don't think these need to be lists here. # FIXME: This will only work for singly-nested matrices if self.nested: