Skip to content

Commit

Permalink
Improve parallel threading implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Nov 22, 2023
1 parent 0a37498 commit 894d1d4
Showing 1 changed file with 77 additions and 123 deletions.
200 changes: 77 additions & 123 deletions pyop3/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,33 +145,40 @@ def __call__(self, **kwargs):
self.index, [a for a, _ in self.all_function_arguments]
)

print_with_rank("icore", icore.data)
print_with_rank("iroot", iroot.data)
print_with_rank("ileaf", ileaf.data)

assert self.index.id == new_index.id

# substitute subsets into loopexpr, should maybe be done in partition_iterset
parallel_loop = self.copy(index=new_index)
code = compile(parallel_loop)

# interleave communication and computation
with self._updates_in_flight(0):
# replace the parallel axis subset with one for the specific indices here
extent = just_one(icore.axes.root.components).count
core_kwargs = merge_dicts(
[kwargs, {icore.name: icore, extent.name: extent}]
)
code(**core_kwargs)
initializers, finalizerss = self._array_updates()

for init in initializers:
init()

# replace the parallel axis subset with one for the specific indices here
extent = just_one(icore.axes.root.components).count
core_kwargs = merge_dicts(
[kwargs, {icore.name: icore, extent.name: extent}]
)
code(**core_kwargs)

# await reductions
for fin in finalizerss[0]:
fin()

# roots
with self._updates_in_flight(1):
# replace the parallel axis subset with one for the specific indices here
root_extent = just_one(iroot.axes.root.components).count
root_kwargs = merge_dicts(
[kwargs, {icore.name: iroot, extent.name: root_extent}]
)
code(**root_kwargs)
# replace the parallel axis subset with one for the specific indices here
root_extent = just_one(iroot.axes.root.components).count
root_kwargs = merge_dicts(
[kwargs, {icore.name: iroot, extent.name: root_extent}]
)
code(**root_kwargs)

# await broadcasts
for fin in finalizerss[1]:
fin()

# leaves
leaf_extent = just_one(ileaf.axes.root.components).count
Expand Down Expand Up @@ -239,63 +246,66 @@ def _distarray_args(self):
)

def _array_updates(self):
"""Collect appropriate callables for updating shared values in the right order.
Returns
-------
(initializers, (finalizers0, finalizers1))
Collections of callables to be executed at the right times.
Notes
-----
To avoid blocking, updates are done using a different Python thread per
array per operation. The function returns a tuple of initializer callables
that trigger the operations, as well as two collections of finalizer callables.
These are separated into those that must be complete for "root" points to be
valid, and those that must be complete for "leaf" points to be valid.
This does not release the GIL but this is acceptable for this application.
"""
# NOTE: It is safe to include reductions in the finalizers because
# core entities (in the iterset) are defined as being those that do
# not overlap with any points in the star forest.
if config["thread_model"] in {"SINGLE", "SERIALIZED"}:
# SINGLE means other threads cannot be used and SERIALIZED means
# that ghost exchanges can only be done on a single thread
raise NotImplementedError

# NOTE: The following is now slightly out-of-date. We now distinguish
# *per array* and also split each generation into starts and finalizers.

# TODO update this comment to account for different threading models

# Since we sometimes have to do a reduce and then a broadcast the messages
# are organised into generations with each generation being executed in
# turn.
# As an example consider needing to update 2 arrays, one with a
# reduce-then-broadcast and the other with a reduction. This will produce
# the following collection of messages (the final generation is always -1):
#
# [generation 0] : [array1.reduce_begin, array2.reduce_begin]
# [generation 1] : [array1.reduce_end, array1.broadcast_begin]
# [generation -1] : [array1.broadcast_end, array2.reduce_end]
#
# To avoid blocking the operations are executed on a separate thread. Once
# the thread terminates, all messages will have been sent and execution
# may continue.

# maps array to messages split by generation
messages = defaultdict(partial(defaultdict, list))
initializers = []
finalizerss = ([], [])
for array, intent, touches_ghost_points in self._distarray_args:
if intent in {READ, RW}:
if touches_ghost_points:
if not array._roots_valid:
# 2-tuple of inits (bad name) and finalizers
messages[array][0] = (
# init
[array._reduce_leaves_to_roots_begin],
# finalizer
[
array._reduce_leaves_to_roots_end,
array._broadcast_roots_to_leaves_begin,
],
bcast_thread = threading.Thread(
action=array._broadcast_roots_to_leaves,
)
messages[array][1] = (
[],
[array._broadcast_roots_to_leaves_end],
# As soon as the reduction is done we start the broadcast. This
# is done on a different thread so we can use the termination of
# the original thread as a finalizer.
reduce_thread = threading.Thread(
action=map(
operator.call,
[array._reduce_leaves_to_roots, bcast_thread.start],
)
)
initializers.append(reduce_thread.start)
finalizerss[0].append(reduce_thread.join)
finalizerss[1].append(bcast_thread.join)
else:
messages[array][0] = (
[array._broadcast_roots_to_leaves_begin],
[],
)
messages[array][1] = (
[],
[array._broadcast_roots_to_leaves_end],
thread = threading.Thread(
action=array._broadcast_roots_to_leaves
)
initializers.append(thread.start)
finalizerss[1].append(thread.join)
else:
if not array._roots_valid:
messages[array][0] = ([array.reduce_leaves_to_roots_begin], [])
messages[array][1] = ([], [array.reduce_leaves_to_roots_end])
thread = threading.Thread(
action=array._reduce_leaves_to_roots,
)
initializers.append(thread.start)
finalizerss[0].append(thread.join)

elif intent == WRITE:
# Assumes that all points are written to (i.e. not a subset). If
Expand All @@ -318,8 +328,11 @@ def _array_updates(self):
# explained in the documentation.
if intent in {INC, MIN_RW, MAX_RW}:
assert array._pending_reduction is not None
messages[array][0] = ([array.reduce_leaves_to_roots_begin], [])
messages[array][1] = ([], [array.reduce_leaves_to_roots_end])
thread = threading.Thread(
action=array._reduce_leaves_to_roots,
)
initializers.append(thread.start)
finalizerss[0].append(thread.join)

# We are modifying owned values so the leaves must now be wrong
array._leaves_valid = False
Expand All @@ -343,66 +356,7 @@ def _array_updates(self):
else:
raise AssertionError

return messages

@contextlib.contextmanager
def _updates_in_flight(self, generation):
"""Context manager for interleaving computation and communication."""
sendrecvs = self._array_updates()

if config["thread_model"] in {"SINGLE", "SERIALIZED"}:
raise NotImplementedError("Untested")
# now out-of-date
# cannot do a thread per-array, compress the sendrecvs
new_sendrecvs = defaultdict(list)
for sendrecvs_per_array in sendrecvs.values():
for generation, messages in sendrecvs_per_array.items():
new_sendrecvs[generation].extend(messages)
sendrecvs = new_sendrecvs

# TODO make an enum
if config["thread_model"] == "SINGLE":
raise NotImplementedError("Untested")
# multithreading is not supported, do all of the sendrecvs apart
# from the final generation eagerly
ngenerations = len(sendrecvs) - 1
for gen in range(ngenerations):
for sendrecv in sendrecvs[gen]:
sendrecv()
finalizers = messages[-1]
elif config["thread_model"] == "SERIALIZED":
raise NotImplementedError("Untested")
# ghost exchanges can only be done on a single thread
thread = threading.Thread(
target=self.__class__._sendrecv, args=(sendrecvs,)
)
thread.start()
finalizers = [thread.join]
elif config["thread_model"] == "MULTIPLE":
# different thread per array
finalizers = []
for sendrecvs_per_array in sendrecvs.values():
inits, finalizers_ = sendrecvs_per_array[generation]
thread = threading.Thread(
target=self.__class__._sendrecv, args=(inits, finalizers_)
)
thread.start()
finalizers.append(thread.join)
else:
raise AssertionError

yield

# wait for everything
for f in finalizers:
f()

@staticmethod
def _sendrecv(inits, finalizers):
for msg in inits:
msg()
for fin in finalizers:
fin()
return initializers, finalizerss


# TODO singledispatch
Expand Down

0 comments on commit 894d1d4

Please sign in to comment.