From 540985c26106750e629130179fb5c3b77af07468 Mon Sep 17 00:00:00 2001 From: ksagiyam <46749170+ksagiyam@users.noreply.github.com> Date: Mon, 29 Apr 2024 11:39:48 +0100 Subject: [PATCH] fix halo exchange (#30) * fix halo exchange * cleanup --------- Co-authored-by: Connor Ward --- pyop3/lang.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/pyop3/lang.py b/pyop3/lang.py index f632690..be45e76 100644 --- a/pyop3/lang.py +++ b/pyop3/lang.py @@ -346,7 +346,7 @@ def _buffer_exchanges(buffer, intent, *, touches_ghost_points): # reductions assert intent in {INC, MIN_WRITE, MIN_RW, MAX_WRITE, MAX_RW} # We don't need to update roots if performing the same reduction - # again. For example we can increment into an buffer as many times + # again. For example we can increment into a buffer as many times # as we want. The reduction only needs to be done when the # data is read. if buffer._roots_valid or intent == buffer._pending_reduction: @@ -361,6 +361,20 @@ def _buffer_exchanges(buffer, intent, *, touches_ghost_points): initializers.append(buffer._reduce_leaves_to_roots_begin) reductions.append(buffer._reduce_leaves_to_roots_end) + # set leaves to appropriate nil value + if intent == INC: + nil = 0 + elif intent in {MIN_WRITE, MIN_RW}: + nil = dtype_limits(buffer.dtype).max + else: + assert intent in {MAX_WRITE, MAX_RW} + nil = dtype_limits(buffer.dtype).min + + def _init_nil(): + buffer._data[buffer.sf.ileaf] = nil + + reductions.append(_init_nil) + # We are modifying owned values so the leaves must now be wrong buffer._leaves_valid = False @@ -370,15 +384,6 @@ def _buffer_exchanges(buffer, intent, *, touches_ghost_points): else: buffer._pending_reduction = intent - # set leaves to appropriate nil value - if intent == INC: - buffer._data[buffer.sf.ileaf] = 0 - elif intent in {MIN_WRITE, MIN_RW}: - buffer._data[buffer.sf.ileaf] = dtype_limits(buffer.dtype).max - else: - assert intent in {MAX_WRITE, MAX_RW} - buffer._data[buffer.sf.ileaf] = dtype_limits(buffer.dtype).min - return tuple(initializers), tuple(reductions), tuple(broadcasts)