From 15e870da71c09bd5e88ee96f23e7b4a641602a9b Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 26 Apr 2024 21:19:10 +0100 Subject: [PATCH] Fix ghost points (#26) --- pyop3/axtree/tree.py | 10 +++++----- pyop3/itree/tree.py | 5 +++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pyop3/axtree/tree.py b/pyop3/axtree/tree.py index 58a1cb1..05b451e 100644 --- a/pyop3/axtree/tree.py +++ b/pyop3/axtree/tree.py @@ -436,8 +436,8 @@ def ghost_count_per_component(self): def owned(self): return self._tree.owned.root - def index(self): - return self._tree.index() + def index(self, *, include_ghost_points=False): + return self._tree.index(include_ghost_points=include_ghost_points) def iter(self, *, include_ghost_points=False): return self._tree.iter(include_ghost_points=include_ghost_points) @@ -709,10 +709,10 @@ def outer_loops(self): def subst_layouts(self): pass - def index(self, ghost=False): + def index(self, *, include_ghost_points=False): from pyop3.itree.tree import ContextFreeLoopIndex, LoopIndex - iterset = self if ghost else self.owned + iterset = self if include_ghost_points else self.owned # If the iterset is linear (single-component for every axis) then we # can consider the loop to be "context-free". if len(iterset.leaves) == 1: @@ -1294,7 +1294,7 @@ def tabulated_offsets(self): rmap_axes = iterset.add_subtree(self, *iterset.leaf) rmap = HierarchicalArray(rmap_axes, dtype=IntType) rmap = rmap[loop_index.local_index] - for idx in loop_index.iter(): + for idx in loop_index.iter(include_ghost_points=True): target_indices = idx.replace_map # for p in self.iter(idxs): for p in self.iter([idx], include_ghost_points=True): # seems to fix thing diff --git a/pyop3/itree/tree.py b/pyop3/itree/tree.py index f4af31c..97653e4 100644 --- a/pyop3/itree/tree.py +++ b/pyop3/itree/tree.py @@ -429,10 +429,11 @@ def layout_exprs(self): def datamap(self): return self.iterset.datamap - def iter(self, stuff=pmap()): + def iter(self, stuff=pmap(), *, include_ghost_points=False): + iterset = self.iterset if include_ghost_points else self.iterset.owned return iter_axis_tree( self, - self.iterset, + iterset, self.iterset.target_paths, self.iterset.index_exprs, stuff,